PyTorch Relax 前端

PyTorch Relax 前端#

import torch
import torch.nn.functional as F
from torch import fx
from torch.nn import Module
from torchvision.models import resnet18, ResNet18_Weights

import tvm
from tvm import relax
import tvm.testing
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.relax.frontend import detach_params
from tvm.relax.frontend.torch import from_fx
def load_model(input_shape):
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).eval()
    data = torch.randn(*input_shape)
    model = torch.jit.trace(model, data)
    return model
from tvm import relay


input_shape = 1, 3, 224, 224
input_name = "data"
torch_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).eval()
graph_model = fx.symbolic_trace(torch_model)
input_info = [(input_shape, "float32")]
with torch.no_grad():
    mod = from_fx(graph_model, input_info)
mod.show()
/media/pc/data/lxw/ai/tvm/xinetzone/__pypackages__/3.10/lib/tvm/script/highlight.py:117: UserWarning: No module named 'black'
To print formatted TVM script, please install the formatter 'Black':
/media/pc/data/tmp/cache/conda/envs/tvmz/bin/python -m pip install "black==22.3.0" --upgrade --user
  warnings.warn(
# from tvm.script import ir as I
# from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function
    def main(inp_0: R.Tensor((1, 3, 224, 224), dtype="float32")) -> R.Tensor((1, 1000), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.conv2d(inp_0, metadata["relax.expr.Constant"][0], strides=[2, 2], padding=[3, 3, 3, 3], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv1: R.Tuple(R.Tensor((1, 64, 112, 112), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv, metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3], metadata["relax.expr.Constant"][4], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv2: R.Tensor((1, 64, 112, 112), dtype="float32") = lv1[0]
            lv3: R.Tensor((1, 64, 112, 112), dtype="float32") = R.nn.relu(lv2)
            lv4: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.max_pool2d(lv3, pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], ceil_mode=False, layout="NCHW", out_layout="NCHW")
            lv5: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv4, metadata["relax.expr.Constant"][5], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv6: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv5, metadata["relax.expr.Constant"][6], metadata["relax.expr.Constant"][7], metadata["relax.expr.Constant"][8], metadata["relax.expr.Constant"][9], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv7: R.Tensor((1, 64, 56, 56), dtype="float32") = lv6[0]
            lv8: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv7)
            lv9: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv8, metadata["relax.expr.Constant"][10], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv10: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv9, metadata["relax.expr.Constant"][11], metadata["relax.expr.Constant"][12], metadata["relax.expr.Constant"][13], metadata["relax.expr.Constant"][14], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv11: R.Tensor((1, 64, 56, 56), dtype="float32") = lv10[0]
            lv12: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv11, lv4)
            lv13: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv12)
            lv14: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv13, metadata["relax.expr.Constant"][15], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv15: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv14, metadata["relax.expr.Constant"][16], metadata["relax.expr.Constant"][17], metadata["relax.expr.Constant"][18], metadata["relax.expr.Constant"][19], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv16: R.Tensor((1, 64, 56, 56), dtype="float32") = lv15[0]
            lv17: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv16)
            lv18: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.conv2d(lv17, metadata["relax.expr.Constant"][20], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv19: R.Tuple(R.Tensor((1, 64, 56, 56), dtype="float32"), R.Tensor((64,), dtype="float32"), R.Tensor((64,), dtype="float32")) = R.nn.batch_norm(lv18, metadata["relax.expr.Constant"][21], metadata["relax.expr.Constant"][22], metadata["relax.expr.Constant"][23], metadata["relax.expr.Constant"][24], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv20: R.Tensor((1, 64, 56, 56), dtype="float32") = lv19[0]
            lv21: R.Tensor((1, 64, 56, 56), dtype="float32") = R.add(lv20, lv13)
            lv22: R.Tensor((1, 64, 56, 56), dtype="float32") = R.nn.relu(lv21)
            lv23: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, metadata["relax.expr.Constant"][25], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv24: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv23, metadata["relax.expr.Constant"][26], metadata["relax.expr.Constant"][27], metadata["relax.expr.Constant"][28], metadata["relax.expr.Constant"][29], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv25: R.Tensor((1, 128, 28, 28), dtype="float32") = lv24[0]
            lv26: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv25)
            lv27: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv26, metadata["relax.expr.Constant"][30], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv28: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv27, metadata["relax.expr.Constant"][31], metadata["relax.expr.Constant"][32], metadata["relax.expr.Constant"][33], metadata["relax.expr.Constant"][34], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv29: R.Tensor((1, 128, 28, 28), dtype="float32") = lv28[0]
            lv30: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv22, metadata["relax.expr.Constant"][35], strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv31: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv30, metadata["relax.expr.Constant"][36], metadata["relax.expr.Constant"][37], metadata["relax.expr.Constant"][38], metadata["relax.expr.Constant"][39], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv32: R.Tensor((1, 128, 28, 28), dtype="float32") = lv31[0]
            lv33: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv29, lv32)
            lv34: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv33)
            lv35: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv34, metadata["relax.expr.Constant"][40], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv36: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv35, metadata["relax.expr.Constant"][41], metadata["relax.expr.Constant"][42], metadata["relax.expr.Constant"][43], metadata["relax.expr.Constant"][44], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv37: R.Tensor((1, 128, 28, 28), dtype="float32") = lv36[0]
            lv38: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv37)
            lv39: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.conv2d(lv38, metadata["relax.expr.Constant"][45], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv40: R.Tuple(R.Tensor((1, 128, 28, 28), dtype="float32"), R.Tensor((128,), dtype="float32"), R.Tensor((128,), dtype="float32")) = R.nn.batch_norm(lv39, metadata["relax.expr.Constant"][46], metadata["relax.expr.Constant"][47], metadata["relax.expr.Constant"][48], metadata["relax.expr.Constant"][49], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv41: R.Tensor((1, 128, 28, 28), dtype="float32") = lv40[0]
            lv42: R.Tensor((1, 128, 28, 28), dtype="float32") = R.add(lv41, lv34)
            lv43: R.Tensor((1, 128, 28, 28), dtype="float32") = R.nn.relu(lv42)
            lv44: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, metadata["relax.expr.Constant"][50], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv45: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv44, metadata["relax.expr.Constant"][51], metadata["relax.expr.Constant"][52], metadata["relax.expr.Constant"][53], metadata["relax.expr.Constant"][54], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv46: R.Tensor((1, 256, 14, 14), dtype="float32") = lv45[0]
            lv47: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv46)
            lv48: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv47, metadata["relax.expr.Constant"][55], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv49: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv48, metadata["relax.expr.Constant"][56], metadata["relax.expr.Constant"][57], metadata["relax.expr.Constant"][58], metadata["relax.expr.Constant"][59], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv50: R.Tensor((1, 256, 14, 14), dtype="float32") = lv49[0]
            lv51: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv43, metadata["relax.expr.Constant"][60], strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv52: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv51, metadata["relax.expr.Constant"][61], metadata["relax.expr.Constant"][62], metadata["relax.expr.Constant"][63], metadata["relax.expr.Constant"][64], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv53: R.Tensor((1, 256, 14, 14), dtype="float32") = lv52[0]
            lv54: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv50, lv53)
            lv55: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv54)
            lv56: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv55, metadata["relax.expr.Constant"][65], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv57: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv56, metadata["relax.expr.Constant"][66], metadata["relax.expr.Constant"][67], metadata["relax.expr.Constant"][68], metadata["relax.expr.Constant"][69], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv58: R.Tensor((1, 256, 14, 14), dtype="float32") = lv57[0]
            lv59: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv58)
            lv60: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.conv2d(lv59, metadata["relax.expr.Constant"][70], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv61: R.Tuple(R.Tensor((1, 256, 14, 14), dtype="float32"), R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="float32")) = R.nn.batch_norm(lv60, metadata["relax.expr.Constant"][71], metadata["relax.expr.Constant"][72], metadata["relax.expr.Constant"][73], metadata["relax.expr.Constant"][74], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv62: R.Tensor((1, 256, 14, 14), dtype="float32") = lv61[0]
            lv63: R.Tensor((1, 256, 14, 14), dtype="float32") = R.add(lv62, lv55)
            lv64: R.Tensor((1, 256, 14, 14), dtype="float32") = R.nn.relu(lv63)
            lv65: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, metadata["relax.expr.Constant"][75], strides=[2, 2], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv66: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv65, metadata["relax.expr.Constant"][76], metadata["relax.expr.Constant"][77], metadata["relax.expr.Constant"][78], metadata["relax.expr.Constant"][79], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv67: R.Tensor((1, 512, 7, 7), dtype="float32") = lv66[0]
            lv68: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv67)
            lv69: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv68, metadata["relax.expr.Constant"][80], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv70: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv69, metadata["relax.expr.Constant"][81], metadata["relax.expr.Constant"][82], metadata["relax.expr.Constant"][83], metadata["relax.expr.Constant"][84], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv71: R.Tensor((1, 512, 7, 7), dtype="float32") = lv70[0]
            lv72: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv64, metadata["relax.expr.Constant"][85], strides=[2, 2], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv73: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv72, metadata["relax.expr.Constant"][86], metadata["relax.expr.Constant"][87], metadata["relax.expr.Constant"][88], metadata["relax.expr.Constant"][89], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv74: R.Tensor((1, 512, 7, 7), dtype="float32") = lv73[0]
            lv75: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv71, lv74)
            lv76: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv75)
            lv77: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv76, metadata["relax.expr.Constant"][90], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv78: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv77, metadata["relax.expr.Constant"][91], metadata["relax.expr.Constant"][92], metadata["relax.expr.Constant"][93], metadata["relax.expr.Constant"][94], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv79: R.Tensor((1, 512, 7, 7), dtype="float32") = lv78[0]
            lv80: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv79)
            lv81: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.conv2d(lv80, metadata["relax.expr.Constant"][95], strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", out_dtype="float32")
            lv82: R.Tuple(R.Tensor((1, 512, 7, 7), dtype="float32"), R.Tensor((512,), dtype="float32"), R.Tensor((512,), dtype="float32")) = R.nn.batch_norm(lv81, metadata["relax.expr.Constant"][96], metadata["relax.expr.Constant"][97], metadata["relax.expr.Constant"][98], metadata["relax.expr.Constant"][99], axis=1, epsilon=1.0000000000000001e-05, center=True, scale=True, momentum=0.10000000000000001)
            lv83: R.Tensor((1, 512, 7, 7), dtype="float32") = lv82[0]
            lv84: R.Tensor((1, 512, 7, 7), dtype="float32") = R.add(lv83, lv76)
            lv85: R.Tensor((1, 512, 7, 7), dtype="float32") = R.nn.relu(lv84)
            lv86: R.Tensor((1, 512, 1, 1), dtype="float32") = R.nn.adaptive_avg_pool2d(lv85, output_size=[1, 1], layout="NCHW", out_layout="NCHW")
            lv87: R.Tensor((1, 512), dtype="float32") = R.reshape(lv86, R.shape([1, 512]))
            lv88: R.Tensor((512, 1000), dtype="float32") = R.permute_dims(metadata["relax.expr.Constant"][100], axes=None)
            lv89: R.Tensor((1, 1000), dtype="float32") = R.matmul(lv87, lv88, out_dtype="float32")
            lv90: R.Tensor((1, 1000), dtype="float32") = R.add(lv89, metadata["relax.expr.Constant"][101])
            gv: R.Tensor((1, 1000), dtype="float32") = lv90
            R.output(gv)
        return gv

# Metadata omitted. Use show_meta=True in script() method to show it.