Skip to content

pytorch/ao

torchao: PyTorch Architecture Optimization

This repository is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an issue

Introduction

torchao is a PyTorch library for quantization and sparsity.

Get Started

Installation

torchao makes liberal use of several new features in pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.

Stable Release

pip install torchao

Nightly Release

pip install torchao-nightly

From source

git clone https://github.com/pytorch/ao
cd ao
pip install .

Quantization

import torch
import torchao

# inductor settings which improve torch.compile performance for quantized modules
torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

# Plug in your model and example input
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))

# compile the model to recover performance
model = torch.compile(model, mode='max-autotune')
model(input)

Sparsity

import torch
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
from torch.ao.pruning import WeightNormSparsifier

# bfloat16 CUDA model
model = torch.nn.Sequential(torch.nn.Linear(64, 64)).cuda().to(torch.bfloat16)

# Accuracy: Finding a sparse subnetwork
sparse_config = []
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      sparse_config.append({"tensor_fqn": f"{name}.weight"})

sparsifier = WeightNormSparsifier(sparsity_level=1.0,
                                 sparse_block_shape=(1,4),
                                 zeros_per_block=2)

# attach FakeSparsity
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()
# now we have dense model with sparse weights

# Performance: Accelerated sparse inference
for name, mod in model.named_modules():
   if isinstance(mod, torch.nn.Linear):
      mod.weight = torch.nn.Parameter(to_sparse_semi_structured(mod.weight))

To learn more try out our APIs, you can check out API examples in

Supported Features

  1. Quantization algorithms
  2. Sparsity algorithms such as Wanda that help improve accuracy of sparse networks
  3. Support for lower precision dtypes such as
  4. Bleeding Edge Kernels for experimental kernels without backwards compatibility guarantees

Our Goals

  • Composability with torch.compile: We rely heavily on torch.compile to write pure PyTorch code and codegen efficient kernels. There are however limits to what a compiler can do so we don't shy away from writing our custom CUDA/Triton kernels
  • Composability with FSDP: The new support for FSDP per parameter sharding means engineers and researchers alike can experiment with different quantization and distributed strategies concurrently.
  • Performance: We measure our performance on every commit using an A10G. We also regularly run performance benchmarks on the torchbench suite
  • Heterogeneous Hardware: Efficient kernels that can run on CPU/GPU based server (w/ torch.compile) and mobile backends (w/ ExecuTorch).
  • Packaging kernels should be easy: We support custom CUDA and Triton extensions so you can focus on writing your kernels and we'll ensure that they work on most operating systems and devices

Integrations

torchao has been integrated with other libraries including

  • torchtune leverages our 8 and 4 bit weight-only quantization techniques with optional support for GPTQ
  • Executorch leverages our GPTQ implementation for both 8da4w (int8 dynamic activation with int4 weight) and int4 weight-only quantization.
  • HQQ leverages our int4mm kernel for low latency inference

Success stories

Our kernels have been used to achieve SOTA inference performance on

License

torchao is released under the BSD 3 license.