Overview

In the previous section, you saw that the Fully Connected operator supports multiple GEMM (General Matrix Multiplication) variants.

To evaluate the performance of these variants across different hardware platforms, you will construct a series of benchmark models that utilize the Fully Connected operator with different GEMM implementations for comparative analysis.

These models will be used later with executor_runner to measure throughput, latency, and ETDump traces for various KleidiAI micro-kernels.

Define a linear benchmark model with PyTorch for ExecuTorch

This step can be confusing at first, but building a minimal model helps you focus on the core operator performance. You’ll be able to quickly test different GEMM implementations and see how each one performs on Arm-based hardware. If you run into errors, check that your PyTorch and ExecuTorch versions are up to date and that you’re using the correct data types for your target GEMM variant. By adjusting some of the model’s input parameters, we can also simulate the behavior of nodes that appear in real-world models.

    

        
        
import torch
import torch.nn as nn
class DemoLinearModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(256,256)

    def forward(self, x):
        y = self.linear(x)
        return y

    def get_example_inputs(self,dtype=torch.float32):
        return (torch.randn(1, 256, dtype=dtype),)

    

This model creates a single 256×256 linear layer, which can easily be exported in different data types (FP32, FP16, INT8, INT4) to match KleidiAI’s GEMM variants.

Export FP16 and FP32 models for pf16_gemm and pf32_gemm variants

XNNPACK GEMM VariantActivations DataTypeWeights DataTypeOutput DataType
pf16_gemmFP16FP16FP16
pf32_gemmFP32FP32FP32

The following code demonstrates how to lower and export a model that leverages the pf16_gemm variant to accelerate computation:

    

        
        
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.xnnpack.partition.config.xnnpack_config import ConfigPrecisionType
from executorch.exir import to_edge_transform_and_lower

def export_executorch_model(dtype: torch.dtype, model_name: str):
    mode_file_name = "model/" + model_name
    pte_file = mode_file_name + ".pte"
    etr_file = mode_file_name + ".etrecord"

    model = DemoLinearModel().eval().to(dtype)
    example_inputs = model.get_example_inputs(dtype)

    exported_program = torch.export.export(model, example_inputs)

    partitioner = XnnpackPartitioner()
    edge_program = to_edge_transform_and_lower(
        exported_program,
        partitioner=[partitioner],
        generate_etrecord=True
    )

    et_program = edge_program.to_executorch()
    with open(pte_file, "wb") as f:
        f.write(et_program.buffer)

    # Get and save ETRecord
    etrecord = et_program.get_etrecord()
    etrecord.save(etr_file)

export_executorch_model(torch.float16,"linear_model_pf16_gemm")

    

To generate a model that uses the pf32_gemm variant, simply change the dtype in the previous code to torch.float32, as shown below:

    

        
        

export_executorch_model(torch.float32,"linear_model_pf32_gemm")

    

Export INT8 quantized models for pqs8_qc8w_gemm and qp8_f32_qc8w_gemm variants

INT8 quantized GEMMs are designed to reduce memory footprint and improve performance while maintaining acceptable accuracy.

XNNPACK GEMM VariantActivations DataTypeWeights DataTypeOutput DataType
qp8_f32_qc8w_gemmAsymmetric INT8 per-row quantizationPer-channel symmetric INT8 quantizationFP32
pqs8_qc8w_gemmAsymmetric INT8 quantizationPer-channel symmetric INT8 quantizationAsymmetric INT8 quantization

The following code demonstrates how to quantized a model that leverages the pqs8_qc8w_gemm/qp8_f32_qc8w_gemm variants to accelerate computation:

    

        
        

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)

def export_int8_quantize_model(dynamic: bool, model_name: str):
    mode_file_name = "model/" + model_name
    pte_file = mode_file_name + ".pte"
    etr_file = mode_file_name + ".etrecord"

    model = DemoLinearModel().eval().to(torch.float32)
    example_inputs = model.get_example_inputs(torch.float32)

    #Quantizer model
    model = torch.export.export(model, example_inputs).module()
    quantizer = XNNPACKQuantizer()
    operator_config = get_symmetric_quantization_config(
        is_per_channel=True,
        is_dynamic=dynamic
    )

    quantizer.set_global(operator_config)
    quantize_model = prepare_pt2e(model, quantizer)
    quantize_model(*example_inputs)
    quantize_model = convert_pt2e(quantize_model)

    #lower and export model
    exported_program = torch.export.export(quantize_model, example_inputs)

    partitioner = XnnpackPartitioner()
    edge_program = to_edge_transform_and_lower(
        exported_program,
        partitioner=[partitioner],
        generate_etrecord=True
    )

    et_program = edge_program.to_executorch()
    with open(pte_file, "wb") as f:
        f.write(et_program.buffer)

    # Get and save ETRecord
    etrecord = et_program.get_etrecord()
    etrecord.save(etr_file)

export_int8_quantize_model(False,"linear_model_pqs8_qc8w_gemm");
export_int8_quantize_model(True,"linear_model_qp8_f32_qc8w_gemm");

    

Export INT4 quantized model for qp8_f32_qb4w_gemm variant

This final variant represents KleidiAI’s INT4 path, accelerated by SME2 micro-kernels.

XNNPACK GEMM VariantActivations DataTypeWeights DataTypeOutput DataType
qp8_f32_qb4w_gemmAsymmetric INT8 per-row quantizationINT4 (signed), shared blockwise quantizationFP32

The following code demonstrates how to quantized a model that leverages the qp8_f32_qb4w_gemm variant to accelerate computation:

    

        
        
from torchao.quantization.granularity import PerGroup, PerAxis
from torchao.quantization.quant_api import (
    IntxWeightOnlyConfig,
    Int8DynamicActivationIntxWeightConfig,
    quantize_,
)

def export_int4_quantize_model(dynamic: bool, model_name: str):
    mode_file_name = "model/" + model_name
    pte_file = mode_file_name + ".pte"
    etr_file = mode_file_name + ".etrecord"

    model = DemoLinearModel().eval().to(torch.float32)
    example_inputs = model.get_example_inputs(torch.float32)

    #Quantizer model

    linear_config = Int8DynamicActivationIntxWeightConfig(
        weight_dtype=torch.int4,
        weight_granularity=PerGroup(32),
    )

    quantize_(model, linear_config)

    #lower and export model
    exported_program = torch.export.export(model, example_inputs)

    partitioner = XnnpackPartitioner()
    edge_program = to_edge_transform_and_lower(
        exported_program,
        partitioner=[partitioner],
        generate_etrecord=True
    )

    et_program = edge_program.to_executorch()
    with open(pte_file, "wb") as f:
        f.write(et_program.buffer)

    # Get and save ETRecord
    etrecord = et_program.get_etrecord()
    etrecord.save(etr_file)

export_int4_quantize_model(False,"linear_model_qp8_f32_qb4w_gemm");

    
Note

When exporting models, the generate_etrecord option is enabled to produce the .etrecord file alongside the .pte model file. These ETRecord files are essential for subsequent model inspection and performance analysis using the ExecuTorch Inspector API.

Run the benchmark model export script for ExecuTorch

Instead of manually executing each code block explained above, you can download and run the full example script that builds and exports all linear-layer benchmark models (FP16, FP32, INT8, and INT4). This script automatically performs quantization, partitioning, lowering, and export to ExecuTorch format.

    

        
        
wget https://raw.githubusercontent.com/ArmDeveloperEcosystem/arm-learning-paths/refs/heads/main/content/learning-paths/mobile-graphics-and-gaming/measure-kleidiai-kernel-performance-on-executorch/export-linear-model.py
chmod +x export-linear-model.py
python3 ./export-linear-model.py

    

Verify exported ExecuTorch and KleidiAI model files

After successful execution, you should see both .pte (ExecuTorch model) and .etrecord (profiling metadata) files in the model/ directory:

    

        
        
$ ls model/ -1
linear_model_pf16_gemm.etrecord
linear_model_pf16_gemm.pte
linear_model_pf32_gemm.etrecord
linear_model_pf32_gemm.pte
linear_model_pqs8_qc8w_gemm.etrecord
linear_model_pqs8_qc8w_gemm.pte
linear_model_qp8_f32_qb4w_gemm.etrecord
linear_model_qp8_f32_qb4w_gemm.pte
linear_model_qp8_f32_qc8w_gemm.etrecord
linear_model_qp8_f32_qc8w_gemm.pte

    

Great job! You now have a complete set of benchmark models exported for multiple GEMM variants and quantization levels. You’re ready to move on and measure performance using ExecuTorch and KleidiAI micro-kernels on Arm-based hardware.

Back
Next