使用数据集mnist的训练集部分
import os
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
平台是python3.x
python main.py
- mnist数据集自动下载,选择download=True,也可进行变换,只是用到了它的训练集(剩下的测试集没有使用)
datasets.MNIST('data', train=True, download=True, transform=transform)
- 代码里面有详细的讲解,应该比较容易弄懂