torchao
概述#
原文:pytorch-native-architecture-optimization
torchao 是 PyTorch 原生库,通过利用低位宽数据类型、量化和稀疏性,使模型更快更小。torchao
是一个易于访问的工具包,包含(主要是)用易于阅读的 PyTorch 代码编写的技术,涵盖推理和训练两个方面。
除非另有说明,基线是在 A100 80GB GPU 上运行的 bf16。
针对 LLama 3 的主要指标包括:
使用
autoquant
和仅int4
权重量化加hqq
,使 LLama 3 8B 推理速度提升 \(97\%\)。在 128K 上下文长度下,使用量化 KV 缓存,使 LLama 3.1 8B 推理的峰值 VRAM 减少 \(73\%\)。
使用
float8
训练在 H100 上进行 LLama 3 70B 预训练,速度提升 \(50\%\)。使用 4 比特量化优化器,使 LLama 3 8B 的峰值 VRAM 减少 \(30\%\)。
针对扩散模型推理的主要指标包括:
在
flux1.dev
上使用 float8 动态量化推理和 float8 逐行缩放,在 H100 上速度提升 \(53\%\)。对于
CogVideoX
,使用int8
动态量化使模型 VRAM 减少 \(50\%\)。
推理量化算法#
推理量化算法适用于包含 nn.Linear
层的任意 PyTorch 模型。通过我们的顶层 API quantize_
,可以选择仅权重和动态激活量化,支持多种数据类型和稀疏布局。
from torchao.quantization import (
quantize_,
int4_weight_only,
)
quantize_(model, int4_weight_only())
有时,由于开销问题,量化一个层可能会使其变慢。因此,如果你希望我们为你选择如何量化模型中的每一层,那么你可以选择运行
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
quantize_
API 根据模型是计算密集型还是内存密集型提供了一些不同的选项。
from torchao.quantization import (
# Memory bound models
int4_weight_only,
int8_weight_only,
# Compute bound models
int8_dynamic_activation_int8_semi_sparse_weight,
int8_dynamic_activation_int8_weight,
# Device capability 8.9+
float8_weight_only,
float8_dynamic_activation_float8_weight,
)
API是可组合的,例如我们结合了稀疏性和量化,为 ViT-H 推理带来了 \(5\%\) 的速度提升。
但我们也可以做一些事情,比如将权重量化为 int4
,并将 kv 缓存量化为 int8
,以支持在不到 18.9GB VRAM 下全长度 128K 上下文运行的 Llama 3.1 8B。
QAT(量化感知训练)#
在 4 比特以下的后训练量化中,准确性可能会严重下降。通过使用量化感知训练(Quantization Aware Training, QAT),我们已经成功恢复了高达 \(96\%\) 的准确性损失。我们将这一方法作为端到端方案集成到了 torchtune
中,并附带了一个简单的教程。
低精度计算和通信#
torchao
提供易于使用的端到端工作流,用于降低训练计算和分布式通信的精度,从 torch.nn.Linear
层的 float8
开始。以下是将训练运行的计算 gemm
转换为 float8
的一行代码:
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(model)
有关如何通过使用 float8
将 LLaMa 3 70B 预训练速度提高多达 1.5 倍的端到端示例,请参阅我们的 README、torchtitan 的博客和 float8
配方。
我们正在扩展我们的训练工作流以支持更多的数据类型和布局。
低比特优化器#
受到 Bits and Bytes 的启发,我们还添加了 8 比特和 4 比特优化器的原型支持,作为 AdamW
的即插即用替代品。
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit
optim = AdamW8bit(model.parameters())
集成#
我们一直在积极努力,确保 torchao
在开源中一些最重要的项目中能够良好工作。
在 diffusers-torchao 中作为加速扩散模型的参考实现
在 HQQ 中用于快速 4 比特推理
在
torchtune
中用于 PyTorch 原生 QLoRA 和 QAT 配方在
torchchat
中用于后训练量化在 SGLang 中用于
int4
和int8
后训练量化