torchao 概述#

原文:pytorch-native-architecture-optimization

torchao 是 PyTorch 原生库,通过利用低位宽数据类型、量化和稀疏性,使模型更快更小。torchao 是一个易于访问的工具包,包含(主要是)用易于阅读的 PyTorch 代码编写的技术,涵盖推理和训练两个方面。

除非另有说明,基线是在 A100 80GB GPU 上运行的 bf16。

针对 LLama 3 的主要指标包括:

  • 使用 autoquant 和仅 int4 权重量化加 hqq,使 LLama 3 8B 推理速度提升 \(97\%\)

  • 在 128K 上下文长度下,使用量化 KV 缓存,使 LLama 3.1 8B 推理的峰值 VRAM 减少 \(73\%\)

  • 使用 float8 训练在 H100 上进行 LLama 3 70B 预训练,速度提升 \(50\%\)

  • 使用 4 比特量化优化器,使 LLama 3 8B 的峰值 VRAM 减少 \(30\%\)

针对扩散模型推理的主要指标包括:

  • flux1.dev 上使用 float8 动态量化推理和 float8 逐行缩放,在 H100 上速度提升 \(53\%\)

  • 对于 CogVideoX,使用 int8 动态量化使模型 VRAM 减少 \(50\%\)

推理量化算法#

推理量化算法适用于包含 nn.Linear 层的任意 PyTorch 模型。通过我们的顶层 API quantize_,可以选择仅权重和动态激活量化,支持多种数据类型和稀疏布局。

from torchao.quantization import (  
    quantize_,  
    int4_weight_only,  
)  
quantize_(model, int4_weight_only())

有时,由于开销问题,量化一个层可能会使其变慢。因此,如果你希望我们为你选择如何量化模型中的每一层,那么你可以选择运行

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

quantize_ API 根据模型是计算密集型还是内存密集型提供了一些不同的选项。

from torchao.quantization import (  
    # Memory bound models  
    int4_weight_only,  
    int8_weight_only,

    # Compute bound models  
    int8_dynamic_activation_int8_semi_sparse_weight,  
    int8_dynamic_activation_int8_weight,  
      
    # Device capability 8.9+  
    float8_weight_only,  
    float8_dynamic_activation_float8_weight,  
)

API是可组合的,例如我们结合了稀疏性和量化,为 ViT-H 推理带来了 \(5\%\) 的速度提升。

但我们也可以做一些事情,比如将权重量化为 int4,并将 kv 缓存量化为 int8,以支持在不到 18.9GB VRAM 下全长度 128K 上下文运行的 Llama 3.1 8B。

QAT(量化感知训练)#

在 4 比特以下的后训练量化中,准确性可能会严重下降。通过使用量化感知训练(Quantization Aware Training, QAT),我们已经成功恢复了高达 \(96\%\) 的准确性损失。我们将这一方法作为端到端方案集成到了 torchtune 中,并附带了一个简单的教程

低精度计算和通信#

torchao提供易于使用的端到端工作流,用于降低训练计算和分布式通信的精度,从 torch.nn.Linear 层的 float8 开始。以下是将训练运行的计算 gemm 转换为 float8 的一行代码:

from torchao.float8 import convert_to_float8_training  
convert_to_float8_training(model)

有关如何通过使用 float8 将 LLaMa 3 70B 预训练速度提高多达 1.5 倍的端到端示例,请参阅我们的 READMEtorchtitan 的博客float8 配方

我们正在扩展我们的训练工作流以支持更多的数据类型和布局。

低比特优化器#

受到 Bits and Bytes 的启发,我们还添加了 8 比特和 4 比特优化器的原型支持,作为 AdamW 的即插即用替代品。

from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit  
optim = AdamW8bit(model.parameters())

集成#

我们一直在积极努力,确保 torchao 在开源中一些最重要的项目中能够良好工作。