Normalize CIFAR10 Dataset Tensor

Use Torchvision Transforms Normalize (transforms.Normalize) to normalize CIFAR10 dataset tensors using the mean and standard deviation of the dataset

Type: FREE   By: Tylan O'Flynn, AIWorkbox.com Instructor Tylan O'Flynn   Duration: 1:42   Technologies: PyTorch, Python

Page Sections: Video  |  Code  |  Transcript


< > Code:

You must be a Member to view code

Access all courses and lessons, gain confidence and expertise, and learn how things work and how to use them.

    or   Log In


Transcript:

Now that we know how to convert CIFAR10 PIL images to PyTorch tensors, we may also want to normalize the resulting tensors.

Dataset normalization has consistently been shown to improve generalization behavior in deep learning models.


We will first want to import PyTorch and Torchvision.

import torch

import torchvision


We will then want to import torchvision.datasets as datasets and torchvision.transforms as transforms.

import torchvision.datasets as datasets

import torchvision.transforms as transforms


We will also want to check that our versions for both PyTorch 0.4.0 and Torchvision 0.2.1 are current.

print(torch.__version__)

print(torchvision.__version__)


We will then define our normalize function as follows: normalize equals transforms.Normalize.

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

The CIFAR10 tensors have three channels – red, green, and blue – and the argument is that the mean parameter specifies our target mean for each channel.

In this case, 0.5 for all three.

Similarly, the std parameter takes a list target standard deviations for each channel which we also specify here to be 0.5.

This tends to be a good starting point.


If we import the CIFAR10 set as usual, transforming the PIL images to tensors on import:

cifar_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())


and pick out a tensor:

datapoint = cifar_trainset[0][0]


then we can print that tensor to see what it looks like.

print(datapoint)


We can then apply our newly defined normalized transform to this tensor by calling normalize for that tensor as an argument.

normalize(datapoint)

We can see here that our normalization transform did in fact alter the tensor.


We could normalize the entire dataset by looping over it and calling normalize on each tensor individually.

However, this is not the cleanest way to include a normalization step when importing datasets from torchvision.

We should instead include normalize in the transform argument when importing the CIFAR10 set, and for that we will need to combine the two tensors and normalize transforms using transforms.Compose.



Back to deep learning tutorial lesson list