900字范文,内容丰富有趣,生活中的好帮手!
900字范文 > 翻译: 3.5. 图像分类数据集Fashion-MNIST pytorch

翻译: 3.5. 图像分类数据集Fashion-MNIST pytorch

时间:2023-08-24 05:27:47

相关推荐

翻译: 3.5. 图像分类数据集Fashion-MNIST pytorch

用于图像分类的广泛使用的数据集之一是 MNIST 数据集[LeCun et al., 1998]。虽然它作为基准数据集运行良好,但即使按照今天的标准,即使是简单的模型也能达到 95% 以上的分类准确率,这使得它不适合区分强模型和弱模型。今天,MNIST 更多的是作为健全性检查而不是基准。为了提高赌注,我们将在接下来的部分中将讨论重点放在质量相似但相对复杂的 Fashion-MNIST 数据集 [Xiao et al., ]上,该数据集于 年发布。

%matplotlib inlineimport torchimport torchvisionfrom torch.utils import datafrom torchvision import transformsfrom d2l import torch as d2ld2l.use_svg_display()

3.5.1 读取数据集

我们可以通过框架中的内置函数下载 Fashion-MNIST 数据集并将其读入内存。

# `ToTensor` converts the image data from PIL type to 32-bit floating point# tensors. It divides all numbers by 255 so that all pixel values are between# 0 and 1trans = transforms.ToTensor()mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

Fashion-MNIST 由来自 10 个类别的图像组成,每个类别由训练数据集中的 6000 张图像和测试数据集中的 1000 张图像表示。测试数据集(或测试集)用于 评估模型性能,而不是用于训练。因此,训练集和测试集分别包含 60000 和 10000 张图像。

len(mnist_train), len(mnist_test)

(60000, 10000)

每个输入图像的高度和宽度都是 28 像素。请注意,该数据集由灰度图像组成,其通道数为 1。为简洁起见,在本书中,我们存储了具有高度的任何图像的高度h,宽度w,像素为hxw

mnist_train[0][0].shape

torch.Size([1, 28, 28])

Fashion-MNIST 中的图像与以下类别相关联:T 恤、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和踝靴。以下函数在数字标签索引和它们在文本中的名称之间进行转换。

def get_fashion_mnist_labels(labels): #@save"""Return text labels for the Fashion-MNIST dataset."""text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

我们现在可以创建一个函数来可视化这些示例。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save"""Plot a list of images."""figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):# Tensor Imageax.imshow(img.numpy())else:# PIL Imageax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

以下是训练数据集中前几个示例的图像及其对应的标签(以文本形式)。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

3.5.2. 读取小批量

为了让我们在读取训练集和测试集时更轻松,我们使用内置的数据迭代器,而不是从头开始创建一个。回想一下,在每次迭代中,数据迭代器每次都会读取具有大小的小批量数据batch_size。我们还随机打乱训练数据迭代器的示例。

batch_size = 256def get_dataloader_workers(): #@save"""Use 4 processes to read the data."""return 4train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers())

让我们看看读取训练数据所需的时间。

timer = d2l.Timer()for X, y in train_iter:continuef'{timer.stop():.2f} sec'

'2.46 sec'

3.5.3. 把所有东西放在一起

现在我们定义load_data_fashion_mnist获取和读取 Fashion-MNIST 数据集的函数。它返回训练集和验证集的数据迭代器。此外,它接受一个可选参数以将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None): #@save"""Download the Fashion-MNIST dataset and then load it into memory."""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = pose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

下面我们 load_data_fashion_mnist通过指定resize 参数来测试函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)for X, y in train_iter:print(X.shape, X.dtype, y.shape, y.dtype)break

torch.Size([32, 1, 64, 64]) torch.float32 torch.Size([32]) torch.int64

我们现在已准备好在接下来的部分中使用 Fashion-MNIST 数据集。

3.5.4。概括

Fashion-MNIST 是一个服装分类数据集,由代表 10 个类别的图像组成。我们将在后续章节和章节中使用这个数据集来评估各种分类算法。

我们用高度存储任何图像的高度h,宽度w,像素hxw.

数据迭代器是高效性能的关键组件。依靠利用高性能计算的良好实现的数据迭代器来避免减慢训练循环。

3.5.5。练习

减少batch_size(例如,减少到 1)会影响阅读性能吗?

读取的总数是一样的,工作总量是一样的。batch_size的目的一个是为了并行,一个是为了减少一次读太多数据,对内存存储要求太高。

数据迭代器的性能很重要。您认为当前的实施速度是否足够快?探索各种改进方案。

查看框架的在线 API 文档。还有哪些其他数据集可用?

/docs/stable/torchvision/datasets.html

Datasets:

MNIST

Fashion-MNIST

KMNIST

EMNIST

QMNIST

FakeData

COCO:Captions,Detection

LSUN

ImageFolder

DatasetFolder

ImageNet

CIFAR

STL10

SVHN

PhotoTour

SBU

Flickr

VOC

Cityscapes

SBD

USPS

Kinetics-400

HMDB51

UCF101

CelebA

参考

https://d2l.ai/chapter_linear-networks/image-classification-dataset.html

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。