使用 TVM nn.module 工作流在 MLC-LLM 中添加新模型架构#

在本教程中,将演示如何使用新的 TVM nn.module 工作流在 MLC-LLM 中添加新模型架构。TVM nn.module 是新的模型编译工作流,旨在为 MLC-LLM 带来模块化的 Python 优先编译,使用户和开发者能够更无缝地支持新模型和功能。

例如,在 TVM nn.module 工作流下,定义 Mistral 模型架构所需的代码量仅为旧工作流的一半左右。从高层次来看,TVM nn.module 与 PyTorch nn.module 接口非常相似。

在这里,将使用 GPT-2 进行演示。GPT-2 是以自监督方式在非常大的英语语料库上预训练的 transformers 模型,可用于猜测句子中的下一个单词。它在 Huggingface 中的模型定义可以找到。

定义 GPT-2 模型#

mlc-llm/python/mlc_llm/model/ 下创建 gpt2 文件夹。其结构将如下所示:

mlc-llm/python/mlc_llm/model/gpt2/
├── gpt2_loader.py          # 从 Huggingface 加载并转换权重
├── gpt2_model.py           # 定义模型架构和配置
├── gpt2_quantization.py    # 定义量化方案
└── __init__.py

首先关注 gpt2_model.py。该文件使用 tvm.relax.frontend.nn.Module 以模块化的方式定义 GPT-2 模型架构,类似于 PyTorch 的对应部分。

from set_env import temp_dir

gpt2_model.py 中定义配置类#

首先,定义配置类,它几乎是从 Huggingface 的 GPT2Config 直接翻译过来的。该类的属性应与 Huggingface 配置中的相应属性同名,否则 Huggingface 配置将无法正确加载。

__post_init__ 函数在所有数据类属性初始化后被调用。

from mlc_llm.model.gpt2.gpt2_model import GPT2Config

GPT2Config??
Hide code cell output
Init signature:
GPT2Config(
    vocab_size: int,
    n_embd: int,
    n_layer: int,
    n_head: int,
    layer_norm_epsilon: float,
    n_inner: int = -1,
    context_window_size: int = 0,
    prefill_chunk_size: int = 0,
    scale_attn_by_inverse_layer_idx: bool = False,
    tensor_parallel_shards: int = 1,
    head_dim: int = 0,
    max_batch_size: int = 1,
    kwargs: Dict[str, Any] = <factory>,
) -> None
Source:        
@dataclasses.dataclass
class GPT2Config(ConfigBase):  # pylint: disable=too-many-instance-attributes
    """Configuration of the GPT-2 model."""

    vocab_size: int
    n_embd: int
    n_layer: int
    n_head: int
    layer_norm_epsilon: float
    n_inner: int = -1
    context_window_size: int = 0
    prefill_chunk_size: int = 0
    scale_attn_by_inverse_layer_idx: bool = False
    tensor_parallel_shards: int = 1
    head_dim: int = 0
    max_batch_size: int = 1
    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)

    def __post_init__(self):
        if self.n_inner is None or self.n_inner == -1:
            self.n_inner = 4 * self.n_embd
        if self.context_window_size == 0:
            for name in ["n_positions", "max_sequence_length"]:
                if name in self.kwargs:
                    self.context_window_size = self.kwargs.pop(name)
                    logger.info(
                        "%s not found in config.json. Falling back to %s (%d)",
                        bold("context_window_size"),
                        bold(name),
                        self.context_window_size,
                    )
                    break
            else:
                raise ValueError(
                    "Unable to determine the maximum sequence length, because none of "
                    "`context_window_size`, `n_positions` or `max_sequence_length` is "
                    "provided in `config.json`."
                )
        if self.head_dim == 0:
            self.head_dim = self.n_embd // self.n_head
        assert self.head_dim * self.n_head == self.n_embd
        if self.prefill_chunk_size == 0:
            logger.info(
                "%s defaults to %d",
                bold("prefill_chunk_size"),
                min(self.context_window_size, 8192),
            )
            self.prefill_chunk_size = min(self.context_window_size, 8192)
        elif self.prefill_chunk_size > self.context_window_size:
            logger.info(
                "Overriding %s from %d to %d",
                bold("prefill_chunk_size"),
                self.prefill_chunk_size,
                min(self.context_window_size, 8192),
            )
            self.prefill_chunk_size = min(self.context_window_size, 8192)
File:           /media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/model/gpt2/gpt2_model.py
Type:           type
Subclasses:     

gpt2_model.py 中定义模型架构#

使用 tvm.relax.frontend.nn.Module,能够以模块化的方式定义模型架构。它看起来与 PyTorch 风格非常相似,只是前向函数实际上并不执行计算。它使用作为输入传递的占位符来跟踪计算图。

你可以选择使用 op._print(some_tensor) 在运行编译模块时打印张量的中间值。如果你这样做,你必须在 export_tvm()jit() 中指定 debug=True。除了手动打印外,还提供了端到端的调试模块 DebugChat,它将自动转储所有层的中间值。

from mlc_llm.model.gpt2.gpt2_model import GPT2Attention
GPT2Attention??
Hide code cell output
Init signature: GPT2Attention(config: mlc_llm.model.gpt2.gpt2_model.GPT2Config)
Docstring:     
Base class for neural network components. Subclass it to build your models.
Modules can nest within each other in a tree structure using regular attribute assignment.
Source:        
class GPT2Attention(nn.Module):  # pylint: disable=too-many-instance-attributes
    def __init__(self, config: GPT2Config):
        self.embed_dim = config.n_embd
        if config.n_head % config.tensor_parallel_shards != 0:
            raise ValueError(
                f"Cannot split {config.n_head} attention heads "
                f"evenly to {config.tensor_parallel_shards} GPUs."
            )
        self.num_heads = config.n_head // config.tensor_parallel_shards
        self.head_dim = config.head_dim
        self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx

        self.c_attn = nn.Linear(
            in_features=self.embed_dim,
            out_features=3 * self.num_heads * self.head_dim,
            bias=True,
        )
        self.c_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True)

    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):
        d, h = self.head_dim, self.num_heads
        b, s, _ = hidden_states.shape

        qkv = self.c_attn(hidden_states)
        qkv = op.reshape(qkv, (b, s, 3 * h, d))

        if self.scale_attn_by_inverse_layer_idx:
            attn_score_scaling_factor = 1.0 / float(layer_id + 1)
        else:
            attn_score_scaling_factor = 1.0

        # Attention
        output = op.reshape(
            paged_kv_cache.attention_with_fused_qkv(
                layer_id, qkv, self.num_heads, attn_score_scaling_factor
            ),
            (b, s, h * d),
        )
        return self.c_proj(output)
File:           /media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/model/gpt2/gpt2_model.py
Type:           type
Subclasses:     

请注意,已经提供了一些内置的常用模块,你会发现它们非常方便。例如,这里的 nn.Linearnn.KVCache 模块都是 MLC-LLM 中的内置模块

同样,也提供了许多常见的对张量进行操作的内置算子。例如,op.reshapeop.matmulop.softmax 等。

使用 nn.spec 定义模型规范#

一旦验证了模型的每一层行为正确,就可以编写模型规范,将模型从 nn.module 转换为 TVM IRModule。

get_default_spec 函数中,需要定义如下模型规范:

from mlc_llm.model.gpt2.gpt2_model import GPT2LMHeadModel

GPT2LMHeadModel.get_default_spec??
Hide code cell output
Signature: GPT2LMHeadModel.get_default_spec(self)
Docstring: <no docstring>
Source:   
    def get_default_spec(self):
        mod_spec = {
            "embed": {
                "input_ids": nn.spec.Tensor(["seq_len"], "int32"),
                "$": {
                    "param_mode": "packed",
                    "effect_mode": "none",
                },
            },
            "prefill": {
                "input_embed": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype),
                "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
                "$": {
                    "param_mode": "packed",
                    "effect_mode": "none",
                },
            },
            "decode": {
                "input_embed": nn.spec.Tensor([1, 1, self.n_embed], self.dtype),
                "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
                "$": {
                    "param_mode": "packed",
                    "effect_mode": "none",
                },
            },
            "batch_prefill": {
                "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype),
                "logit_positions": nn.spec.Tensor(["batch_size"], "int32"),
                "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
                "$": {
                    "param_mode": "packed",
                    "effect_mode": "none",
                },
            },
            "batch_decode": {
                "input_embeds": nn.spec.Tensor(["batch_size", 1, self.n_embed], self.dtype),
                "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
                "$": {
                    "param_mode": "packed",
                    "effect_mode": "none",
                },
            },
            "batch_verify": {
                "input_embeds": nn.spec.Tensor([1, "seq_len", self.n_embed], self.dtype),
                "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache),
                "$": {
                    "param_mode": "packed",
                    "effect_mode": "none",
                },
            },
            "create_paged_kv_cache": {
                "max_batch_size": int,
                "max_total_seq_len": int,
                "prefill_chunk_size": int,
                "page_size": int,
                "support_sliding_window": int,
                "$": {
                    "param_mode": "none",
                    "effect_mode": "none",
                },
            },
        }
        return nn.spec.ModuleSpec.from_raw(mod_spec, self)
File:      /media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/model/gpt2/gpt2_model.py
Type:      function

所有指定的方法,例如 embedprefilldecode 等,都将被导出到 TVM IRModule 中。支持 nn.spec.Tensornn.spec.Tuple 和整数作为 relax 函数的输入。

"default" 和 "packed" 调用约定之间的区别如下:

在提供模型规范后,可以使用 export_tvm 函数轻松地将 TVM nn.module 转换为 relax Tensor IR。可以查看整个模型的 Tensor IR 表示,以及模型参数名称和数据类型的完整列表。

from mlc_llm.model.gpt2 import gpt2_model

config_dict = {
    "architectures": ["GPT2LMHeadModel"],
    "bos_token_id": 50256,
    "eos_token_id": 50256,
    "hidden_act": "gelu_new",
    "n_ctx": 1024,
    "n_embd": 768,
    "n_head": 12,
    "n_layer": 12,
    "n_positions": 1024,
    "layer_norm_epsilon": 1e-05,
    "scale_attn_by_inverse_layer_idx": False,
    "vocab_size": 50257,
}

config = gpt2_model.GPT2Config.from_dict(config_dict)
model = gpt2_model.GPT2LMHeadModel(config)
mod, named_params = model.export_tvm(
    spec=model.get_default_spec(),
)

# Uncomment the following line to show the model in Tensor IR
# mod.show(black_format=False)

for name, param in named_params:
    print(name, param.shape, param.dtype)
Hide code cell output
transformer.wte.weight [vocab_size, 768] float32
transformer.wpe.weight [1024, 768] float32
transformer.h.0.ln_1.weight [768] float32
transformer.h.0.ln_1.bias [768] float32
transformer.h.0.attn.c_attn.weight [2304, 768] float32
transformer.h.0.attn.c_attn.bias [2304] float32
transformer.h.0.attn.c_proj.weight [768, 768] float32
transformer.h.0.attn.c_proj.bias [768] float32
transformer.h.0.ln_2.weight [768] float32
transformer.h.0.ln_2.bias [768] float32
transformer.h.0.mlp.c_fc.weight [3072, 768] float32
transformer.h.0.mlp.c_fc.bias [3072] float32
transformer.h.0.mlp.c_proj.weight [768, 3072] float32
transformer.h.0.mlp.c_proj.bias [768] float32
transformer.h.1.ln_1.weight [768] float32
transformer.h.1.ln_1.bias [768] float32
transformer.h.1.attn.c_attn.weight [2304, 768] float32
transformer.h.1.attn.c_attn.bias [2304] float32
transformer.h.1.attn.c_proj.weight [768, 768] float32
transformer.h.1.attn.c_proj.bias [768] float32
transformer.h.1.ln_2.weight [768] float32
transformer.h.1.ln_2.bias [768] float32
transformer.h.1.mlp.c_fc.weight [3072, 768] float32
transformer.h.1.mlp.c_fc.bias [3072] float32
transformer.h.1.mlp.c_proj.weight [768, 3072] float32
transformer.h.1.mlp.c_proj.bias [768] float32
transformer.h.2.ln_1.weight [768] float32
transformer.h.2.ln_1.bias [768] float32
transformer.h.2.attn.c_attn.weight [2304, 768] float32
transformer.h.2.attn.c_attn.bias [2304] float32
transformer.h.2.attn.c_proj.weight [768, 768] float32
transformer.h.2.attn.c_proj.bias [768] float32
transformer.h.2.ln_2.weight [768] float32
transformer.h.2.ln_2.bias [768] float32
transformer.h.2.mlp.c_fc.weight [3072, 768] float32
transformer.h.2.mlp.c_fc.bias [3072] float32
transformer.h.2.mlp.c_proj.weight [768, 3072] float32
transformer.h.2.mlp.c_proj.bias [768] float32
transformer.h.3.ln_1.weight [768] float32
transformer.h.3.ln_1.bias [768] float32
transformer.h.3.attn.c_attn.weight [2304, 768] float32
transformer.h.3.attn.c_attn.bias [2304] float32
transformer.h.3.attn.c_proj.weight [768, 768] float32
transformer.h.3.attn.c_proj.bias [768] float32
transformer.h.3.ln_2.weight [768] float32
transformer.h.3.ln_2.bias [768] float32
transformer.h.3.mlp.c_fc.weight [3072, 768] float32
transformer.h.3.mlp.c_fc.bias [3072] float32
transformer.h.3.mlp.c_proj.weight [768, 3072] float32
transformer.h.3.mlp.c_proj.bias [768] float32
transformer.h.4.ln_1.weight [768] float32
transformer.h.4.ln_1.bias [768] float32
transformer.h.4.attn.c_attn.weight [2304, 768] float32
transformer.h.4.attn.c_attn.bias [2304] float32
transformer.h.4.attn.c_proj.weight [768, 768] float32
transformer.h.4.attn.c_proj.bias [768] float32
transformer.h.4.ln_2.weight [768] float32
transformer.h.4.ln_2.bias [768] float32
transformer.h.4.mlp.c_fc.weight [3072, 768] float32
transformer.h.4.mlp.c_fc.bias [3072] float32
transformer.h.4.mlp.c_proj.weight [768, 3072] float32
transformer.h.4.mlp.c_proj.bias [768] float32
transformer.h.5.ln_1.weight [768] float32
transformer.h.5.ln_1.bias [768] float32
transformer.h.5.attn.c_attn.weight [2304, 768] float32
transformer.h.5.attn.c_attn.bias [2304] float32
transformer.h.5.attn.c_proj.weight [768, 768] float32
transformer.h.5.attn.c_proj.bias [768] float32
transformer.h.5.ln_2.weight [768] float32
transformer.h.5.ln_2.bias [768] float32
transformer.h.5.mlp.c_fc.weight [3072, 768] float32
transformer.h.5.mlp.c_fc.bias [3072] float32
transformer.h.5.mlp.c_proj.weight [768, 3072] float32
transformer.h.5.mlp.c_proj.bias [768] float32
transformer.h.6.ln_1.weight [768] float32
transformer.h.6.ln_1.bias [768] float32
transformer.h.6.attn.c_attn.weight [2304, 768] float32
transformer.h.6.attn.c_attn.bias [2304] float32
transformer.h.6.attn.c_proj.weight [768, 768] float32
transformer.h.6.attn.c_proj.bias [768] float32
transformer.h.6.ln_2.weight [768] float32
transformer.h.6.ln_2.bias [768] float32
transformer.h.6.mlp.c_fc.weight [3072, 768] float32
transformer.h.6.mlp.c_fc.bias [3072] float32
transformer.h.6.mlp.c_proj.weight [768, 3072] float32
transformer.h.6.mlp.c_proj.bias [768] float32
transformer.h.7.ln_1.weight [768] float32
transformer.h.7.ln_1.bias [768] float32
transformer.h.7.attn.c_attn.weight [2304, 768] float32
transformer.h.7.attn.c_attn.bias [2304] float32
transformer.h.7.attn.c_proj.weight [768, 768] float32
transformer.h.7.attn.c_proj.bias [768] float32
transformer.h.7.ln_2.weight [768] float32
transformer.h.7.ln_2.bias [768] float32
transformer.h.7.mlp.c_fc.weight [3072, 768] float32
transformer.h.7.mlp.c_fc.bias [3072] float32
transformer.h.7.mlp.c_proj.weight [768, 3072] float32
transformer.h.7.mlp.c_proj.bias [768] float32
transformer.h.8.ln_1.weight [768] float32
transformer.h.8.ln_1.bias [768] float32
transformer.h.8.attn.c_attn.weight [2304, 768] float32
transformer.h.8.attn.c_attn.bias [2304] float32
transformer.h.8.attn.c_proj.weight [768, 768] float32
transformer.h.8.attn.c_proj.bias [768] float32
transformer.h.8.ln_2.weight [768] float32
transformer.h.8.ln_2.bias [768] float32
transformer.h.8.mlp.c_fc.weight [3072, 768] float32
transformer.h.8.mlp.c_fc.bias [3072] float32
transformer.h.8.mlp.c_proj.weight [768, 3072] float32
transformer.h.8.mlp.c_proj.bias [768] float32
transformer.h.9.ln_1.weight [768] float32
transformer.h.9.ln_1.bias [768] float32
transformer.h.9.attn.c_attn.weight [2304, 768] float32
transformer.h.9.attn.c_attn.bias [2304] float32
transformer.h.9.attn.c_proj.weight [768, 768] float32
transformer.h.9.attn.c_proj.bias [768] float32
transformer.h.9.ln_2.weight [768] float32
transformer.h.9.ln_2.bias [768] float32
transformer.h.9.mlp.c_fc.weight [3072, 768] float32
transformer.h.9.mlp.c_fc.bias [3072] float32
transformer.h.9.mlp.c_proj.weight [768, 3072] float32
transformer.h.9.mlp.c_proj.bias [768] float32
transformer.h.10.ln_1.weight [768] float32
transformer.h.10.ln_1.bias [768] float32
transformer.h.10.attn.c_attn.weight [2304, 768] float32
transformer.h.10.attn.c_attn.bias [2304] float32
transformer.h.10.attn.c_proj.weight [768, 768] float32
transformer.h.10.attn.c_proj.bias [768] float32
transformer.h.10.ln_2.weight [768] float32
transformer.h.10.ln_2.bias [768] float32
transformer.h.10.mlp.c_fc.weight [3072, 768] float32
transformer.h.10.mlp.c_fc.bias [3072] float32
transformer.h.10.mlp.c_proj.weight [768, 3072] float32
transformer.h.10.mlp.c_proj.bias [768] float32
transformer.h.11.ln_1.weight [768] float32
transformer.h.11.ln_1.bias [768] float32
transformer.h.11.attn.c_attn.weight [2304, 768] float32
transformer.h.11.attn.c_attn.bias [2304] float32
transformer.h.11.attn.c_proj.weight [768, 768] float32
transformer.h.11.attn.c_proj.bias [768] float32
transformer.h.11.ln_2.weight [768] float32
transformer.h.11.ln_2.bias [768] float32
transformer.h.11.mlp.c_fc.weight [3072, 768] float32
transformer.h.11.mlp.c_fc.bias [3072] float32
transformer.h.11.mlp.c_proj.weight [768, 3072] float32
transformer.h.11.mlp.c_proj.bias [768] float32
transformer.ln_f.weight [768] float32
transformer.ln_f.bias [768] float32
lm_head.weight [vocab_size, 768] float32

gpt2_loader.py 中定义加载器#

gpt2_loader.py 中,定义了如何将 Huggingface 的参数转换为 MLC 模型所使用的格式。

加载器类将返回 ExternMapping,其中包含两种映射:

  • 源 -> MLC 参数映射:例如参数重命名、参数转换等。

  • 未使用的映射:源中未在 MLC 模型定义中使用的参数。

在 GPT-2 中,由于使用了 Conv1D,需要对 c_attnc_projc_fc 的权重进行转置。为此,将提供映射函数,如下所示:

for conv1d_weight_name in ["attn.c_attn", "attn.c_proj", "mlp.c_proj", "mlp.c_fc"]:
    src_name = f"h.{i}.{conv1d_weight_name}.weight"
    mlc_name = f"transformer.{src_name}"
    mapping.add_mapping(
        mlc_name,
        [src_name],
        functools.partial(
            lambda x, dtype: x.transpose().astype(dtype),
            dtype=named_parameters[mlc_name].dtype,
        ),
    )

为了使 GPT-2 参数转换正常工作,还需要进行一些重命名操作。请参考gpt2_loader.py

将模型添加到支持的预构建模型工作流#

一旦整个模型在 TVM 的 nn.module 中定义完毕,包括模型架构、模型加载器和模型量化器,就可以将其添加到支持的预构建模型工作流中。

mlc-llm/python/mlc_llm/model/model.py中,将GPT-2模型添加到 MODELS 列表中:

"gpt2": Model(
    name="gpt2",
    model=gpt2_model.GPT2LMHeadModel,
    config=gpt2_model.GPT2Config,
    source={
        "huggingface-torch": gpt2_loader.huggingface,
        "huggingface-safetensor": gpt2_loader.huggingface,
    },
    quantize={
        "no-quant": gpt2_quantization.no_quant,
        "group-quant": gpt2_quantization.group_quant,
    },
)

编译 GPT-2 模型库和权重#

以下步骤与通用模型编译工作流相同。

# Create directory
!mkdir -p {temp_dir}/dist/models
%cd {temp_dir}/dist/models

# Clone HF weights
!git lfs install
# git clone https://huggingface.co/openai-community/gpt2
# git clone git@hf.co:openai-community/gpt2
!git clone https://hf-mirror.com/openai-community/gpt2
%cd ../..
/media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models
Updated git hooks.
Git LFS initialized.
正克隆到 'gpt2'...
remote: Enumerating objects: 87, done.
remote: Counting objects: 100% (3/3), done.
remote: Compressing objects: 100% (2/2), done.
remote: Total 87 (delta 0), reused 0 (delta 0), pack-reused 84 (from 1)
展开对象中: 100% (87/87), 1.65 MiB | 38.00 KiB/s, 完成.
过滤内容: 100% (11/11), 5.23 GiB | 2.64 MiB/s, 完成.
/media/pc/data/lxw/ai/tvm-book/tests/.temp
# Convert weight
!python -m mlc_llm convert_weight {temp_dir}/dist/models/gpt2/ --device cuda --quantization q0f16 -o {temp_dir}/dist/gpt2-q0f16-MLC
Hide code cell output
[2025-01-07 11:11:43] INFO auto_config.py:116: Found model configuration: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/config.json
[2025-01-07 11:11:46] INFO auto_device.py:79: Found device: cuda:0
[2025-01-07 11:11:46] INFO auto_device.py:79: Found device: cuda:1
[2025-01-07 11:11:46] INFO auto_weight.py:71: Finding weights in: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2
[2025-01-07 11:11:46] INFO auto_weight.py:130: Found source weight format: huggingface-torch. Source configuration: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/pytorch_model.bin
[2025-01-07 11:11:49] INFO auto_weight.py:161: Found source weight format: huggingface-safetensor. Source configuration: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/model.safetensors.index.json
[2025-01-07 11:11:49] INFO auto_weight.py:107: Using source weight configuration: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/pytorch_model.bin. Use `--source` to override.
[2025-01-07 11:11:49] INFO auto_weight.py:111: Using source weight format: huggingface-torch. Use `--source-format` to override.
[2025-01-07 11:11:49] INFO auto_config.py:154: Found model type: gpt2. Use `--model-type` to override.
Weight conversion with arguments:
  --config          /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/config.json
  --quantization    NoQuantize(name='q0f16', kind='no-quant', model_dtype='float16')
  --model-type      gpt2
  --device          cuda:0
  --source          /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/pytorch_model.bin
  --source-format   huggingface-torch
  --output          /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC
[2025-01-07 11:11:49] INFO gpt2_model.py:47: context_window_size not found in config.json. Falling back to n_positions (1024)
[2025-01-07 11:11:49] INFO gpt2_model.py:64: prefill_chunk_size defaults to 1024
Start storing to cache /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC
[2025-01-07 11:11:52] INFO huggingface_loader.py:185: Loading HF parameters from: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/pytorch_model.bin
/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/loader/utils.py:43: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  for name, param in torch.load(path, map_location=torch.device("cpu")).items():
[2025-01-07 11:11:53] INFO huggingface_loader.py:175: [Not quantized] Parameter: "lm_head.weight", shape: (50257, 768), dtype: float16
[2025-01-07 11:11:53] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.wte.weight", shape: (50257, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.wpe.weight", shape: (1024, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.0.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.1.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.2.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.3.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:54] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.4.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.5.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.6.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.7.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.8.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.9.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.10.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.ln_1.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.ln_1.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.attn.c_attn.weight", shape: (2304, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.attn.c_attn.bias", shape: (2304,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.attn.c_proj.weight", shape: (768, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.attn.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.ln_2.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.ln_2.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.mlp.c_fc.weight", shape: (3072, 768), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.mlp.c_fc.bias", shape: (3072,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.mlp.c_proj.weight", shape: (768, 3072), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.h.11.mlp.c_proj.bias", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.ln_f.weight", shape: (768,), dtype: float16
[2025-01-07 11:11:55] INFO huggingface_loader.py:175: [Not quantized] Parameter: "transformer.ln_f.bias", shape: (768,), dtype: float16
100%|█████████████████████████████████████████| 149/149 [00:02<00:00, 51.89it/s]
[2025-01-07 11:11:55] INFO huggingface_loader.py:197: Unloading HF weight file: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/pytorch_model.bin
[2025-01-07 11:11:56] INFO stats.py:77: Time usage: HF loading: 0.481 sec; Pre-quantization mapping: 0.920 sec; Quantization: 0.000 sec
[2025-01-07 11:11:56] INFO stats.py:91: RAM usage: Peak RAM: 0.510 GB. Total bytes loaded from disk: 0.510 GB

All finished, 8 total shards committed, record saved to /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/ndarray-cache.json
[2025-01-07 11:11:56] INFO convert_weight.py:155: Parameter size after quantization: 0.304 GB
[2025-01-07 11:11:56] INFO convert_weight.py:160: Total parameters: 124,439,808
[2025-01-07 11:11:56] INFO convert_weight.py:161: Bits per parameter: 20.963
[2025-01-07 11:11:56] INFO convert_weight.py:166: Saved to directory: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC
  1. gen_config: 生成 mlc-chat-config.json 并处理分词器

!python -m mlc_llm gen_config {temp_dir}/dist/models/gpt2 --quantization q0f16 --conv-template gpt2 -o {temp_dir}/dist/gpt2-q0f16-MLC/
Hide code cell output
[2025-01-07 11:12:41] INFO auto_config.py:116: Found model configuration: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/config.json
[2025-01-07 11:12:41] INFO auto_config.py:154: Found model type: gpt2. Use `--model-type` to override.
[2025-01-07 11:12:41] INFO gpt2_model.py:47: context_window_size not found in config.json. Falling back to n_positions (1024)
[2025-01-07 11:12:41] INFO gpt2_model.py:64: prefill_chunk_size defaults to 1024
[2025-01-07 11:12:41] INFO config.py:107: Overriding max_batch_size from 1 to 128
[2025-01-07 11:12:41] INFO gen_config.py:150: [generation_config.json] Setting bos_token_id: 50256
[2025-01-07 11:12:41] INFO gen_config.py:150: [generation_config.json] Setting eos_token_id: 50256
[2025-01-07 11:12:41] INFO gen_config.py:164: Not found tokenizer config: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/tokenizer.model
[2025-01-07 11:12:41] INFO gen_config.py:162: Found tokenizer config: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/tokenizer.json. Copying to /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/tokenizer.json
[2025-01-07 11:12:41] INFO gen_config.py:162: Found tokenizer config: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/vocab.json. Copying to /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/vocab.json
[2025-01-07 11:12:41] INFO gen_config.py:162: Found tokenizer config: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/merges.txt. Copying to /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/merges.txt
[2025-01-07 11:12:41] INFO gen_config.py:164: Not found tokenizer config: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/added_tokens.json
[2025-01-07 11:12:41] INFO gen_config.py:162: Found tokenizer config: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/models/gpt2/tokenizer_config.json. Copying to /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/tokenizer_config.json
[2025-01-07 11:12:41] INFO gen_config.py:223: Detected tokenizer info: {'token_postproc_method': 'byte_level', 'prepend_space_in_encode': False, 'strip_space_in_decode': False}
[2025-01-07 11:12:41] INFO gen_config.py:32: [System default] Setting pad_token_id: 0
[2025-01-07 11:12:41] INFO gen_config.py:32: [System default] Setting temperature: 1.0
[2025-01-07 11:12:41] INFO gen_config.py:32: [System default] Setting presence_penalty: 0.0
[2025-01-07 11:12:41] INFO gen_config.py:32: [System default] Setting frequency_penalty: 0.0
[2025-01-07 11:12:41] INFO gen_config.py:32: [System default] Setting repetition_penalty: 1.0
[2025-01-07 11:12:41] INFO gen_config.py:32: [System default] Setting top_p: 1.0
[2025-01-07 11:12:41] INFO gen_config.py:251: Dumping configuration file to: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/mlc-chat-config.json
  1. 编译:根据 mlc-chat-config.json 中的规范编译模型库

!python -m mlc_llm compile {temp_dir}/dist/gpt2-q0f16-MLC/mlc-chat-config.json --device cuda -o {temp_dir}/dist/gpt2-q0f16-MLC/gpt2-q0f16-cuda.so
Hide code cell output
[2025-01-07 11:13:02] INFO auto_config.py:70: Found model configuration: /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/mlc-chat-config.json
[2025-01-07 11:13:05] INFO auto_device.py:79: Found device: cuda:0
[2025-01-07 11:13:05] INFO auto_device.py:79: Found device: cuda:1
[2025-01-07 11:13:05] INFO auto_target.py:78: Found configuration of target device "cuda:0": {"thread_warp_size": runtime.BoxInt(32), "arch": "sm_86", "max_threads_per_block": runtime.BoxInt(1024), "max_num_threads": runtime.BoxInt(1024), "kind": "cuda", "max_shared_memory_per_block": runtime.BoxInt(49152), "tag": "", "keys": ["cuda", "gpu"]}
[2025-01-07 11:13:05] INFO auto_target.py:110: Found host LLVM triple: x86_64-unknown-linux-gnu
[2025-01-07 11:13:05] INFO auto_target.py:111: Found host LLVM CPU: haswell
[2025-01-07 11:13:05] INFO auto_target.py:334: Generating code for CUDA architecture: sm_86
[2025-01-07 11:13:05] INFO auto_target.py:335: To produce multi-arch fatbin, set environment variable MLC_MULTI_ARCH. Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90a
[2025-01-07 11:13:05] INFO auto_config.py:154: Found model type: gpt2. Use `--model-type` to override.
Compiling with arguments:
  --config          GPT2Config(vocab_size=50257, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-05, n_inner=3072, context_window_size=1024, prefill_chunk_size=1024, scale_attn_by_inverse_layer_idx=False, tensor_parallel_shards=1, head_dim=64, max_batch_size=128, kwargs={})
  --quantization    NoQuantize(name='q0f16', kind='no-quant', model_dtype='float16')
  --model-type      gpt2
  --target          {"thread_warp_size": runtime.BoxInt(32), "host": {"mtriple": "x86_64-unknown-linux-gnu", "tag": "", "kind": "llvm", "mcpu": "haswell", "keys": ["cpu"]}, "arch": "sm_86", "max_threads_per_block": runtime.BoxInt(1024), "libs": ["thrust"], "max_num_threads": runtime.BoxInt(1024), "kind": "cuda", "max_shared_memory_per_block": runtime.BoxInt(49152), "tag": "", "keys": ["cuda", "gpu"]}
  --opt             flashinfer=0;cublas_gemm=1;faster_transformer=0;cudagraph=1;cutlass=1;ipc_allreduce_strategy=AUTO
  --system-lib-prefix ""
  --output          /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/gpt2-q0f16-cuda.so
  --overrides       context_window_size=None;sliding_window_size=None;prefill_chunk_size=None;attention_sink_size=None;max_batch_size=None;tensor_parallel_shards=None;pipeline_parallel_stages=None;disaggregation=None
[2025-01-07 11:13:05] INFO compile.py:140: Creating model from: GPT2Config(vocab_size=50257, n_embd=768, n_layer=12, n_head=12, layer_norm_epsilon=1e-05, n_inner=3072, context_window_size=1024, prefill_chunk_size=1024, scale_attn_by_inverse_layer_idx=False, tensor_parallel_shards=1, head_dim=64, max_batch_size=128, kwargs={})
[2025-01-07 11:13:05] INFO compile.py:158: Exporting the model to TVM Unity compiler
[2025-01-07 11:13:07] INFO compile.py:164: Running optimizations using TVM Unity
[2025-01-07 11:13:07] INFO compile.py:186: Registering metadata: {'model_type': 'gpt2', 'quantization': 'q0f16', 'context_window_size': 1024, 'sliding_window_size': -1, 'attention_sink_size': -1, 'prefill_chunk_size': 1024, 'tensor_parallel_shards': 1, 'pipeline_parallel_stages': 1, 'disaggregation': False, 'kv_state_kind': 'kv_cache', 'max_batch_size': 128}
[2025-01-07 11:13:15] INFO pipeline.py:55: Running TVM Relax graph-level optimizations
[2025-01-07 11:13:18] INFO pipeline.py:55: Lowering to TVM TIR kernels
[2025-01-07 11:13:21] WARNING thrust.py:25: thrust is requested but TVM is not built with thrust.
[2025-01-07 11:13:21] WARNING thrust.py:25: thrust is requested but TVM is not built with thrust.
[2025-01-07 11:13:24] INFO pipeline.py:55: Running TVM TIR-level optimizations
[2025-01-07 11:13:28] INFO pipeline.py:55: Running TVM Dlight low-level optimizations
[2025-01-07 11:13:31] INFO pipeline.py:55: Lowering to VM bytecode
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `alloc_embedding_tensor`: 1.50 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `argsort_probs`: 0.00 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_decode`: 1.69 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_prefill`: 13.69 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `batch_verify`: 13.50 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `create_tir_paged_kv_cache`: 0.00 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `decode`: 0.01 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `embed`: 1.50 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `multinomial_from_uniform`: 0.00 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `prefill`: 13.50 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `renormalize_by_top_p`: 0.00 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `sample_with_top_p`: 0.00 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `sampler_take_probs`: 0.01 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `sampler_verify_draft_tokens`: 0.00 MB
[2025-01-07 11:13:33] INFO estimate_memory_usage.py:58: [Memory usage] Function `softmax_with_temperature`: 0.00 MB
[2025-01-07 11:13:34] INFO pipeline.py:55: Compiling external modules
[2025-01-07 11:13:34] INFO pipeline.py:55: Compilation complete! Exporting to disk
Traceback (most recent call last):
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/interface/compile.py", line 189, in _compile
    args.build_func(
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/support/auto_target.py", line 301, in build
    relax.build(
  File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/vm_build.py", line 353, in build
    return _vmlink(
           ^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/python/tvm/relax/vm_build.py", line 249, in _vmlink
    lib = tvm.build(
          ^^^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/python/tvm/driver/build_module.py", line 297, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
    raise_last_ffi_error()
  File "/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/media/pc/data/lxw/ai/tvm/src/driver/driver_api.cc", line 531, in operator()
    return TIRToRuntime(inputs_arg, host_target);
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/src/driver/driver_api.cc", line 514, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
    device_modules.push_back(codegen::Build(device_mod, it.first));
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/src/target/codegen.cc", line 73, in tvm::codegen::Build(tvm::IRModule, tvm::Target)
    return (*bf)(mod, target);
                    ^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/src/target/opt/build_cuda_on.cc", line 161, in tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
    ptx = (*f)(code, target).operator std::string();
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/support/auto_target.py", line 352, in tvm_callback_cuda_compile
    ptx = nvcc.compile_cuda(code, target_format="fatbin")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/nvcc.py", line 120, in compile_cuda
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/subprocess.py", line 1026, in __init__
    self._execute_child(args, executable, preexec_fn, close_fds,
  File "/media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/subprocess.py", line 1953, in _execute_child
    raise child_exception_type(errno_num, err_msg, err_filename)
FileNotFoundError: [Errno 2] No such file or directory: 'nvcc'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/__main__.py", line 69, in <module>
    main()
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/__main__.py", line 34, in main
    cli.main(sys.argv[2:])
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/cli/compile.py", line 129, in main
    compile(
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/interface/compile.py", line 244, in compile
    _compile(args, model_config)
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/interface/compile.py", line 132, in _compile
    with args.target:
  File "/media/pc/data/lxw/ai/tvm/python/tvm/target/target.py", line 145, in __exit__
    _ffi_api.TargetExitScope(self)
  File "/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 245, in __call__
    raise_last_ffi_error()
  File "/media/pc/data/lxw/ai/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/media/pc/data/lxw/ai/tvm/src/target/target.cc", line 757, in tvm::Target::ExitWithScope()
    ICHECK(!entry->context_stack.empty());
                    ^^^^^^^^^^^^^^^^^^^^^^^
tvm.error.InternalError: Traceback (most recent call last):
  0: tvm::Target::ExitWithScope()
        at /media/pc/data/lxw/ai/tvm/src/target/target.cc:757
  File "/media/pc/data/lxw/ai/tvm/src/target/target.cc", line 758
InternalError: Check failed: (entry->context_stack.top().same_as(*this)) is false: 

使用 DebugChat 调试编译的 MLC 模型#

在成功编译模型库并转换模型权重后,检查模型是否生成正确的输出非常重要。一种检查方法是在相同的输入 tokens 下,将模型的输出 logits 与其 Huggingface PyTorch 版本的输出进行比较。

为了帮助调试 MLC 模型,提供了 mlc_llm.testing.DebugChat 模块,该模块可以:

  • 加载刚刚编译的 MLC 模型

  • 使用用户指定的 prompt 运行模型的完整 forward 流程

  • 转储所有层的中间值。

然后,您可以将这些中间值与 Huggingface PyTorch 模型的中间值进行比较。(对于 PyTorch,您可以使用 register_forward_hook 提取中间值。)

!python -m mlc_llm.testing.debug_chat --model {temp_dir}/dist/gpt2-q0f16-MLC/ --model-lib {temp_dir}/dist/gpt2-q0f16-MLC/gpt2-q0f16-cuda.so --device cuda --debug-dir {temp_dir}/debug-gpt2 --generate-len 5 "Hey how are you doing today?"
Hide code cell output
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/testing/debug_chat.py", line 536, in <module>
    main()
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/testing/debug_chat.py", line 523, in main
    dc = DebugChat(
         ^^^^^^^^^^
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/testing/debug_chat.py", line 227, in __init__
    self.mod, self.params, self.metadata = _get_tvm_module(
                                           ^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/mlc-llm/python/mlc_llm/testing/debug_chat.py", line 49, in _get_tvm_module
    ex = tvm.runtime.load_module(lib_path)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/pc/data/lxw/ai/tvm/python/tvm/runtime/module.py", line 683, in load_module
    raise ValueError(f"cannot find file {path}")
ValueError: cannot find file /media/pc/data/lxw/ai/tvm-book/tests/.temp/dist/gpt2-q0f16-MLC/gpt2-q0f16-cuda.so

中间输出会被转储到 debug-gpt2 文件夹中。对于每个 prefill/decode 阶段,都有单独的文件夹,其中包含存储每个内核函数调用参数的 .npz 文件。

例如:./debug-gpt2/decode_2/f0_take3.npz 对应第 2 个解码步骤中的第 0 个 take 函数调用。输出 logits 会保存到 logits.npz 中。

注意:由于 TIR 函数调用采用目标传递风格,每个函数调用的参数会如下所示:

def low_level_prim_func(in0, in1, ..., out):
    # 实现

因此,函数调用的最后一个参数将是输出。

.npz 文件可以按以下方式加载:

import numpy as np

data = np.load(f'{temp_dir}/debug-gpt2/decode_2/f0_take3.npz')
print(data)
print(data["arg_0"])
print(data["arg_1"])
print(data["arg_2"]) # This is the output of the take function
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[21], line 3
      1 import numpy as np
----> 3 data = np.load(f'{temp_dir}/debug-gpt2/decode_2/f0_take3.npz')
      4 print(data)
      5 print(data["arg_0"])

File /media/pc/data/lxw/envs/anaconda3a/envs/ai/lib/python3.12/site-packages/numpy/lib/npyio.py:427, in load(file, mmap_mode, allow_pickle, fix_imports, encoding, max_header_size)
    425     own_fid = False
    426 else:
--> 427     fid = stack.enter_context(open(os_fspath(file), "rb"))
    428     own_fid = True
    430 # Code to distinguish from NumPy binary files and pickles.

FileNotFoundError: [Errno 2] No such file or directory: '/media/pc/data/lxw/ai/tvm-book/tests/.temp/debug-gpt2/decode_2/f0_take3.npz'