From 39aeac3be769dc616151022b292ad7d3f6a4d617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=A1=E5=9D=82=E6=98=B4?= Date: Wed, 27 Nov 2024 18:28:41 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E4=BA=86=E4=B8=AA=E8=BF=9B=E5=BA=A6?= =?UTF-8?q?=E6=9D=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 实验六/get_dataset.py | 4 ++-- 实验六/train.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/实验六/get_dataset.py b/实验六/get_dataset.py index 39b0a6c..e1f611f 100644 --- a/实验六/get_dataset.py +++ b/实验六/get_dataset.py @@ -16,5 +16,5 @@ def save_img_subset(data, save_path, num_samples): img.save(os.path.join(save_path, f"{i}-label-{label}.png")) # 保存前 600 张训练集图片和前 100 张测试集图片 -save_img_subset(train_data, './DataImages-Train', 600) -save_img_subset(test_data, './DataImages-Test', 100) +save_img_subset(train_data, './DataImages-Train', 6000) +save_img_subset(test_data, './DataImages-Test', 1000) diff --git a/实验六/train.py b/实验六/train.py index 3435a50..ad14bfb 100644 --- a/实验六/train.py +++ b/实验六/train.py @@ -1,3 +1,5 @@ +import time + import cv2 as cv import joblib import numpy as np @@ -6,6 +8,8 @@ import os from sklearn.metrics import accuracy_score from sklearn.svm import SVC +from sklearn.utils import parallel_backend + # 提取轮廓特征 def extract_contour_features(img): @@ -71,11 +75,19 @@ Y_test = np.array(test_labels) # 训练分类器 classifier = SVC(kernel="linear") -classifier.fit(X_train, Y_train) +with parallel_backend('threading',n_jobs=-1): + start_time=time.time() + classifier.fit(X_train, Y_train) + elapsed_time = time.time()-start_time + +print(f"模型训练耗时: {elapsed_time:.2f} 秒") # 在测试集上进行评估 -Y_pred = classifier.predict(X_test) +Y_pred = [] +for test_sample in tqdm.tqdm(X_test, desc="测试集中预测进度"): + Y_pred.append(classifier.predict(test_sample.reshape(1, -1))) accuracy = accuracy_score(Y_test, Y_pred) + print(f"性能: {accuracy * 100:.2f}%") # 保存模型