From 7a39db597b3936da6ebbc0f9f944e1a0fb12f0c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=A1=E5=9D=82=E6=98=B4?= Date: Wed, 27 Nov 2024 18:11:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 实验六/train.py | 83 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 实验六/train.py diff --git a/实验六/train.py b/实验六/train.py new file mode 100644 index 0000000..3435a50 --- /dev/null +++ b/实验六/train.py @@ -0,0 +1,83 @@ +import cv2 as cv +import joblib +import numpy as np +import tqdm +import os + +from sklearn.metrics import accuracy_score +from sklearn.svm import SVC + +# 提取轮廓特征 +def extract_contour_features(img): + contours, _ = cv.findContours(img, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) + contour = contours[0] + area = cv.contourArea(contour) + perimeter = cv.arcLength(contour, True) + return [area, perimeter] + +# 提取形状特征 +def extract_shape_features(contour): + x, y, w, h = cv.boundingRect(contour) + aspect_ratio = float(w) / h + rect_area = w * h + shape_factor = cv.contourArea(contour) / rect_area + return [aspect_ratio, shape_factor] + +# 计算HU矩 +def extract_hu_moments(contour): + moments = cv.moments(contour) + hu_moments = cv.HuMoments(moments) + return hu_moments.flatten() + +# 特征向量构建 +def extract_features(img_path): + img = cv.imread(img_path, cv.IMREAD_GRAYSCALE) + if img is None: + raise FileNotFoundError(f"无法加载图像: {img_path}") + _, img_bin = cv.threshold(img, 128, 255, cv.THRESH_BINARY) + contours, _ = cv.findContours(img_bin, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) + contour = contours[0] + contour_features = extract_contour_features(img_bin) + shape_features = extract_shape_features(contour) + hu_moments = extract_hu_moments(contour) + feature_vector = contour_features + shape_features + hu_moments.tolist() + return feature_vector + +# 加载图像路径和标签 +def load_data(dataset_path): + image_paths = [] + labels = [] + for file_name in os.listdir(dataset_path): + if file_name.endswith(".png"): + label = int(file_name.split("-")[-1].split(".")[0]) + image_paths.append(os.path.join(dataset_path, file_name)) + labels.append(label) + return image_paths, labels + +# 创建文件夹 +def ensure_dir_exists(directory): + if not os.path.exists(directory): + os.makedirs(directory) + +# 加载训练数据 +trains_paths, trains_labels = load_data("cache/pretrains/train") +test_paths, test_labels = load_data("cache/pretrains/test") + +# 提取特征和标签 +X_train = np.array([extract_features(train_path) for train_path in tqdm.tqdm(trains_paths, desc="训练集特征提取中:")]) +Y_train = np.array(trains_labels) +X_test = np.array([extract_features(test_path) for test_path in tqdm.tqdm(test_paths, desc="测试集特征提取中:")]) +Y_test = np.array(test_labels) + +# 训练分类器 +classifier = SVC(kernel="linear") +classifier.fit(X_train, Y_train) + +# 在测试集上进行评估 +Y_pred = classifier.predict(X_test) +accuracy = accuracy_score(Y_test, Y_pred) +print(f"性能: {accuracy * 100:.2f}%") + +# 保存模型 +ensure_dir_exists("models") +joblib.dump(classifier, "models/classifier.pkl")