资料内容:
2. 加载和预处理数据
创建一个Python文件,比如叫 main.py ,然后开始编写代码。首先,导入必要的库:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 设置数据转换
transform = transforms.Compose([
transforms.ToTensor(), # 转换图像为Tensor
transforms.Normalize((0.5,), (0.5,)) # 标准化
])
# 下载训练集和测试集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True,
transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True,
transform=transform)
# 加载数据
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)