diff --git a/实验六/main.py b/实验六/main.py index e1b8f64..5678a2f 100644 --- a/实验六/main.py +++ b/实验六/main.py @@ -1 +1,58 @@ -import cv2 as cv \ No newline at end of file +import cv2 as cv +import joblib +from matplotlib import pyplot as plt, rcParams + +# 导入模型 +classifier = joblib.load('models/classifier.pkl') + +# 导入测试图片 +img=cv.imread('test.png',cv.IMREAD_GRAYSCALE) + +# 设置中文字体 +rcParams['font.sans-serif'] = ['SimHei'] +rcParams['axes.unicode_minus'] = False + +# 预处理图像 +_,img_classifier=cv.threshold(img, 0, 255, cv.THRESH_BINARY + cv.THRESH_OTSU) +img_classifier=cv.blur(img,(3,3)) + +# 特征提取 +def extract_contour_features(img): + contours, _ = cv.findContours(img, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) + contour = contours[0] + area = cv.contourArea(contour) + perimeter = cv.arcLength(contour, True) + return [area, perimeter] + +def extract_shape_features(contour): + x, y, w, h = cv.boundingRect(contour) + aspect_ratio = float(w) / h + rect_area = w * h + shape_factor = cv.contourArea(contour) / rect_area + return [aspect_ratio, shape_factor] + +def extract_hu_moments(contour): + moments = cv.moments(contour) + hu_moments = cv.HuMoments(moments) + return hu_moments.flatten() + +def extract_features(img): + contours, _ = cv.findContours(img, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE) + contour = contours[0] + contour_features = extract_contour_features(img) + shape_features = extract_shape_features(contour) + hu_moments = extract_hu_moments(contour) + feature_vector = contour_features + shape_features + hu_moments.tolist() + return feature_vector + +feature_vector = extract_features(img) + +# 预测 +predicted_label=classifier.predict([feature_vector]) + +plt.figure() +plt.imshow(img,cmap='gray') +plt.title('预测结果:'+str(predicted_label[0])) +plt.axis('off') +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/实验六/test.png b/实验六/test.png new file mode 100644 index 0000000..1b1f72c Binary files /dev/null and b/实验六/test.png differ diff --git a/实验六/train.py b/实验六/train.py index ad14bfb..ed943d0 100644 --- a/实验六/train.py +++ b/实验六/train.py @@ -75,11 +75,10 @@ Y_test = np.array(test_labels) # 训练分类器 classifier = SVC(kernel="linear") -with parallel_backend('threading',n_jobs=-1): - start_time=time.time() +with parallel_backend('threading', n_jobs=-1): + start_time = time.time() classifier.fit(X_train, Y_train) - elapsed_time = time.time()-start_time - + elapsed_time = time.time() - start_time print(f"模型训练耗时: {elapsed_time:.2f} 秒") # 在测试集上进行评估