Introduction
Prepare a PyTorch development environment
Create a PyTorch model for MNIST
About PyTorch model training
Perform training and save the model
Use the model for inference
Understand inference on Android
Create an Android application
Prepare Test Data
Run the Application
Optimizing neural network models in PyTorch
Create an optimized PyTorch model for MNIST
Run optimization
Update the Android application
Review
Next Steps
In this section you will create an Android application to run digit classification.
The application randomly loads a selected image containing a handwritten digit and its true label.
The application runs an inference on the image and predicts the digit value.
Start by creating a project:
Open Android Studio and create a new project with an “Empty Views Activity.”
Set the project name to ArmPyTorchMNISTInference, set the package name to: com.arm.armpytorchmnistinference, select Kotlin as the language, and set the minimum SDK to API 27 (“Oreo” Android 8.1).
Set the API to Android 8.1 (API level 27) because this version introduced NNAPI, providing a standard interface for running computationally intensive machine learning models on Android devices.
Devices with hardware accelerators can leverage NNAPI to offload ML tasks to specialized hardware, such as NPUs (Neural Processing Units), DSPs (Digital Signal Processors), or GPUs (Graphics Processing Units).
The user interface design contains the following:
ImageView
and TextView
sections to display the image and its true label.TextView
controls to display the predicted label and inference time.Use the Android Studio editor to replace the contents of activity_main.xml
, located in src/main/res/layout
with the following code:
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout
xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:padding="16dp"
android:gravity="center">
<!-- Header -->
<TextView
android:id="@+id/header"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Digit Recognition"
android:textSize="24sp"
android:textStyle="bold"
android:layout_marginBottom="16dp"/>
<!-- ImageView to display the image -->
<ImageView
android:id="@+id/imageView"
android:layout_width="200dp"
android:layout_height="200dp"
android:layout_gravity="center"
android:contentDescription="Image for inference"
android:layout_marginBottom="16dp"/>
<!-- Label showing the true label of the image -->
<TextView
android:id="@+id/trueLabel"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="True Label: N/A"
android:textSize="18sp"
android:layout_marginBottom="16dp"/>
<!-- Button to select an input image -->
<Button
android:id="@+id/selectImageButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Load Image"
android:layout_marginBottom="16dp"/>
<!-- Button to run inference -->
<Button
android:id="@+id/runInferenceButton"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Run Inference"
android:layout_marginBottom="16dp"/>
<!-- TextView to display the predicted label and inference time -->
<TextView
android:id="@+id/predictedLabel"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Predicted Label: N/A"
android:textSize="18sp"
android:layout_marginBottom="8dp"/>
<TextView
android:id="@+id/inferenceTime"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Inference Time: N/A ms"
android:textSize="18sp"/>
</LinearLayout>
The above XML code defines a user interface layout for an Android activity using a vertical LinearLayout
. It includes several UI components arranged vertically with padding and centered alignment.
At the top, there is a TextView
acting as a header, displaying the text Digit Recognition
in bold and with a large font size.
Below the header, an ImageView
displays an image, with a default source set to sample_image
.
This is followed by another TextView
that shows the true label of the displayed image, initially set to True Label: N/A
.
The layout also contains two buttons: one labeled Load Image
for selecting an input image, and another labeled Run Inference
to execute the inference process on the selected image.
At the bottom, there are two TextView
elements to display the predicted label and the inference time, both initially set to N/A
. The layout uses margins and appropriate sizes for each element to ensure a clean and organized appearance.
Add PyTorch to the project by opening the build.gradle.kts
file and adding the following two lines under dependencies:
implementation("org.pytorch:pytorch_android:1.10.0")
implementation("org.pytorch:pytorch_android_torchvision:1.10.0")
The dependencies section should look as follows:
dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.appcompat)
implementation(libs.material)
implementation(libs.androidx.activity)
implementation(libs.androidx.constraintlayout)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
implementation("org.pytorch:pytorch_android:1.10.0")
implementation("org.pytorch:pytorch_android_torchvision:1.10.0")
}
You will now implement the logic for the application.
This includes loading the pre-trained model, loading and displaying images, and running inference.
Open MainActivity.kt
and modify it as follows:
package com.arm.armpytorchmnistinference
import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.os.Bundle
import android.widget.Button
import android.widget.ImageView
import android.widget.TextView
import androidx.activity.enableEdgeToEdge
import androidx.appcompat.app.AppCompatActivity
import org.pytorch.IValue
import org.pytorch.Module
import org.pytorch.Tensor
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.InputStream
import kotlin.random.Random
import kotlin.system.measureNanoTime
class MainActivity : AppCompatActivity() {
private lateinit var imageView: ImageView
private lateinit var trueLabel: TextView
private lateinit var selectImageButton: Button
private lateinit var runInferenceButton: Button
private lateinit var predictedLabel: TextView
private lateinit var inferenceTime: TextView
private lateinit var model: Module
private var currentBitmap: Bitmap? = null
private var currentTrueLabel: Int? = null
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContentView(R.layout.activity_main)
// Initialize UI elements
imageView = findViewById(R.id.imageView)
trueLabel = findViewById(R.id.trueLabel)
selectImageButton = findViewById(R.id.selectImageButton)
runInferenceButton = findViewById(R.id.runInferenceButton)
predictedLabel = findViewById(R.id.predictedLabel)
inferenceTime = findViewById(R.id.inferenceTime)
// Load model from assets
model = Module.load(assetFilePath("model.pth"))
// Set up button click listener for selecting random image
selectImageButton.setOnClickListener {
selectRandomImageFromAssets()
}
// Set up button click listener for running inference
runInferenceButton.setOnClickListener {
currentBitmap?.let { bitmap ->
runInference(bitmap)
}
}
}
private fun selectRandomImageFromAssets() {
try {
// Get list of files in the mnist_bitmaps folder
val assetManager = assets
val files = assetManager.list("mnist_bitmaps") ?: arrayOf()
if (files.isEmpty()) {
trueLabel.text = "No images found in assets/mnist_bitmaps"
return
}
// Select a random file from the list
val randomFile = files[Random.nextInt(files.size)]
val inputStream: InputStream = assetManager.open("mnist_bitmaps/$randomFile")
val bitmap = BitmapFactory.decodeStream(inputStream)
// Extract the true label from the filename (e.g., 07_00.png -> true label is 7)
currentTrueLabel = randomFile.split("_")[0].toInt()
// Display the image and its true label
imageView.setImageBitmap(bitmap)
trueLabel.text = "True Label: $currentTrueLabel"
// Set the current bitmap for inference
currentBitmap = bitmap
} catch (e: IOException) {
e.printStackTrace()
trueLabel.text = "Error loading image from assets"
}
}
// Method to convert a grayscale bitmap to a float array and create a tensor with shape [1, 1, 28, 28]
private fun createTensorFromBitmap(bitmap: Bitmap): Tensor {
// Ensure the bitmap is in the correct format (grayscale) and dimensions [28, 28]
if (bitmap.width != 28 || bitmap.height != 28) {
throw IllegalArgumentException("Expected bitmap of size [28, 28], but got [${bitmap.width}, ${bitmap.height}]")
}
// Convert the grayscale bitmap to a float array
val width = bitmap.width
val height = bitmap.height
val floatArray = FloatArray(width * height)
val pixels = IntArray(width * height)
bitmap.getPixels(pixels, 0, width, 0, 0, width, height)
for (i in pixels.indices) {
// Normalize pixel values to [0, 1] range, assuming the grayscale image stores values in the R channel
floatArray[i] = (pixels[i] and 0xFF) / 255.0f
}
// Create a tensor with shape [1, 1, 28, 28] (batch size, channels, height, width)
return Tensor.fromBlob(floatArray, longArrayOf(1, 1, height.toLong(), width.toLong()))
}
private fun runInference(bitmap: Bitmap) {
// Convert bitmap to a float array and create a tensor with shape [1, 1, 28, 28]
val inputTensor = createTensorFromBitmap(bitmap)
// Run inference and measure time
val inferenceTimeMicros = measureTimeMicros {
// Forward pass through the model
val outputTensor = model.forward(IValue.from(inputTensor)).toTensor()
val scores = outputTensor.dataAsFloatArray
// Get the index of the class with the highest score
val maxIndex = scores.indices.maxByOrNull { scores[it] } ?: -1
predictedLabel.text = "Predicted Label: $maxIndex"
}
// Update inference time TextView in microseconds
inferenceTime.text = "Inference Time: $inferenceTimeMicros µs"
}
// Method to measure execution time in microseconds
private inline fun measureTimeMicros(block: () -> Unit): Long {
val time = measureNanoTime(block)
return time / 1000 // Convert nanoseconds to microseconds
}
// Helper function to get the file path from assets
private fun assetFilePath(assetName: String): String {
val file = File(filesDir, assetName)
assets.open(assetName).use { inputStream ->
FileOutputStream(file).use { outputStream ->
val buffer = ByteArray(4 * 1024)
var read: Int
while (inputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
outputStream.flush()
}
}
return file.absolutePath
}
}
The above Kotlin code defines an Android app activity called MainActivity
that performs inference on the MNIST dataset using a pre-trained PyTorch model. The app allows the user to load a random MNIST image from the assets
folder and runs the model to classify the image.
The MainActivity class contains several methods. The first one, onCreate()
is called when the activity is first created. It sets up the user interface by inflating the layout defined in activity_main.xml
and initializes several UI components, including an ImageView
to display the image, TextView
controls to show the true label and predicted label, and two buttons (selectImageButton
and runInferenceButton
) to select an image and run inference. The method then loads the PyTorch model from the assets folder using the assetFilePath()
function and sets up click listeners for the buttons. The selectImageButton
is configured to select a random image from the mnist_bitmaps
folder, while the runInferenceButton
runs the inference on the selected image.
Next, the selectRandomImageFromAssets()
method is responsible for selecting a random image from the mnist_bitmaps
folder in the assets. It lists all the files in the folder, picks one at random, and loads it as a bitmap. The method then extracts the true label from the filename (e.g., 07_00.png implies a true label of 7), displays the selected image in the ImageView
, and updates the trueLabel TextView
with the correct label. If there is an error loading the image or the folder is empty, an appropriate error message is displayed in the trueLabel TextView
.
Afterward, the createTensorFromBitmap()
converts a grayscale bitmap of size 28x28 (an image from the MNIST dataset) into a PyTorch Tensor. First, the method verifies that the bitmap has the correct dimensions. Then, it extracts pixel data from the bitmap, normalizes each pixel value to a float in the range [0, 1], and stores the values in a float array. The method finally constructs and returns a tensor with the shape [1, 1, 28, 28], where 1 is the batch size, 1 is the number of channels (for grayscale), and 28 represents the width and height of the image. This is required to match the input expected by the model.
Subsequently, we have the runInference()
method. It accepts a bitmap as input and performs inference using the pre-trained PyTorch model. It first converts the bitmap to a tensor using the createTensorFromBitmap()
method. Then, it measures the time taken to run the forward pass of the model using the measureTimeMicros()
method. The output tensor from the model, which contains the scores for each digit class, is processed to determine the predicted label. This predicted label is displayed in the predictedLabel TextView
. The method also updates the inferenceTime TextView
with the time taken for the inference in microseconds.
Also, we have an inline function measureTimeMicros()
. It is a utility method that measures the execution time of the provided code block in microseconds. It uses the measureNanoTime()
function to get the execution time in nanoseconds and then converts it to microseconds by dividing the result by 1000. This method is used to measure the time taken for model inference in the runInference()
method.
The assetFilePath()
method is a helper function that copies a file from the assets folder to the application’s internal storage and returns the absolute path of the copied file. This is necessary because PyTorch’s Module.load()
method requires a file path, not an InputStream. The function reads the specified asset file, writes its contents to a file in the internal storage, and returns the path to this file. This method is used in onCreate()
to load the PyTorch model file, model.pth
, from the assets
folder.
The MainActivity
class initializes the UI components, loads a pre-trained PyTorch model, and allows the user to select random MNIST images and run inference on them. Each method is designed to handle a specific aspect of the functionality, such as loading images, converting them to tensors, running inference, and measuring execution time. The code is modular and organized, making it easy to understand and maintain.
To be able to successfully run the application you need to add the model and prepare the bitmaps. Continue to see how to prepare the data.