vta.environment 源代码
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configurable VTA Hareware Environment scope."""
# pylint: disable=invalid-name, exec-used
from __future__ import absolute_import as _abs
import os
import json
import copy
import tvm
from tvm import te
from . import intrin
[文档]
def get_vta_hw_path():
"""Get the VTA HW path."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
vta_hw_default = os.path.abspath(os.path.join(curr_path, "../../../3rdparty/vta-hw"))
VTA_HW_PATH = os.getenv("VTA_HW_PATH", vta_hw_default)
return os.path.abspath(VTA_HW_PATH)
[文档]
def pkg_config(cfg):
"""Returns PkgConfig pkg config object."""
pkg_config_py = os.path.join(get_vta_hw_path(), "config/pkg_config.py")
libpkg = {"__file__": pkg_config_py}
exec(compile(open(pkg_config_py, "rb").read(), pkg_config_py, "exec"), libpkg, libpkg)
PkgConfig = libpkg["PkgConfig"]
return PkgConfig(cfg)
[文档]
class DevContext(object):
"""Internal development context
This contains all the non-user facing compiler
internal context that is hold by the Environment.
Parameters
----------
env : Environment
The environment hosting the DevContext
Note
----
This class is introduced so we have a clear separation
of developer related, and user facing attributes.
"""
# Memory id for DMA
# VTA ALU Opcodes
# Task queue id (pipeline stage)
def __init__(self, env):
[文档]
self.vta_axis = te.thread_axis("vta")
[文档]
self.vta_push_uop = tvm.tir.StringImm("VTAPushGEMMOp")
ctx = tvm.tir.call_intrin("handle", "tir.vta.command_handle")
[文档]
self.command_handle = tvm.tir.Call("handle", "tir.tvm_thread_context", [ctx])
[文档]
self.DEBUG_NO_SYNC = False
env._dev_ctx = self
[文档]
self.gemm = intrin.gemm(env, env.mock_mode)
[文档]
def get_task_qid(self, qid):
"""Get transformed queue index."""
return 1 if self.DEBUG_NO_SYNC else qid
[文档]
class Environment(object):
"""Hardware configuration object.
This object contains all the information
needed for compiling to a specific VTA backend.
Parameters
----------
cfg : dict of str to value.
The configuration parameters.
Example
--------
.. code-block:: python
# the following code reconfigures the environment
# temporarily to attributes specified in new_cfg.json
new_cfg = json.load(json.load(open("new_cfg.json")))
with vta.Environment(new_cfg):
# env works on the new environment
env = vta.get_env()
"""
# constants
# debug flags
[文档]
DEBUG_DUMP_INSN = 1 << 1
[文档]
DEBUG_DUMP_UOP = 1 << 2
[文档]
DEBUG_SKIP_READ_BARRIER = 1 << 3
[文档]
DEBUG_SKIP_WRITE_BARRIER = 1 << 4
# memory scopes
[文档]
inp_scope = "local.inp_buffer"
[文档]
wgt_scope = "local.wgt_buffer"
[文档]
acc_scope = "local.acc_buffer"
# initialization function
def __init__(self, cfg):
# Produce the derived parameters and update dict
[文档]
self.pkg = pkg_config(cfg)
self.__dict__.update(self.pkg.cfg_dict)
# data type width
[文档]
self.INP_WIDTH = 1 << self.LOG_INP_WIDTH
[文档]
self.WGT_WIDTH = 1 << self.LOG_WGT_WIDTH
[文档]
self.ACC_WIDTH = 1 << self.LOG_ACC_WIDTH
[文档]
self.OUT_WIDTH = 1 << self.LOG_OUT_WIDTH
# tensor intrinsic shape
[文档]
self.BATCH = 1 << self.LOG_BATCH
[文档]
self.BLOCK_IN = 1 << self.LOG_BLOCK_IN
[文档]
self.BLOCK_OUT = 1 << self.LOG_BLOCK_OUT
# buffer size
[文档]
self.UOP_BUFF_SIZE = 1 << self.LOG_UOP_BUFF_SIZE
[文档]
self.INP_BUFF_SIZE = 1 << self.LOG_INP_BUFF_SIZE
[文档]
self.WGT_BUFF_SIZE = 1 << self.LOG_WGT_BUFF_SIZE
[文档]
self.ACC_BUFF_SIZE = 1 << self.LOG_ACC_BUFF_SIZE
[文档]
self.OUT_BUFF_SIZE = 1 << self.LOG_OUT_BUFF_SIZE
# bytes per buffer
[文档]
self.INP_ELEM_BITS = self.BATCH * self.BLOCK_IN * self.INP_WIDTH
[文档]
self.WGT_ELEM_BITS = self.BLOCK_OUT * self.BLOCK_IN * self.WGT_WIDTH
[文档]
self.ACC_ELEM_BITS = self.BATCH * self.BLOCK_OUT * self.ACC_WIDTH
[文档]
self.OUT_ELEM_BITS = self.BATCH * self.BLOCK_OUT * self.OUT_WIDTH
[文档]
self.INP_ELEM_BYTES = self.INP_ELEM_BITS // 8
[文档]
self.WGT_ELEM_BYTES = self.WGT_ELEM_BITS // 8
[文档]
self.ACC_ELEM_BYTES = self.ACC_ELEM_BITS // 8
[文档]
self.OUT_ELEM_BYTES = self.OUT_ELEM_BITS // 8
# dtypes
[文档]
self.acc_dtype = "int%d" % self.ACC_WIDTH
[文档]
self.inp_dtype = "int%d" % self.INP_WIDTH
[文档]
self.wgt_dtype = "int%d" % self.WGT_WIDTH
[文档]
self.out_dtype = "int%d" % self.OUT_WIDTH
# bistream name
[文档]
self.BITSTREAM = self.pkg.bitstream
# model string
[文档]
self.MODEL = self.TARGET + "_" + self.BITSTREAM
# lazy cached members
[文档]
self.mock_mode = False
[文档]
self._mock_env = None
[文档]
self._dev_ctx = None
[文档]
self._last_env = None
[文档]
def __enter__(self):
self._last_env = Environment.current
Environment.current = self
return self
[文档]
def __exit__(self, ptype, value, trace):
Environment.current = self._last_env
@property
[文档]
def cfg_dict(self):
return self.pkg.cfg_dict
@property
[文档]
def dev(self):
"""Developer context"""
if self._dev_ctx is None:
self._dev_ctx = DevContext(self)
return self._dev_ctx
@property
[文档]
def mock(self):
"""A mock version of the Environment
The ALU, dma_copy and intrinsics will be
mocked to be nop.
"""
if self.mock_mode:
return self
if self._mock_env is None:
self._mock_env = copy.copy(self)
self._mock_env._dev_ctx = None
self._mock_env.mock_mode = True
return self._mock_env
@property
[文档]
def dma_copy(self):
"""DMA copy pragma"""
return "dma_copy" if not self.mock_mode else "skip_dma_copy"
@property
[文档]
def alu(self):
"""ALU pragma"""
return "alu" if not self.mock_mode else "skip_alu"
@property
[文档]
def gemm(self):
"""GEMM intrinsic"""
return self.dev.gemm
@property
[文档]
def target(self):
return tvm.target.vta(model=self.MODEL)
@property
[文档]
def target_host(self):
"""The target host"""
if self.TARGET in ["pynq", "de10nano"]:
return "llvm -mtriple=armv7-none-linux-gnueabihf"
if self.TARGET == "ultra96":
return "llvm -mtriple=aarch64-linux-gnu"
if self.TARGET in ["sim", "tsim", "intelfocl"]:
return "llvm"
raise ValueError("Unknown target %s" % self.TARGET)
@property
[文档]
def target_vta_cpu(self):
return tvm.target.arm_cpu(model=self.TARGET)
[文档]
def get_env():
"""Get the current VTA Environment.
Returns
-------
env : Environment
The current environment.
"""
return Environment.current
[文档]
def _init_env():
"""Initialize the default global env"""
config_path = os.path.join(get_vta_hw_path(), "config/vta_config.json")
if not os.path.exists(config_path):
raise RuntimeError("Cannot find config in %s" % str(config_path))
cfg = json.load(open(config_path))
return Environment(cfg)
Environment.current = _init_env()