From e394669ac2048c61915b6313cda833cb13629f57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=A1=E5=9D=82=E6=98=B4?= Date: Thu, 12 Dec 2024 21:58:18 +0800 Subject: [PATCH] train! --- 实验八/1.get_dataset.py | 33 ++++++++++++++++++-- 实验八/3.train.py | 69 +++++++++++++++++++++++++++++++++-------- 2 files changed, 86 insertions(+), 16 deletions(-) diff --git a/实验八/1.get_dataset.py b/实验八/1.get_dataset.py index 87b3569..04b8ac1 100644 --- a/实验八/1.get_dataset.py +++ b/实验八/1.get_dataset.py @@ -1,4 +1,7 @@ +import bz2 import os +import tempfile + import requests import tqdm import tarfile @@ -62,6 +65,27 @@ def decompress(file_path, output_dir): tar.extract(member, path=output_dir) bar.update(1) + +def decompress_bz2(file_path, output_dir): + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + print('解压 ' + file_path.split('/')[-1], ' 到 ', output_dir) + + # 获取.bz2文件的名称,不包含扩展名 + output_filename = os.path.basename(file_path).rsplit('.bz2', 1)[0] + + # 构建输出文件的完整路径 + output_file_path = os.path.join(output_dir, output_filename) + + # 检查输出文件是否已存在,如果存在则跳过解压 + if os.path.exists(output_file_path): + print(output_filename, ' 已存在,跳过解压') + return + + # 解压.bz2文件 + with bz2.BZ2File(file_path, 'rb') as bz2_file, open(output_file_path, 'wb') as output_file: + output_file.write(bz2_file.read()) + # 下载人脸数据集 face_dataset_url='http://vis-www.cs.umass.edu/fddb/originalPics.tar.gz' face_dataset_path='cache/dataset/face/' @@ -77,8 +101,11 @@ decompress(os.path.join(face_dataset_path,face_dataset_url.split('/')[-1]),face_ decompress(os.path.join(face_label_path,face_label_url.split('/')[-1]),face_label_path) # 下载ResNet模型 -model_url1='https://github.com/davisking/dlib-models/raw/master/shape_predictor_68_face_landmarks.dat' -model_url2='https://github.com/davisking/dlib-models/raw/master/dlib_face_recognition_resnet_model_v1.dat' +model_url1='https://github.com/davisking/dlib-models/raw/refs/heads/master/shape_predictor_68_face_landmarks.dat.bz2' +model_url2='https://github.com/davisking/dlib-models/raw/refs/heads/master/dlib_face_recognition_resnet_model_v1.dat.bz2' model_path='models/' download(model_url1,model_path) -download(model_url2,model_path) \ No newline at end of file +download(model_url2,model_path) + +decompress_bz2(os.path.join(model_path,model_url1.split('/')[-1]),model_path) +decompress_bz2(os.path.join(model_path,model_url2.split('/')[-1]),model_path) \ No newline at end of file diff --git a/实验八/3.train.py b/实验八/3.train.py index 394685b..09a2734 100644 --- a/实验八/3.train.py +++ b/实验八/3.train.py @@ -1,31 +1,74 @@ import dlib import cv2 as cv +import joblib import numpy as np import os +from sklearn.ensemble import RandomForestClassifier +from sklearn.model_selection import train_test_split +from sklearn.multioutput import MultiOutputClassifier +from tqdm import tqdm + # 提取人脸特征 def extract_face_features(image_path): - img=cv.imread(image_path) - detections=detector(img,1) - face_features=[] + img = cv.imread(image_path) + detections = detector(img, 1) + face_features = [] + face_bboxes = [] for rect in detections: - shape=predictor(img,rect) - face_descriptor=face_rec_model.compute_face_descriptor(img,shape) + shape = predictor(img, rect) + face_descriptor = face_rec_model.compute_face_descriptor(img, shape) face_features.append(face_descriptor) - return face_features + bbox = (rect.left(), rect.top(), rect.right(), rect.bottom()) + face_bboxes.append(bbox) + return face_features, face_bboxes # 获取图片路径 -def get_img_path(directory,extension=None): - if extension==None: - extension=['.jpg','.jpeg','.png'] - files=[] +def get_img_path(directory, extension=None): + if extension is None: + extension = ['.jpg', '.jpeg', '.png'] + files = [] for root, dirs, file_names in os.walk(directory): for file_name in file_names: if any(file_name.lower().endswith(ext) for ext in extension): - files.append(os.path.join(root,file_name)) + files.append(os.path.join(root, file_name)) return files + # 加载dlib模型 detector = dlib.get_frontal_face_detector() -predictor=dlib.shape_predictor('models/shape_predictor_68_face_landmarks.dat') -face_rec_model=dlib.face_recognition_model_v1('实验八/models/dlib_face_recognition_resnet_model_v1.dat') \ No newline at end of file +predictor = dlib.shape_predictor('models/shape_predictor_68_face_landmarks.dat') +face_rec_model = dlib.face_recognition_model_v1('models/dlib_face_recognition_resnet_model_v1.dat') + +# 获取图片 +img_directory = 'cache/pretrained/' +images = get_img_path(img_directory) + +# 提取特征 +features = [] +labels = [] + +for image in tqdm(images, desc='提取图片特征中:'): + extracted_features, face_bboxes = extract_face_features(image) + for feature,bbox in zip(extracted_features, face_bboxes): + features.append(feature) + labels.append(bbox) + +X_train = np.array(features) +Y_train = np.array(labels) + +# 分割测试集 +X_train, X_test, Y_train, Y_test = train_test_split(X_train, Y_train, test_size=0.2, random_state=42) + +# 训练SVM模型 +print('训练模型中') +clf=MultiOutputClassifier(RandomForestClassifier(n_estimators=100,random_state=42)) +clf.fit(X_train, Y_train) + +# 评估训练数据 +predictions = clf.predict(X_test) +accuracy=(predictions==Y_test).mean() +print(f'分类器准确度:{accuracy * 100:.2f}%') + +os.makedirs('models/', exist_ok=True) +joblib.dump(clf, 'models/my_classifier.pkl')