train!
This commit is contained in:
+30
-3
@@ -1,4 +1,7 @@
|
|||||||
|
import bz2
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import tqdm
|
import tqdm
|
||||||
import tarfile
|
import tarfile
|
||||||
@@ -62,6 +65,27 @@ def decompress(file_path, output_dir):
|
|||||||
tar.extract(member, path=output_dir)
|
tar.extract(member, path=output_dir)
|
||||||
bar.update(1)
|
bar.update(1)
|
||||||
|
|
||||||
|
|
||||||
|
def decompress_bz2(file_path, output_dir):
|
||||||
|
# 确保输出目录存在
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
print('解压 ' + file_path.split('/')[-1], ' 到 ', output_dir)
|
||||||
|
|
||||||
|
# 获取.bz2文件的名称,不包含扩展名
|
||||||
|
output_filename = os.path.basename(file_path).rsplit('.bz2', 1)[0]
|
||||||
|
|
||||||
|
# 构建输出文件的完整路径
|
||||||
|
output_file_path = os.path.join(output_dir, output_filename)
|
||||||
|
|
||||||
|
# 检查输出文件是否已存在,如果存在则跳过解压
|
||||||
|
if os.path.exists(output_file_path):
|
||||||
|
print(output_filename, ' 已存在,跳过解压')
|
||||||
|
return
|
||||||
|
|
||||||
|
# 解压.bz2文件
|
||||||
|
with bz2.BZ2File(file_path, 'rb') as bz2_file, open(output_file_path, 'wb') as output_file:
|
||||||
|
output_file.write(bz2_file.read())
|
||||||
|
|
||||||
# 下载人脸数据集
|
# 下载人脸数据集
|
||||||
face_dataset_url='http://vis-www.cs.umass.edu/fddb/originalPics.tar.gz'
|
face_dataset_url='http://vis-www.cs.umass.edu/fddb/originalPics.tar.gz'
|
||||||
face_dataset_path='cache/dataset/face/'
|
face_dataset_path='cache/dataset/face/'
|
||||||
@@ -77,8 +101,11 @@ decompress(os.path.join(face_dataset_path,face_dataset_url.split('/')[-1]),face_
|
|||||||
decompress(os.path.join(face_label_path,face_label_url.split('/')[-1]),face_label_path)
|
decompress(os.path.join(face_label_path,face_label_url.split('/')[-1]),face_label_path)
|
||||||
|
|
||||||
# 下载ResNet模型
|
# 下载ResNet模型
|
||||||
model_url1='https://github.com/davisking/dlib-models/raw/master/shape_predictor_68_face_landmarks.dat'
|
model_url1='https://github.com/davisking/dlib-models/raw/refs/heads/master/shape_predictor_68_face_landmarks.dat.bz2'
|
||||||
model_url2='https://github.com/davisking/dlib-models/raw/master/dlib_face_recognition_resnet_model_v1.dat'
|
model_url2='https://github.com/davisking/dlib-models/raw/refs/heads/master/dlib_face_recognition_resnet_model_v1.dat.bz2'
|
||||||
model_path='models/'
|
model_path='models/'
|
||||||
download(model_url1,model_path)
|
download(model_url1,model_path)
|
||||||
download(model_url2,model_path)
|
download(model_url2,model_path)
|
||||||
|
|
||||||
|
decompress_bz2(os.path.join(model_path,model_url1.split('/')[-1]),model_path)
|
||||||
|
decompress_bz2(os.path.join(model_path,model_url2.split('/')[-1]),model_path)
|
||||||
+56
-13
@@ -1,31 +1,74 @@
|
|||||||
import dlib
|
import dlib
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
|
import joblib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.multioutput import MultiOutputClassifier
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
# 提取人脸特征
|
# 提取人脸特征
|
||||||
def extract_face_features(image_path):
|
def extract_face_features(image_path):
|
||||||
img=cv.imread(image_path)
|
img = cv.imread(image_path)
|
||||||
detections=detector(img,1)
|
detections = detector(img, 1)
|
||||||
face_features=[]
|
face_features = []
|
||||||
|
face_bboxes = []
|
||||||
for rect in detections:
|
for rect in detections:
|
||||||
shape=predictor(img,rect)
|
shape = predictor(img, rect)
|
||||||
face_descriptor=face_rec_model.compute_face_descriptor(img,shape)
|
face_descriptor = face_rec_model.compute_face_descriptor(img, shape)
|
||||||
face_features.append(face_descriptor)
|
face_features.append(face_descriptor)
|
||||||
return face_features
|
bbox = (rect.left(), rect.top(), rect.right(), rect.bottom())
|
||||||
|
face_bboxes.append(bbox)
|
||||||
|
return face_features, face_bboxes
|
||||||
|
|
||||||
# 获取图片路径
|
# 获取图片路径
|
||||||
def get_img_path(directory,extension=None):
|
def get_img_path(directory, extension=None):
|
||||||
if extension==None:
|
if extension is None:
|
||||||
extension=['.jpg','.jpeg','.png']
|
extension = ['.jpg', '.jpeg', '.png']
|
||||||
files=[]
|
files = []
|
||||||
for root, dirs, file_names in os.walk(directory):
|
for root, dirs, file_names in os.walk(directory):
|
||||||
for file_name in file_names:
|
for file_name in file_names:
|
||||||
if any(file_name.lower().endswith(ext) for ext in extension):
|
if any(file_name.lower().endswith(ext) for ext in extension):
|
||||||
files.append(os.path.join(root,file_name))
|
files.append(os.path.join(root, file_name))
|
||||||
return files
|
return files
|
||||||
|
|
||||||
|
|
||||||
# 加载dlib模型
|
# 加载dlib模型
|
||||||
detector = dlib.get_frontal_face_detector()
|
detector = dlib.get_frontal_face_detector()
|
||||||
predictor=dlib.shape_predictor('models/shape_predictor_68_face_landmarks.dat')
|
predictor = dlib.shape_predictor('models/shape_predictor_68_face_landmarks.dat')
|
||||||
face_rec_model=dlib.face_recognition_model_v1('实验八/models/dlib_face_recognition_resnet_model_v1.dat')
|
face_rec_model = dlib.face_recognition_model_v1('models/dlib_face_recognition_resnet_model_v1.dat')
|
||||||
|
|
||||||
|
# 获取图片
|
||||||
|
img_directory = 'cache/pretrained/'
|
||||||
|
images = get_img_path(img_directory)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
features = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for image in tqdm(images, desc='提取图片特征中:'):
|
||||||
|
extracted_features, face_bboxes = extract_face_features(image)
|
||||||
|
for feature,bbox in zip(extracted_features, face_bboxes):
|
||||||
|
features.append(feature)
|
||||||
|
labels.append(bbox)
|
||||||
|
|
||||||
|
X_train = np.array(features)
|
||||||
|
Y_train = np.array(labels)
|
||||||
|
|
||||||
|
# 分割测试集
|
||||||
|
X_train, X_test, Y_train, Y_test = train_test_split(X_train, Y_train, test_size=0.2, random_state=42)
|
||||||
|
|
||||||
|
# 训练SVM模型
|
||||||
|
print('训练模型中')
|
||||||
|
clf=MultiOutputClassifier(RandomForestClassifier(n_estimators=100,random_state=42))
|
||||||
|
clf.fit(X_train, Y_train)
|
||||||
|
|
||||||
|
# 评估训练数据
|
||||||
|
predictions = clf.predict(X_test)
|
||||||
|
accuracy=(predictions==Y_test).mean()
|
||||||
|
print(f'分类器准确度:{accuracy * 100:.2f}%')
|
||||||
|
|
||||||
|
os.makedirs('models/', exist_ok=True)
|
||||||
|
joblib.dump(clf, 'models/my_classifier.pkl')
|
||||||
|
|||||||
Reference in New Issue
Block a user