tvm.topi.nn.fifo_buffer 源代码

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""FIFO buffer op"""
from __future__ import absolute_import as _abs
import tvm
from tvm import te
from .. import tag
from ..transform import concatenate, strided_slice


[文档] @tvm.te.tag_scope(tag=tag.INJECTIVE + ",fifo_buffer") def fifo_buffer(data, buffer, axis): """ FIFO buffer to enable computation reuse in CNNs with sliding indow input Compute equivalent of .. code-block:: python concat(buffer, data, axis=axis) .slice_axis(axis=axis, begin=data.shape[axis], end=data.shape[axis]+buffer.shape[axis]) Useful for * Encoding explicit re-use of computation in convolution ops operated on a sliding window input * Implementing a FIFO queue to cache intermediate results, e.g. as in Fast WaveNet. Parameters ---------- data : tvm.te.Tensor The input data buffer : tvm.te.Tensor Previous value of the FIFO buffer axis : int Specify which axis should be used for buffering Returns ------- result : tvm.te.Tensor Updated value for the buffer """ assert len(data.shape) == len(buffer.shape), ( f"buffer and data must have same number of dimensions, " f"buffer.shape = {buffer.shape}, data.shape = {data.shape}" ) assert len(buffer.shape) >= 1, "Zero-dimension tensor not supported" assert 0 <= axis < len(buffer.shape), "buffer axis out of range" for i in range(len(data.shape)): if i == axis: assert int(str(data.shape[i])) <= int(str(buffer.shape[i])) else: assert int(str(data.shape[i])) == int(str(buffer.shape[i])) buflen = buffer.shape[axis] data_size = data.shape[axis] # Explicitly write out formula up to 4D, and then use concat+slice combo for 5D and higher if len(buffer.shape) == 1: return te.compute( buffer.shape, lambda i: tvm.tir.if_then_else( i < buflen - data_size, buffer[i + data_size], data[i - buflen + data_size] ), name="new_buffer", ) if len(buffer.shape) == 2: if axis == 0: return te.compute( buffer.shape, lambda i, j: tvm.tir.if_then_else( i < buflen - data_size, buffer[i + data_size, j], data[i - buflen + data_size, j], ), name="new_buffer", ) if axis == 1: return te.compute( buffer.shape, lambda i, j: tvm.tir.if_then_else( j < buflen - data_size, buffer[i, j + data_size], data[i, j - buflen + data_size], ), name="new_buffer", ) assert False, f"Invalid value for axis; it should be at most {len(buffer.shape)}" elif len(buffer.shape) == 3: if axis == 0: return te.compute( buffer.shape, lambda i, j, k: tvm.tir.if_then_else( i < buflen - data_size, buffer[i + data_size, j, k], data[i - buflen + data_size, j, k], ), name="new_buffer", ) if axis == 1: return te.compute( buffer.shape, lambda i, j, k: tvm.tir.if_then_else( j < buflen - data_size, buffer[i, j + data_size, k], data[i, j - buflen + data_size, k], ), name="new_buffer", ) if axis == 2: return te.compute( buffer.shape, lambda i, j, k: tvm.tir.if_then_else( k < buflen - data_size, buffer[i, j, k + data_size], data[i, j, k - buflen + data_size], ), name="new_buffer", ) assert False, f"Invalid value for axis; it should be at most {len(buffer.shape)}" elif len(buffer.shape) == 4: if axis == 0: return te.compute( buffer.shape, lambda i, j, k, l: tvm.tir.if_then_else( i < buflen - data_size, buffer[i + data_size, j, k, l], data[i - buflen + data_size, j, k, l], ), name="new_buffer", ) if axis == 1: return te.compute( buffer.shape, lambda i, j, k, l: tvm.tir.if_then_else( j < buflen - data_size, buffer[i, j + data_size, k, l], data[i, j - buflen + data_size, k, l], ), name="new_buffer", ) if axis == 2: return te.compute( buffer.shape, lambda i, j, k, l: tvm.tir.if_then_else( k < buflen - data_size, buffer[i, j, k + data_size, l], data[i, j, k - buflen + data_size, l], ), name="new_buffer", ) if axis == 3: return te.compute( buffer.shape, lambda i, j, k, l: tvm.tir.if_then_else( l < buflen - data_size, buffer[i, j, k, l + data_size], data[i, j, k, l - buflen + data_size], ), name="new_buffer", ) assert False, f"Invalid value for axis; it should be at most {len(buffer.shape)}" else: # Implement FIFO buffer as combination of concat and slice begin = [0] * len(buffer.shape) begin[axis] = data.shape[axis] end = list(buffer.shape[:]) end[axis] += data.shape[axis] return strided_slice(concatenate((buffer, data), axis=axis), begin=begin, end=end) return None