PyTorch 基准测试#

参考:benchmark & github

定义同一功能两种不同实现以备后续测试#

import torch


def batched_dot_mul_sum(a, b):
    '''Computes batched dot by multiplying and summing'''
    return a.mul(b).sum(-1)


def batched_dot_bmm(a, b):
    '''Computes batched dot by reducing to ``bmm``'''
    a = a.reshape(-1, 1, a.shape[-1])
    b = b.reshape(-1, b.shape[-1], 1)
    return torch.bmm(a, b).flatten(-3)


# Input for benchmarking
x = torch.randn(10000, 64)

# Ensure that both functions compute the same output
assert batched_dot_mul_sum(x, x).allclose(batched_dot_bmm(x, x))

timeit 测试#

import timeit

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
mul_sum(x, x):  106.9 us
bmm(x, x):      117.7 us

使用 torch.utils.benchmark.Timer 测试#

from torch.utils import benchmark

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

print(t0.timeit(100))
print(t1.timeit(100))
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd2045ebfa0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  471.74 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd34474b580>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  888.97 us
  1 measurement, 100 runs , 1 thread

尽管基本功能方面的 API 相同,但存在一些重要差异。benchmark.Timer.timeit() 返回每次运行的时间,而不是像 timeit.Timer.timeit() 那样返回总的运行时间。PyTorch 基准测试模块还为打印结果提供了格式化字符串表示。

另一个重要差异,也是结果不同的原因,是 PyTorch 基准测试模块默认在单个线程中运行。我们可以通过 num_threads 参数更改线程数。

torch.utils.benchmark.Timer 还接受几个额外的参数,包括:labelsub_labeldescriptionenv,它们会改变返回的测量对象的 __repr__,并用于对结果进行分组(稍后会详细介绍)。

num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')

t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using mul and sum')

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x},
    num_threads=num_threads,
    label='Multithreaded batch dot',
    sub_label='Implemented using bmm')

print(t0.timeit(100))
print(t1.timeit(100))
Benchmarking on 24 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd2045eaf80>
Multithreaded batch dot: Implemented using mul and sum
setup: from __main__ import batched_dot_mul_sum
  102.00 us
  1 measurement, 100 runs , 24 threads
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd2045eb2b0>
Multithreaded batch dot: Implemented using bmm
setup: from __main__ import batched_dot_bmm
  101.13 us
  1 measurement, 100 runs , 24 threads

使用所有可用线程运行基准测试与 timeit 模块的结果相似。更重要的是,哪个版本更快取决于我们用多少个线程运行代码。这就是为什么使用代表实际用例的线程设置对代码进行基准测试很重要的原因。另一个要记住的重要事情是在 GPU 上进行基准测试时要同步 CPU 和 CUDA。让我们再次在 CUDA 张量上运行上述基准测试,看看会发生什么。

x = torch.randn(10000, 1024, device='cuda')

t0 = timeit.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = timeit.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Ran each twice to show difference before/after warm-up
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'mul_sum(x, x):  {t0.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
print(f'bmm(x, x):      {t1.timeit(100) / 100 * 1e6:>5.1f} us')
mul_sum(x, x):  634.3 us
mul_sum(x, x):   35.7 us
bmm(x, x):      2649.5 us
bmm(x, x):       37.3 us
t0 = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='from __main__ import batched_dot_mul_sum',
    globals={'x': x})

t1 = benchmark.Timer(
    stmt='batched_dot_bmm(x, x)',
    setup='from __main__ import batched_dot_bmm',
    globals={'x': x})

# Run only once since benchmark module does warm-up for us
print(t0.timeit(100))
print(t1.timeit(100))
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd34db777c0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  150.93 us
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd34c0c2650>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  50.30 us
  1 measurement, 100 runs , 1 thread

结果揭示了一些有趣的事情。使用 timeit 模块的 bmm 版本第一次运行比第二次运行花费更长的时间。这是因为 bmm 调用 cuBLAS,而 cuBLAS 需要在第一次调用时加载,这需要一些时间。这就是为什么在基准测试之前进行预热运行很重要的原因,幸运的是,PyTorch 的基准测试模块可以处理这个问题。

timeitbenchmark 模块之间的结果差异是因为 timeit 模块没有同步 CUDA,因此只计时内核启动的时间。PyTorch 的基准测试模块为我们执行了同步操作。

使用Blocked Autorange 进行基准测试#

timeit.Timer.autorange() 至少进行 0.2 秒的连续测量,而 torch.utils.benchmark.blocked_autorange() 进行多次测量,其总时间至少为 0.2 秒(可以通过 min_run_time 参数进行更改),但受到计时开销占整体测量一小部分的限制。这是通过首先以不断增加的循环次数运行来实现的,直到运行时间远大于测量开销(这也作为预热),然后进行测量,直到达到目标时间。这种方法具有有用的特性,即浪费较少的数据,并允许我们计算统计数据来估计测量的可靠性。

m0 = t0.blocked_autorange()
m1 = t1.blocked_autorange()

print(m0)
print(m1)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd34c0c1db0>
batched_dot_mul_sum(x, x)
setup: from __main__ import batched_dot_mul_sum
  149.64 us
  1 measurement, 10000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd34c0c1ba0>
batched_dot_bmm(x, x)
setup: from __main__ import batched_dot_bmm
  49.85 us
  1 measurement, 10000 runs , 1 thread

我们还可以检查从返回的测量对象中获取的单个统计数据。

print(f"Mean:   {m0.mean * 1e6:6.2f} us")
print(f"Median: {m0.median * 1e6:6.2f} us")
Mean:   149.64 us
Median: 149.64 us

比较基准测试结果#

到目前为止,我们一直在将两个版本的批量点积与单个输入进行比较。在实践中,我们还希望尝试不同数量的输入和线程的组合。Compare 类可以帮助以格式化的表格形式显示许多测量结果。它使用上述注释(标签、子标签、线程数等)以及描述来对表格进行分组和组织。让我们使用 Compare 来看看我们的函数在不同输入大小和线程数下的性能如何。

from itertools import product

# Compare takes a list of measurements which we'll save in results.
results = []

sizes = [1, 64, 1024, 10000]
for b, n in product(sizes, sizes):
    # label and sub_label are the rows
    # description is the column
    label = 'Batched dot'
    sub_label = f'[{b}, {n}]'
    x = torch.ones((b, n))
    for num_threads in [1, 4, 16, 32]:
        results.append(benchmark.Timer(
            stmt='batched_dot_mul_sum(x, x)',
            setup='from __main__ import batched_dot_mul_sum',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='mul/sum',
        ).blocked_autorange(min_run_time=1))
        results.append(benchmark.Timer(
            stmt='batched_dot_bmm(x, x)',
            setup='from __main__ import batched_dot_bmm',
            globals={'x': x},
            num_threads=num_threads,
            label=label,
            sub_label=sub_label,
            description='bmm',
        ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.print()
[--------------- Batched dot ----------------]
                      |  mul/sum   |    bmm   
1 threads: -----------------------------------
      [1, 1]          |       5.8  |       9.2
      [1, 64]         |       6.2  |       9.6
      [1, 1024]       |       6.5  |      10.4
      [1, 10000]      |       9.2  |      11.7
      [64, 1]         |       6.4  |       9.6
      [64, 64]        |       8.2  |      13.8
      [64, 1024]      |      36.5  |     229.7
      [64, 10000]     |     289.0  |    2114.2
      [1024, 1]       |       7.3  |      15.2
      [1024, 64]      |      51.8  |      90.8
      [1024, 1024]    |     461.1  |    3465.2
      [1024, 10000]   |   27525.8  |   33693.0
      [10000, 1]      |      24.7  |      73.8
      [10000, 64]     |     327.4  |     684.2
      [10000, 1024]   |   28438.8  |   35480.5
      [10000, 10000]  |  288529.2  |  378751.1
4 threads: -----------------------------------
      [1, 1]          |       5.8  |       9.6
      [1, 64]         |       6.3  |       9.3
      [1, 1024]       |       6.4  |      10.5
      [1, 10000]      |       9.3  |      11.9
      [64, 1]         |       6.4  |       9.6
      [64, 64]        |       8.3  |      20.3
      [64, 1024]      |      48.1  |     307.2
      [64, 10000]     |      90.0  |    3706.0
      [1024, 1]       |      11.2  |      21.0
      [1024, 64]      |      49.4  |      60.9
      [1024, 1024]    |     139.1  |     944.6
      [1024, 10000]   |   10918.2  |    8791.5
      [10000, 1]      |      24.1  |      67.9
      [10000, 64]     |     111.4  |     196.7
      [10000, 1024]   |   10966.8  |    9261.8
      [10000, 10000]  |  106744.5  |   91647.6
16 threads: ----------------------------------
      [1, 1]          |       8.1  |      11.9
      [1, 64]         |       6.3  |      10.3
      [1, 1024]       |      10.1  |      17.2
      [1, 10000]      |       9.9  |      12.3
      [64, 1]         |       6.4  |       9.6
      [64, 64]        |       8.4  |      13.9
      [64, 1024]      |      42.2  |     461.2
      [64, 10000]     |      65.5  |    4675.4
      [1024, 1]       |       7.6  |      15.3
      [1024, 64]      |      51.3  |      58.8
      [1024, 1024]    |      68.5  |     308.4
      [1024, 10000]   |    9078.2  |    2671.1
      [10000, 1]      |      18.4  |      66.3
      [10000, 64]     |      73.4  |     104.2
      [10000, 1024]   |    8494.2  |    2594.6
      [10000, 10000]  |   73330.4  |   24553.6
32 threads: ----------------------------------
      [1, 1]          |       5.8  |       9.5
      [1, 64]         |       8.7  |       9.5
      [1, 1024]       |      10.4  |      11.1
      [1, 10000]      |       9.8  |      11.8
      [64, 1]         |       6.3  |       9.6
      [64, 64]        |       8.4  |      14.1
      [64, 1024]      |      59.2  |     706.2
      [64, 10000]     |     124.4  |    7149.6
      [1024, 1]       |       8.0  |      15.4
      [1024, 64]      |      58.5  |      80.3
      [1024, 1024]    |      73.5  |     211.3
      [1024, 10000]   |    8362.5  |    1796.2
      [10000, 1]      |      18.2  |      67.6
      [10000, 64]     |     142.1  |     101.5
      [10000, 1024]   |    7403.5  |    1360.9
      [10000, 10000]  |   73554.7  |   15836.1

Times are in microseconds (us).

上述结果表明,对于在多个线程上运行的大型张量,可以简化为 bmm 的版本更好,而对于较小和/或单线程代码,另一个版本更好。

Compare 还提供了用于更改表格格式的函数。

compare.trim_significant_figures()
compare.colorize()
compare.print()
[-------------- Batched dot --------------]
                      |  mul/sum  |   bmm  
1 threads: --------------------------------
      [1, 1]          |        6  |       9
      [1, 64]         |        6  |      10
      [1, 1024]       |        6  |      10
      [1, 10000]      |        9  |      10
      [64, 1]         |        6  |      10
      [64, 64]        |        8  |      14
      [64, 1024]      |       36  |     230
      [64, 10000]     |      289  |    2100
      [1024, 1]       |        7  |      15
      [1024, 64]      |       52  |      91
      [1024, 1024]    |      461  |    3500
      [1024, 10000]   |    28000  |   34000
      [10000, 1]      |       25  |      74
      [10000, 64]     |      327  |     680
      [10000, 1024]   |    28400  |   40000
      [10000, 10000]  |   300000  |  400000
4 threads: --------------------------------
      [1, 1]          |        6  |      10
      [1, 64]         |        6  |       9
      [1, 1024]       |        6  |      10
      [1, 10000]      |        9  |      10
      [64, 1]         |        6  |      10
      [64, 64]        |        8  |      20
      [64, 1024]      |       50  |     300
      [64, 10000]     |       90  |    4000
      [1024, 1]       |       11  |      21
      [1024, 64]      |       49  |      60
      [1024, 1024]    |      140  |     940
      [1024, 10000]   |    10920  |    9000
      [10000, 1]      |       24  |      70
      [10000, 64]     |      111  |     197
      [10000, 1024]   |    10970  |    9260
      [10000, 10000]  |   107000  |   91600
16 threads: -------------------------------
      [1, 1]          |        8  |      10
      [1, 64]         |        6  |      10
      [1, 1024]       |       10  |      20
      [1, 10000]      |       10  |      10
      [64, 1]         |        6  |      10
      [64, 64]        |        8  |      14
      [64, 1024]      |       40  |     500
      [64, 10000]     |       66  |    5000
      [1024, 1]       |        8  |      15
      [1024, 64]      |       50  |      60
      [1024, 1024]    |       70  |     308
      [1024, 10000]   |     9100  |    3000
      [10000, 1]      |       18  |      66
      [10000, 64]     |       73  |     100
      [10000, 1024]   |     8000  |    3000
      [10000, 10000]  |    70000  |   25000
32 threads: -------------------------------
      [1, 1]          |        6  |      10
      [1, 64]         |        9  |       9
      [1, 1024]       |       10  |      10
      [1, 10000]      |       10  |      10
      [64, 1]         |        6  |      10
      [64, 64]        |        8  |      14
      [64, 1024]      |       60  |     700
      [64, 10000]     |      124  |    7000
      [1024, 1]       |        8  |      15
      [1024, 64]      |       58  |      80
      [1024, 1024]    |       74  |     210
      [1024, 10000]   |     8400  |    2000
      [10000, 1]      |       18  |      70
      [10000, 64]     |      140  |     101
      [10000, 1024]   |     7000  |    1000
      [10000, 10000]  |    74000  |   16000

Times are in microseconds (us).

保存/加载基准测试结果#

测量结果可以通过 pickle 模块进行序列化。这使得A/B测试变得简单,因为您可以从两个单独的环境中收集测量结果,将其 pickle 化,然后在单个环境中加载它们。Timer 甚至接受 env 构造函数参数,以便这种 A/B 测试可以无缝地工作。

让我们想象一下,不是使用两个 Python 函数,而是将 add/sumbmm 方法分别添加到 PyTorch 的两个不同构建中。下面的示例演示了如何对它们进行 A/B 测试。为了简单起见,我们只使用形状的子集,并简单地通过 pickle 来回传递结果,而不是实际使用多个环境并将结果写入磁盘。

import pickle

ab_test_results = []
for env in ('environment A: mul/sum', 'environment B: bmm'):
    for b, n in ((1, 1), (1024, 10000), (10000, 1)):
        x = torch.ones((b, n))
        dot_fn = (batched_dot_mul_sum if env == 'environment A: mul/sum' else batched_dot_bmm)
        m = benchmark.Timer(
            stmt='batched_dot(x, x)',
            globals={'x': x, 'batched_dot': dot_fn},
            num_threads=1,
            label='Batched dot',
            description=f'[{b}, {n}]',
            env=env,
        ).blocked_autorange(min_run_time=1)
        ab_test_results.append(pickle.dumps(m))

ab_results = [pickle.loads(i) for i in ab_test_results]
compare = benchmark.Compare(ab_results)
compare.trim_significant_figures()
compare.colorize()
compare.print()
[------------------------------------- Batched dot -------------------------------------]
                                               |  [1, 1]  |  [1024, 10000]  |  [10000, 1]
1 threads: ------------------------------------------------------------------------------
  (environment A: mul/sum)  batched_dot(x, x)  |     6    |      29000      |      20    
  (environment B: bmm)      batched_dot(x, x)  |    15    |      30000      |      73    

Times are in microseconds (us).
# And just to show that we can round trip all of the results from earlier:
round_tripped_results = pickle.loads(pickle.dumps(results))
assert(str(benchmark.Compare(results)) == str(benchmark.Compare(round_tripped_results)))

生成带有模糊参数的输入#

正如我们在上一节中所看到的,根据输入张量的不同,可能会产生一些明显的性能差异。因此,在多个不同的输入上运行基准测试是一个好主意。然而,创建所有这些输入张量可能会很繁琐,这就是 torch.utils.benchmark.Fuzzer 和相关类的作用所在。让我们看看如何使用 Fuzzer 为基准测试创建一些测试用例。

from torch.utils.benchmark import Fuzzer, FuzzedParameter, FuzzedTensor, ParameterAlias

# Generates random tensors with 128 to 10000000 elements and sizes k0 and k1 chosen from a
# ``loguniform`` distribution in [1, 10000], 40% of which will be discontiguous on average.
example_fuzzer = Fuzzer(
    parameters = [
        FuzzedParameter('k0', minval=1, maxval=10000, distribution='loguniform'),
        FuzzedParameter('k1', minval=1, maxval=10000, distribution='loguniform'),
    ],
    tensors = [
        FuzzedTensor('x', size=('k0', 'k1'), min_elements=128, max_elements=10000000, probability_contiguous=0.6)
    ],
    seed=0,
)

results = []
for tensors, tensor_params, params in example_fuzzer.take(10):
    # description is the column label
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()
[--------------------- Batched dot ---------------------]
                                     |  mul/sum  |   bmm 
1 threads: ----------------------------------------------
      725    x 257                   |     102   |    199
      49     x 383                   |      21   |     36
      34     x 1468                  |      40   |    190
      187    x 5039                  |     430   |   3100
      2140   x 1296 (discontiguous)  |    1820   |  82000
      78     x 1598                  |      63   |    420
      519    x 763                   |     182   |   1320
      141    x 1082                  |      76   |    510
      78     x 5    (discontiguous)  |       7   |     13
      187    x 1                     |       7   |     11

Times are in microseconds (us).

定义自己的模糊器具有很大的灵活性,这对于创建一组强大的输入来进行基准测试非常有用。但是为了让事情变得更简单,PyTorch基准测试模块附带了一些内置的模糊器,以满足常见的基准测试需求。让我们看看如何使用其中一个内置的模糊器。

from torch.utils.benchmark.op_fuzzers import binary

results = []
for tensors, tensor_params, params in binary.BinaryOpFuzzer(seed=0).take(10):
    sub_label=f"{params['k0']:<6} x {params['k1']:<4} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
    results.append(benchmark.Timer(
        stmt='batched_dot_mul_sum(x, x)',
        setup='from __main__ import batched_dot_mul_sum',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='mul/sum',
    ).blocked_autorange(min_run_time=1))
    results.append(benchmark.Timer(
        stmt='batched_dot_bmm(x, x)',
        setup='from __main__ import batched_dot_bmm',
        globals=tensors,
        label='Batched dot',
        sub_label=sub_label,
        description='bmm',
    ).blocked_autorange(min_run_time=1))

compare = benchmark.Compare(results)
compare.trim_significant_figures()
compare.colorize(rowwise=True)
compare.print()
[----------------------- Batched dot ------------------------]
                                         |  mul/sum  |   bmm  
1 threads: ---------------------------------------------------
      64     x 473  (discontiguous)      |    13900  |  101000
      16384  x 12642115 (discontiguous)  |       32  |     122
      8192   x 892                       |     6930  |   24000
      512    x 64   (discontiguous)      |   104000  |  369000
      493    x 27   (discontiguous)      |     2020  |    4500
      118    x 32   (discontiguous)      |      903  |    2650
      16     x 495  (discontiguous)      |    20000  |   35000
      488    x 62374                     |    84300  |   98700
      240372 x 69                        |    48000  |   17900
      40156  x 32   (discontiguous)      |     1890  |    5100

Times are in microseconds (us).

收集带有Callgrind的指令计数#

优化代码的一个挑战是墙钟的时间变化和不透明度。有很多非确定性来源,从自适应时钟速度到与其他进程的资源竞争。此外,端到端时间无法提供关于在哪里花费时间的洞察,这是我们在优化代码时真正感兴趣的。

补充方法是同时收集指令计数。这些计数是代理指标,并不捕获所有性能方面(例如内存或I/O绑定任务),但它们确实具有一些有用的属性。指令计数是可重复的,对环境变化不敏感,并提供了程序在何处花费周期的细粒度洞察。

要查看指令计数的效用,让我们看看如何减少 batched_dot_mul_sum 的开销。显而易见的解决方案是将它移动到C++,这样我们就可以避免在Python和C++之间多次往返。

幸运的是,源代码几乎相同。我们在 C++ 中必须问一个问题:我们应该通过值还是引用传递参数?

batched_dot_src = """\
/* ---- Python ---- */
// def batched_dot_mul_sum(a, b):
//     return a.mul(b).sum(-1)

torch::Tensor batched_dot_mul_sum_v0(
    const torch::Tensor a,
    const torch::Tensor b) {
  return a.mul(b).sum(-1);
}

torch::Tensor batched_dot_mul_sum_v1(
    const torch::Tensor& a,
    const torch::Tensor& b) {
  return a.mul(b).sum(-1);
}
"""


# PyTorch makes it easy to test our C++ implementations by providing a utility
# to JIT compile C++ source into Python extensions:
import os
from torch.utils import cpp_extension
cpp_lib = cpp_extension.load_inline(
    name='cpp_lib',
    cpp_sources=batched_dot_src,
    extra_cflags=['-O3'],
    extra_include_paths=[
        # `load_inline` needs to know where to find ``pybind11`` headers.
        os.path.join(os.getenv('CONDA_PREFIX'), 'include')
    ],
    functions=['batched_dot_mul_sum_v0', 'batched_dot_mul_sum_v1']
)

# `load_inline` will create a shared object that is loaded into Python. When we collect
# instruction counts Timer will create a subprocess, so we need to re-import it. The
# import process is slightly more complicated for C extensions, but that's all we're
# doing here.
module_import_str = f"""\
# https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
import importlib.util
spec = importlib.util.spec_from_file_location("cpp_lib", {repr(cpp_lib.__file__)})
cpp_lib = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cpp_lib)"""

import textwrap
def pretty_print(result):
    """Import machinery for ``cpp_lib.so`` can get repetitive to look at."""
    print(repr(result).replace(textwrap.indent(module_import_str, "  "), "  import cpp_lib"))


t_baseline = benchmark.Timer(
    stmt='batched_dot_mul_sum(x, x)',
    setup='''\
from __main__ import batched_dot_mul_sum
x = torch.randn(2, 2)''')

t0 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v0(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

t1 = benchmark.Timer(
    stmt='cpp_lib.batched_dot_mul_sum_v1(x, x)',
    setup=f'''\
{module_import_str}
x = torch.randn(2, 2)''')

# Moving to C++ did indeed reduce overhead, but it's hard to tell which
# calling convention is more efficient. v1 (call with references) seems to
# be a bit faster, but it's within measurement error.
pretty_print(t_baseline.blocked_autorange())
pretty_print(t0.blocked_autorange())
pretty_print(t1.blocked_autorange())
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd1dbbe9b10>
batched_dot_mul_sum(x, x)
setup:
  from __main__ import batched_dot_mul_sum
  x = torch.randn(2, 2)

  8.92 us
  1 measurement, 100000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd1dbbea680>
cpp_lib.batched_dot_mul_sum_v0(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)

  7.33 us
  1 measurement, 100000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fd34474b760>
cpp_lib.batched_dot_mul_sum_v1(x, x)
setup:
  import cpp_lib
  x = torch.randn(2, 2)

  7.51 us
  1 measurement, 100000 runs , 1 thread
# Let's use ``Callgrind`` to determine which is better.
stats_v0 = t0.collect_callgrind()
stats_v1 = t1.collect_callgrind()

pretty_print(stats_v0)
pretty_print(stats_v1)

# `.as_standardized` removes file names and some path prefixes, and makes
# it easier to read the function symbols.
stats_v0 = stats_v0.as_standardized()
stats_v1 = stats_v1.as_standardized()

# `.delta` diffs the instruction counts, and `.denoise` removes several
# functions in the Python interpreter that are known to have significant
# jitter.
delta = stats_v1.delta(stats_v0).denoise()

# `.transform` is a convenience API for transforming function names. It is
# useful for increasing cancelation when ``diff-ing`` instructions, as well as
# just generally improving readability.
replacements = (
    ("???:void pybind11", "pybind11"),
    ("batched_dot_mul_sum_v0", "batched_dot_mul_sum_v1"),
    ("at::Tensor, at::Tensor", "..."),
    ("at::Tensor const&, at::Tensor const&", "..."),
    ("auto torch::detail::wrap_pybind_function_impl_", "wrap_pybind_function_impl_"),
)
for before, after in replacements:
    delta = delta.transform(lambda l: l.replace(before, after))

# We can use print options to control how much of the function to display.
torch.set_printoptions(linewidth=160)

# Once parsed, the instruction counts make clear that passing `a` and `b`
# by reference is more efficient as it skips some ``c10::TensorImpl`` bookkeeping
# for the intermediate Tensors, and is also works better with ``pybind11``. This
# is consistent with our noisy wall time observations.
print(delta)
---------------------------------------------------------------------------
OSError                                   Traceback (most recent call last)
/media/pc/data/lxw/ai/torch-book/doc/recipes/benchmark.ipynb 单元格 30 line 2
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/benchmark.ipynb#X60sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a> # Let's use ``Callgrind`` to determine which is better.
----> <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/benchmark.ipynb#X60sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1'>2</a> stats_v0 = t0.collect_callgrind()
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/benchmark.ipynb#X60sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2'>3</a> stats_v1 = t1.collect_callgrind()
      <a href='vscode-notebook-cell://ssh-remote%2B10.16.11.3/media/pc/data/lxw/ai/torch-book/doc/recipes/benchmark.ipynb#X60sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> pretty_print(stats_v0)

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py:486, in Timer.collect_callgrind(self, number, repeats, collect_baseline, retain_out_file)
    484 is_python = (self._language == Language.PYTHON)
    485 assert is_python or not self._globals
--> 486 result = valgrind_timer_interface.wrapper_singleton().collect_callgrind(
    487     task_spec=self._task_spec,
    488     globals=self._globals,
    489     number=number,
    490     repeats=repeats or 1,
    491     collect_baseline=collect_baseline and is_python,
    492     is_python=is_python,
    493     retain_out_file=retain_out_file,
    494 )
    496 return (result[0] if repeats is None else result)

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py:527, in _ValgrindWrapper.collect_callgrind(self, task_spec, globals, number, repeats, collect_baseline, is_python, retain_out_file)
    515 def collect_callgrind(
    516     self,
    517     task_spec: common.TaskSpec,
   (...)
    524     retain_out_file: bool,
    525 ) -> Tuple[CallgrindStats, ...]:
    526     """Collect stats, and attach a reference run which can be used to filter interpreter overhead."""
--> 527     self._validate()
    528     assert is_python or not collect_baseline
    530     *task_stats, baseline_stats = self._invoke(
    531         task_spec=task_spec,
    532         globals=globals,
   (...)
    537         retain_out_file=retain_out_file,
    538     )

File /media/pc/data/tmp/cache/conda/envs/tvmz/lib/python3.10/site-packages/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py:513, in _ValgrindWrapper._validate(self)
    511 missing_cmds = [cmd for cmd, available in self._commands_available.items() if not available]
    512 if missing_cmds:
--> 513     raise OSError("Missing: " + ", ".join(missing_cmds))

OSError: Missing: valgrind, callgrind_control, callgrind_annotate