From 428fbd40f41481200b032333f60d014ac6a8e441 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=A1=E5=9D=82=E6=98=B4?= Date: Wed, 27 Nov 2024 16:25:26 +0800 Subject: [PATCH] =?UTF-8?q?mnist=E6=95=B0=E6=8D=AE=E9=9B=86=E8=8E=B7?= =?UTF-8?q?=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 实验六/get_dataset.py | 20 ++++++++++++++++++++ 实验六/main.py | 1 + 实验六/pretrain.py | 0 3 files changed, 21 insertions(+) create mode 100644 实验六/get_dataset.py create mode 100644 实验六/main.py create mode 100644 实验六/pretrain.py diff --git a/实验六/get_dataset.py b/实验六/get_dataset.py new file mode 100644 index 0000000..39b0a6c --- /dev/null +++ b/实验六/get_dataset.py @@ -0,0 +1,20 @@ +from torchvision import datasets +from tqdm import tqdm +import os + +# 数据集加载 +train_data = datasets.MNIST(root="./data/", train=True, download=True) +test_data = datasets.MNIST(root="./data/", train=False, download=True) + +# 保存图片函数 +def save_img_subset(data, save_path, num_samples): + if not os.path.exists(save_path): + os.mkdir(save_path) + + for i in tqdm(range(num_samples), desc=f"Saving {num_samples} images to {save_path}"): + img, label = data[i] + img.save(os.path.join(save_path, f"{i}-label-{label}.png")) + +# 保存前 600 张训练集图片和前 100 张测试集图片 +save_img_subset(train_data, './DataImages-Train', 600) +save_img_subset(test_data, './DataImages-Test', 100) diff --git a/实验六/main.py b/实验六/main.py new file mode 100644 index 0000000..e1b8f64 --- /dev/null +++ b/实验六/main.py @@ -0,0 +1 @@ +import cv2 as cv \ No newline at end of file diff --git a/实验六/pretrain.py b/实验六/pretrain.py new file mode 100644 index 0000000..e69de29