import pytest
import tvm
from tvm import relay
from tvm.relay.testing.resnet import get_workload
from tvm.relay.testing import run_opt_pass
def get_conv_net ():
"""This gets the net for:
conv2d
/ |
/ |
conv2d |
\ |
\ |
elemwise add
|
"""
dshape = ( 1 , 1 , 5 , 1 )
x = relay . var ( "x" , shape = dshape )
y = relay . nn . conv2d ( x , relay . var ( "w1" ), kernel_size = ( 3 , 3 ), padding = ( 1 , 1 ), channels = 1 )
x1 = relay . nn . conv2d ( y , relay . var ( "w2" ), kernel_size = ( 3 , 3 ), padding = ( 1 , 1 ), channels = 1 )
z = relay . add ( y , x1 )
return tvm . IRModule . from_expr ( z )
def get_conv2d ():
x = relay . var ( "x" , shape = ( 1 , 56 , 56 , 64 ))
weight1 = relay . var ( "weight1" , shape = ( 3 , 3 , 64 , 32 ))
y = relay . nn . conv2d (
x ,
weight1 ,
channels = 32 ,
kernel_size = ( 3 , 3 ),
padding = ( 1 , 1 ),
data_layout = "NHWC" ,
kernel_layout = "HWIO" ,
)
return tvm . IRModule . from_expr ( y )
def test_extract_identity ():
mod = get_conv2d ()
op_freqs = relay . analysis . list_op_freqs ( mod )
assert len ( op_freqs ) == 1
assert op_freqs [ "nn.conv2d" ] == 1
def test_extract_conv_net ():
mod = get_conv_net ()
op_freqs = relay . analysis . list_op_freqs ( mod )
assert len ( op_freqs ) == 2
assert op_freqs [ "add" ] == 1
assert op_freqs [ "nn.conv2d" ] == 2
def test_extract_fused ():
mod = get_conv_net ()
mod = relay . transform . InferType ()( mod )
mod = relay . transform . FuseOps ( 3 )( mod )
op_freqs = relay . analysis . list_op_freqs ( mod )
assert len ( op_freqs ) == 2
assert op_freqs [ "add" ] == 1
assert op_freqs [ "nn.conv2d" ] == 2
def test_extract_resnet ():
mod , _params = get_workload ()
expected_op_freqs = {
"nn.batch_norm" : 19 ,
"nn.conv2d" : 21 ,
"nn.relu" : 18 ,
"nn.max_pool2d" : 1 ,
"add" : 8 ,
"nn.global_avg_pool2d" : 1 ,
"nn.batch_flatten" : 1 ,
"nn.dense" : 1 ,
"nn.bias_add" : 1 ,
"nn.softmax" : 1 ,
}
op_freqs = relay . analysis . list_op_freqs ( mod )
assert len ( op_freqs ) == len ( expected_op_freqs )
assert all ([ op_freqs [ op ] == expected_op_freqs [ op ] for op in expected_op_freqs ])
if __name__ == "__main__" :
tvm . testing . main ()
复制到剪贴板