用于图像分类的广泛使用的数据集之一是 MNIST 数据集[LeCun et al., 1998]。虽然它作为基准数据集运行良好,但即使按照今天的标准,即使是简单的模型也能达到 95% 以上的分类准确率,这使得它不适合区分强模型和弱模型。今天,MNIST 更多的是作为健全性检查而不是基准。为了提高赌注,我们将在接下来的部分中将讨论重点放在质量相似但相对复杂的 Fashion-MNIST 数据集 [Xiao et al., ]上,该数据集于 年发布。
%matplotlib inlineimport torchimport torchvisionfrom torch.utils import datafrom torchvision import transformsfrom d2l import torch as d2ld2l.use_svg_display()
3.5.1 读取数据集
我们可以通过框架中的内置函数下载 Fashion-MNIST 数据集并将其读入内存。
# `ToTensor` converts the image data from PIL type to 32-bit floating point# tensors. It divides all numbers by 255 so that all pixel values are between# 0 and 1trans = transforms.ToTensor()mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
Fashion-MNIST 由来自 10 个类别的图像组成,每个类别由训练数据集中的 6000 张图像和测试数据集中的 1000 张图像表示。测试数据集(或测试集)用于 评估模型性能,而不是用于训练。因此,训练集和测试集分别包含 60000 和 10000 张图像。
len(mnist_train), len(mnist_test)
(60000, 10000)
每个输入图像的高度和宽度都是 28 像素。请注意,该数据集由灰度图像组成,其通道数为 1。为简洁起见,在本书中,我们存储了具有高度的任何图像的高度h
,宽度w
,像素为hxw
mnist_train[0][0].shape
torch.Size([1, 28, 28])
Fashion-MNIST 中的图像与以下类别相关联:T 恤、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和踝靴。以下函数在数字标签索引和它们在文本中的名称之间进行转换。
def get_fashion_mnist_labels(labels): #@save"""Return text labels for the Fashion-MNIST dataset."""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]
我们现在可以创建一个函数来可视化这些示例。
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save"""Plot a list of images."""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# Tensor Imageax.imshow(img.numpy())else:# PIL Imageax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
以下是训练数据集中前几个示例的图像及其对应的标签(以文本形式)。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
3.5.2. 读取小批量
为了让我们在读取训练集和测试集时更轻松,我们使用内置的数据迭代器,而不是从头开始创建一个。回想一下,在每次迭代中,数据迭代器每次都会读取具有大小的小批量数据batch_size。我们还随机打乱训练数据迭代器的示例。
batch_size = 256def get_dataloader_workers(): #@save"""Use 4 processes to read the data."""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())
让我们看看读取训练数据所需的时间。
timer = d2l.Timer()for X, y in train_iter:continuef'{timer.stop():.2f} sec'
'2.46 sec'
3.5.3. 把所有东西放在一起
现在我们定义load_data_fashion_mnist获取和读取 Fashion-MNIST 数据集的函数。它返回训练集和验证集的数据迭代器。此外,它接受一个可选参数以将图像大小调整为另一种形状。
def load_data_fashion_mnist(batch_size, resize=None): #@save"""Download the Fashion-MNIST dataset and then load it into memory."""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = pose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
下面我们 load_data_fashion_mnist通过指定resize 参数来测试函数的图像大小调整功能。
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break
torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64
我们现在已准备好在接下来的部分中使用 Fashion-MNIST 数据集。
3.5.4。概括
Fashion-MNIST 是一个服装分类数据集,由代表 10 个类别的图像组成。我们将在后续章节和章节中使用这个数据集来评估各种分类算法。
我们用高度存储任何图像的高度h
,宽度w
,像素hxw
.
数据迭代器是高效性能的关键组件。依靠利用高性能计算的良好实现的数据迭代器来避免减慢训练循环。
3.5.5。练习
减少batch_size(例如,减少到 1)会影响阅读性能吗?
读取的总数是一样的,工作总量是一样的。batch_size的目的一个是为了并行,一个是为了减少一次读太多数据,对内存存储要求太高。
数据迭代器的性能很重要。您认为当前的实施速度是否足够快?探索各种改进方案。
查看框架的在线 API 文档。还有哪些其他数据集可用?
/docs/stable/torchvision/datasets.html
Datasets:
MNIST
Fashion-MNIST
KMNIST
EMNIST
QMNIST
FakeData
COCO:Captions,Detection
LSUN
ImageFolder
DatasetFolder
ImageNet
CIFAR
STL10
SVHN
PhotoTour
SBU
Flickr
VOC
Cityscapes
SBD
USPS
Kinetics-400
HMDB51
UCF101
CelebA
参考
https://d2l.ai/chapter_linear-networks/image-classification-dataset.html