加了个进度条
This commit is contained in:
+2
-2
@@ -16,5 +16,5 @@ def save_img_subset(data, save_path, num_samples):
|
|||||||
img.save(os.path.join(save_path, f"{i}-label-{label}.png"))
|
img.save(os.path.join(save_path, f"{i}-label-{label}.png"))
|
||||||
|
|
||||||
# 保存前 600 张训练集图片和前 100 张测试集图片
|
# 保存前 600 张训练集图片和前 100 张测试集图片
|
||||||
save_img_subset(train_data, './DataImages-Train', 600)
|
save_img_subset(train_data, './DataImages-Train', 6000)
|
||||||
save_img_subset(test_data, './DataImages-Test', 100)
|
save_img_subset(test_data, './DataImages-Test', 1000)
|
||||||
|
|||||||
+14
-2
@@ -1,3 +1,5 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import joblib
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -6,6 +8,8 @@ import os
|
|||||||
|
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score
|
||||||
from sklearn.svm import SVC
|
from sklearn.svm import SVC
|
||||||
|
from sklearn.utils import parallel_backend
|
||||||
|
|
||||||
|
|
||||||
# 提取轮廓特征
|
# 提取轮廓特征
|
||||||
def extract_contour_features(img):
|
def extract_contour_features(img):
|
||||||
@@ -71,11 +75,19 @@ Y_test = np.array(test_labels)
|
|||||||
|
|
||||||
# 训练分类器
|
# 训练分类器
|
||||||
classifier = SVC(kernel="linear")
|
classifier = SVC(kernel="linear")
|
||||||
classifier.fit(X_train, Y_train)
|
with parallel_backend('threading',n_jobs=-1):
|
||||||
|
start_time=time.time()
|
||||||
|
classifier.fit(X_train, Y_train)
|
||||||
|
elapsed_time = time.time()-start_time
|
||||||
|
|
||||||
|
print(f"模型训练耗时: {elapsed_time:.2f} 秒")
|
||||||
|
|
||||||
# 在测试集上进行评估
|
# 在测试集上进行评估
|
||||||
Y_pred = classifier.predict(X_test)
|
Y_pred = []
|
||||||
|
for test_sample in tqdm.tqdm(X_test, desc="测试集中预测进度"):
|
||||||
|
Y_pred.append(classifier.predict(test_sample.reshape(1, -1)))
|
||||||
accuracy = accuracy_score(Y_test, Y_pred)
|
accuracy = accuracy_score(Y_test, Y_pred)
|
||||||
|
|
||||||
print(f"性能: {accuracy * 100:.2f}%")
|
print(f"性能: {accuracy * 100:.2f}%")
|
||||||
|
|
||||||
# 保存模型
|
# 保存模型
|
||||||
|
|||||||
Reference in New Issue
Block a user