from torchvision import datasets from tqdm import tqdm import os # 数据集加载 train_data = datasets.MNIST(root="./data/", train=True, download=True) test_data = datasets.MNIST(root="./data/", train=False, download=True) # 保存图片函数 def save_img_subset(data, save_path, num_samples): if not os.path.exists(save_path): os.mkdir(save_path) for i in tqdm(range(num_samples), desc=f"Saving {num_samples} images to {save_path}"): img, label = data[i] img.save(os.path.join(save_path, f"{i}-label-{label}.png")) # 保存前 600 张训练集图片和前 100 张测试集图片 save_img_subset(train_data, './DataImages-Train', 60000) save_img_subset(test_data, './DataImages-Test', 10000)