Introduction
Prepare a PyTorch development environment
Create a PyTorch model for MNIST
About PyTorch model training
Perform training and save the model
Use the model for inference
Understand inference on Android
Create an Android application
Prepare Test Data
Run the Application
Optimizing neural network models in PyTorch
Create an optimized PyTorch model for MNIST
Run optimization
Update the Android application
Review
Next Steps
In the previous section, you created a feedforward neural network for digit classification using the MNIST dataset. The network was left untrained and lacks the ability to make accurate predictions.
To enable the network to recognize handwritten digits effectively, training is needed. Training in PyTorch involves configuring the network’s parameters, such as weights and biases, by exposing the model to labeled data and iteratively adjusting these parameters to minimize prediction errors. This process allows the model to learn the patterns in the data, enabling it to make accurate classifications on new, unseen inputs.
The typical approach to training a neural network in PyTorch involves several key steps.
First, obtain and preprocess the dataset, which usually includes normalizing the data and converting it into a format suitable for the model.
Next, the dataset is split into training and testing subsets. Training data is used to update the model’s parameters, while testing data evaluates its performance. During training, feed batches of input data through the network, calculate the prediction error or loss using a loss function (such as cross-entropy for classification tasks), and optimize the model’s weights and biases using backpropagation. Backpropagation involves computing the gradient of the loss with respect to each parameter and then updating the parameters using an optimizer, like Stochastic Gradient Descent (SGD) or Adam. This process is repeated for multiple epochs until the model achieves satisfactory performance, balancing accuracy and generalization.
Loss is a measure of how well a model’s predictions match the true labels of the data. It quantifies the difference between the predicted output and the actual output. The lower the loss, the better the model’s performance. In classification tasks, a common loss function is Cross-Entropy Loss, while Mean Squared Error (MSE) is often used for regression tasks. The goal of training is to minimize the loss, which indicates that the model’s predictions are getting closer to the actual labels.
Gradients represent the rate of change of the loss with respect to each of the model’s parameters (weights and biases). They are used to update the model’s parameters in the direction that reduces the loss. Gradients are calculated during the backpropagation step, where the loss is propagated backward through the network to compute how each parameter contributes to the overall loss. Optimizers like SGD or Adam use these gradients to adjust the parameters, effectively “teaching” the model to improve its predictions.
An epoch refers to one complete pass through the entire training dataset. During each epoch, the model sees every data point once and updates its parameters accordingly. Multiple epochs are typically required to train a model effectively because, during each epoch, the model learns and fine-tunes its parameters based on the data it processes. The number of epochs is a hyperparameter that you set before training, and increasing it can improve the model’s performance, but too many epochs may lead to overfitting, where the model performs well on training data but poorly on new, unseen data.
Backpropagation is a fundamental algorithm used in training neural networks to optimize their parameters—weights and biases—by minimizing the loss function. It works by propagating the error backward through the network, calculating the gradients of the loss function with respect to each parameter, and updating these parameters accordingly.
To train a model in PyTorch, several essential components are required:
Dataset: the source of data that the model will learn from. It typically consists of input samples and their corresponding labels. PyTorch provides the torchvision.datasets
module for easy access to popular datasets like MNIST, CIFAR-10, and ImageNet. You can also create custom datasets using the torch.utils.data.Dataset
class.
DataLoader: used to efficiently load and batch the data during training. It handles data shuffling, batching, and parallel loading, making it easier to feed the data into the model in a structured manner. This is crucial for performance, especially when working with large datasets.
Model: the Neural Network Architecture defines the structure of the neural network. You learned that in PyTorch, models are typically created by subclassing torch.nn.Module
and defining the network layers and forward pass. This includes specifying the input and output dimensions and the sequence of layers, such as linear layers, activation functions, and dropout.
Loss Function: measures how far the model’s predictions are from the actual targets. It guides the optimization process by providing a signal that tells the model how to adjust its parameters. Common loss functions include Cross-Entropy Loss for classification tasks and Mean Squared Error (MSE) Loss for regression tasks. You can select a predefined loss function from torch.nn or define your own.
Optimizer: updates the model’s parameters based on the gradients computed during backpropagation. It determines how the model learns from the data. Popular optimizers include Stochastic Gradient Descent (SGD) and Adam, which are available in the torch.optim module. You need to specify the learning rate (a hyperparameter that controls how much to change the parameters in response to the gradient) and other hyperparameters when creating the optimizer.
Training Loop: where the actual learning happens. For each iteration of the loop:
This process is repeated for a specified number of epochs to gradually reduce the loss and improve the model’s performance.
In the next step you will see how to perform model training.