vta.exec.rpc_server 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""VTA customized TVM RPC Server

Provides additional runtime function and library loading.
"""
from __future__ import absolute_import

import logging
import argparse
import os
import ctypes
import json
import tvm
from tvm import rpc
from tvm.contrib import cc
from vta import program_bitstream

from ..environment import get_env, pkg_config
from ..libinfo import find_libvta


[文档] def server_start(): """VTA RPC server extension.""" # pylint: disable=unused-variable curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) proj_root = os.path.abspath(os.path.join(curr_path, "../../../../")) dll_path = find_libvta("libvta")[0] cfg_path = os.path.abspath(os.path.join(proj_root, "3rdparty/vta-hw/config/vta_config.json")) runtime_dll = [] _load_module = tvm.get_global_func("tvm.rpc.server.load_module") def load_vta_dll(): """Try to load vta dll""" if not runtime_dll: runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL)) logging.info("Loading VTA library: %s", dll_path) return runtime_dll[0] @tvm.register_func("tvm.rpc.server.load_module", override=True) def load_module(file_name): load_vta_dll() return _load_module(file_name) @tvm.register_func("device_api.ext_dev") def ext_dev_callback(): load_vta_dll() return tvm.get_global_func("device_api.ext_dev")() @tvm.register_func("tvm.contrib.vta.init", override=True) def program_fpga(file_name): # pylint: disable=import-outside-toplevel env = get_env() if env.TARGET == "pynq": from pynq import xlnk # Reset xilinx driver xlnk.Xlnk().xlnk_reset() elif env.TARGET == "de10nano": # Load the de10nano program function. load_vta_dll() path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name) program_bitstream.bitstream_program(env.TARGET, path) logging.info("Program FPGA with %s ", file_name) @tvm.register_func("tvm.rpc.server.shutdown", override=True) def server_shutdown(): if runtime_dll: runtime_dll[0].VTARuntimeShutdown() runtime_dll.pop() @tvm.register_func("tvm.contrib.vta.reconfig_runtime", override=True) def reconfig_runtime(cfg_json): """Rebuild and reload runtime with new configuration. Parameters ---------- cfg_json : str JSON string used for configurations. """ env = get_env() if runtime_dll: if env.TARGET == "de10nano": print("Please reconfigure the runtime AFTER programming a bitstream.") raise RuntimeError("Can only reconfig in the beginning of session...") cfg = json.loads(cfg_json) cfg["TARGET"] = env.TARGET pkg = pkg_config(cfg) # check if the configuration is already the same if os.path.isfile(cfg_path): old_cfg = json.loads(open(cfg_path, "r").read()) if pkg.same_config(old_cfg): logging.info("Skip reconfig_runtime due to same config.") return cflags = ["-O2", "-std=c++17"] cflags += pkg.cflags ldflags = pkg.ldflags lib_name = dll_path source = pkg.lib_source logging.info( "Rebuild runtime:\n output=%s,\n cflags=%s,\n source=%s,\n ldflags=%s", dll_path, "\n\t".join(cflags), "\n\t".join(source), "\n\t".join(ldflags), ) cc.create_shared(lib_name, source, cflags + ldflags) with open(cfg_path, "w") as outputfile: outputfile.write(pkg.cfg_json)
[文档] def main(): """Main funciton""" parser = argparse.ArgumentParser() parser.add_argument( "--host", type=str, default="0.0.0.0", help="The host IP address the server binds to" ) parser.add_argument("--port", type=int, default=9091, help="The port of the RPC") parser.add_argument("--port-end", type=int, default=9199, help="The end search port of the RPC") parser.add_argument( "--key", type=str, default="", help="RPC key used to identify the connection type." ) parser.add_argument("--tracker", type=str, default="", help="Report to RPC tracker") args = parser.parse_args() logging.basicConfig(level=logging.INFO) if args.tracker: url, port = args.tracker.split(":") port = int(port) tracker_addr = (url, port) if not args.key: raise RuntimeError("Need key to present type of resource when tracker is available") else: tracker_addr = None # register the initialization callback def server_init_callback(): # pylint: disable=redefined-outer-name, reimported, import-outside-toplevel, import-self import tvm import vta.exec.rpc_server tvm.register_func("tvm.rpc.server.start", vta.exec.rpc_server.server_start, override=True) server = rpc.Server( args.host, args.port, args.port_end, key=args.key, tracker_addr=tracker_addr, server_init_callback=server_init_callback, ) server.proc.join()
if __name__ == "__main__": main()