# `AnnotateSpans`



{func}`tvm.relay.transform.AnnotateSpans` 的作用是为程序添加跨度信息。具体来说，它首先生成程序的文本表示形式，然后将其解析回带有跨度信息的 Relay 抽象语法树（AST）。对模块进行美化打印，然后再将其解析回来，以便为所有 Relay 子表达式建立 spans（范围）和 sources（来源）。这有助于改善程序化构建的模块在下游的错误和调试诊断。


In [1]:
import testing

In [5]:
import torchvision
import torch
from tvm import relay
model = torchvision.models.resnet18().eval()
inp = torch.randn([1, 3, 224, 224])
trace = torch.jit.trace(model, inp).eval()
mod, _ = relay.frontend.from_pytorch(
    trace, [("input", inp.shape)], use_parser_friendly_name=True
)
mod = relay.transform.AnnotateSpans()(mod)
print(mod["main"])

fn (%input: Tensor[(1, 3, 224, 224), float32] /* ty=Tensor[(1, 3, 224, 224), float32] span=GeneratedSource:116:18 */, %aten___convolution_0_weight: Tensor[(64, 3, 7, 7), float32] /* ty=Tensor[(64, 3, 7, 7), float32] span=GeneratedSource:116:28 */, %aten__batch_norm_0_weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=GeneratedSource:117:26 */, %aten__batch_norm_0_bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=GeneratedSource:117:54 */, %aten__batch_norm_0_mean: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=GeneratedSource:117:80 */, %aten__batch_norm_0_var: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=GeneratedSource:117:106 */, %aten___convolution_1_weight: Tensor[(64, 64, 3, 3), float32] /* ty=Tensor[(64, 64, 3, 3), float32] span=GeneratedSource:121:22 */, %aten__batch_norm_1_weight: Tensor[(64), float32] /* ty=Tensor[(64), float32] span=GeneratedSource:122:26 */, %aten__batch_norm_1_bias: Tensor[(64), float32] /* ty=Tensor[(64), float32] 

In [8]:
import numpy as np
import tvm
x = relay.var("x", shape=(3, 4), dtype="float32")
y = relay.clip(x, -np.inf, np.inf)

f = relay.Function([x], y)
mod = tvm.IRModule.from_expr(f)

mod = relay.transform.AnnotateSpans()(mod)
mod.show()

In [9]:
import tvm
import tvm.relay as relay
from tvm.relay import testing
import tvm.testing


def test_annotate_spans_compatibility():
    data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32"))
    weight = relay.var("weight")

    bn_gamma = relay.var("bn_gamma")
    bn_beta = relay.var("bn_beta")
    bn_mmean = relay.var("bn_mean")
    bn_mvar = relay.var("bn_var")

    simple_net = relay.nn.conv2d(
        data=data, weight=weight, kernel_size=(3, 3), channels=3, padding=(1, 1)
    )
    simple_net = relay.nn.batch_norm(simple_net, bn_gamma, bn_beta, bn_mmean, bn_mvar)[0]
    simple_net = relay.Function(relay.analysis.free_vars(simple_net), simple_net)

    module, params = testing.create_workload(simple_net)

    # Apply some simple passes to legalize the IR.
    with tvm.transform.PassContext(opt_level=0):
        module, params = relay.optimize(
            module, target=tvm.testing.enabled_targets()[0][0], params=params
        )

    seq = tvm.transform.Sequential([relay.transform.AnnotateSpans(), relay.transform.DefuseOps()])
    with tvm.transform.PassContext(opt_level=3):
        module = seq(module)

test_annotate_spans_compatibility()



In [10]:
import tvm
import tvm.relay
from tvm.relay import op
from tvm.ir.instrument import PassTimingInstrument, pass_instrument

def get_test_model():
    x, y, z = [tvm.relay.var(c, shape=(3, 4), dtype="float32") for c in "xyz"]
    e1 = op.add(x, y)
    e2 = op.subtract(x, z)
    e3 = op.multiply(e1, e1 / e2)
    return tvm.IRModule.from_expr(e3 + e2)

def test_pass_timing_instrument():
    pass_timing = PassTimingInstrument()

    # Override current PassContext's instruments
    tvm.transform.PassContext.current().override_instruments([pass_timing])

    mod = get_test_model()
    mod = tvm.relay.transform.AnnotateSpans()(mod)
    mod = tvm.relay.transform.ToANormalForm()(mod)
    mod = tvm.relay.transform.InferType()(mod)

    profiles = pass_timing.render()
    assert "AnnotateSpans" in profiles
    assert "ToANormalForm" in profiles
    assert "InferType" in profiles

    # Reset current PassContext's instruments to None
    tvm.transform.PassContext.current().override_instruments(None)

    mod = get_test_model()
    mod = tvm.relay.transform.AnnotateSpans()(mod)
    mod = tvm.relay.transform.ToANormalForm()(mod)
    mod = tvm.relay.transform.InferType()(mod)

    profiles = pass_timing.render()
    assert profiles == ""