mnist数据集获取
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
from torchvision import datasets
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
|
||||
# 数据集加载
|
||||
train_data = datasets.MNIST(root="./data/", train=True, download=True)
|
||||
test_data = datasets.MNIST(root="./data/", train=False, download=True)
|
||||
|
||||
# 保存图片函数
|
||||
def save_img_subset(data, save_path, num_samples):
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
|
||||
for i in tqdm(range(num_samples), desc=f"Saving {num_samples} images to {save_path}"):
|
||||
img, label = data[i]
|
||||
img.save(os.path.join(save_path, f"{i}-label-{label}.png"))
|
||||
|
||||
# 保存前 600 张训练集图片和前 100 张测试集图片
|
||||
save_img_subset(train_data, './DataImages-Train', 600)
|
||||
save_img_subset(test_data, './DataImages-Test', 100)
|
||||
@@ -0,0 +1 @@
|
||||
import cv2 as cv
|
||||
Reference in New Issue
Block a user