vta.top.graphpack#
A Relay implementation of graph packing.
Exceptions#
Common base class for all non-exit exceptions. |
Classes#
Visitor to perform graph annotation on an AST. |
|
Visitor to locate op on an AST. |
|
Visitor to perform graph packing on an AST. |
Functions#
|
Round the channel const variant if the value not divisible by cfactor_out |
|
Pad the constant if the shape[0] not divisible by cfactor_out. |
|
Get node shape. |
|
Get node type. |
|
Increase operator index |
|
Pack the data channel dimension. |
|
Pack a constant parameter. |
|
Pack the weight into packed format. |
|
Pack the weight into packed format. |
|
convert shape into tuple. |
|
Unpack the data channel dimension. |
|
Pad the weight if the shape[0] not divisible by cfactor_out. |
|
Pad the weight if the shape[1] not divisible by cfactor_out. |
|
We assume stop_name only appears once for simplicity. |
|
Pack the graph into batch&channel packed format. |
|
Exectue a relay pass. |
Module Contents#
- exception vta.top.graphpack.BT[源代码]#
Bases:
Exception
Common base class for all non-exit exceptions.
- class vta.top.graphpack.ExprDeviceAnnot(start=-1, end=-1)[源代码]#
Bases:
tvm.relay.ExprMutator
Visitor to perform graph annotation on an AST.
Parameters#
- start: int
the start location to mark run on vta (inclusive)
- end: int
the end location to mark run on vta (exclusive)
Returns#
None
- class vta.top.graphpack.ExprLocator[源代码]#
Bases:
tvm.relay.ExprMutator
Visitor to locate op on an AST.
- class vta.top.graphpack.ExprPack(bfactor, cfactor, weight_bits)[源代码]#
Bases:
tvm.relay.ExprMutator
Visitor to perform graph packing on an AST.
- vta.top.graphpack._channel_const_match(channel_length, cfactor_out)[源代码]#
Round the channel const variant if the value not divisible by cfactor_out
- vta.top.graphpack._const_shape_match(data, dshape, cfactor_out)[源代码]#
Pad the constant if the shape[0] not divisible by cfactor_out.
- vta.top.graphpack._operator_idx_inc(expr, count_meta, operator_current_idx)[源代码]#
Increase operator index
- vta.top.graphpack._pack_batch_channel(data, dshape, bfactor, cfactor)[源代码]#
Pack the data channel dimension.
- vta.top.graphpack._pack_const(data, dshape, dtype, bfactor, cfactor)[源代码]#
Pack a constant parameter.
- vta.top.graphpack._pack_weight_conv2d_transpose(data, dshape, cfactor)[源代码]#
Pack the weight into packed format.
- vta.top.graphpack._unpack_batch_channel(data, old_shape, unpack_transpose=False)[源代码]#
Unpack the data channel dimension.
- vta.top.graphpack._weight_shape_match(data, dshape, channels, cfactor_out, transpose=False)[源代码]#
Pad the weight if the shape[0] not divisible by cfactor_out.
- vta.top.graphpack._weight_shape_match_transpose(data, dshape, channels, cfactor_out)[源代码]#
Pad the weight if the shape[1] not divisible by cfactor_out.
- vta.top.graphpack.get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)[源代码]#
We assume stop_name only appears once for simplicity. This constraint will be lifted in the future. bitpack_start and bitpack_end are both inclusive.
- vta.top.graphpack.graph_pack(expr, bfactor, cfactor, weight_bits, start_name='nn.max_pool2d', stop_name='nn.global_avg_pool2d', start_name_idx=None, stop_name_idx=None, count_meta=False, device_annot=False, annot_start_name='nn.conv2d', annot_end_name='annotation.stop_fusion')[源代码]#
Pack the graph into batch&channel packed format.
Parameters#
- exprrelay.Expr
The input program.
- bfactorint
The packing factor in batch
- cfactorint
The packing factor in channel
- weight_bits: int
The bit-width of the weights.
- start_name: str, optional
Start packing from certain known node when start_name_idx is None.
- stop_name: str, optional
Stop packing from certain known node when stop_name_idx is None.
- start_name_idx: int, optional
When start_name_idx not None, start packing only when node name equal start_name and node idx equals start_name_idx.
- stop_name_idx: int, optional
When stop_name_idx not None, stop packing only when node name equal stop_name and node index equals stop_name_idx.
- count_meta:boolean, optional
When count_meta is False, the operator increase logic would not count the meta that have the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from 'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase logic would count the meta.
- device_annot: boolean, optional
if we want to annoate the device_type
- annot_start_name: str, optional
device annotation start node, from which we mark the nodes as ext_dev
- annot_end_name: str, optional
device annotation end node, after which we mark the nodes as 'cpu'
Returns#
- exprExpr
The transformed expression.