File size: 2,478 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
/*
 * Code imported via patch from https://github.com/triton-lang/triton/pull/6570, commit 2afae45951b74785b144151b31e91e6c82b0b02f.
 * Copyright (c) 2018-2022 Philippe Tillet, OpenAI.
 * Licensed under the MIT License.
 */

From 2afae45951b74785b144151b31e91e6c82b0b02f Mon Sep 17 00:00:00 2001
From: Han Zhu <[email protected]>
Date: Tue, 22 Apr 2025 18:42:23 -0700
Subject: [PATCH] [autotuner] Lazily initiailize do_bench

---
diff --git a/a/autotuner.py b/b/autotuner.py
index 0ee6bea09..b75c5e353 100644
--- a/a/autotuner.py
+++ b/b/autotuner.py
@@ -4,6 +4,7 @@ import builtins
 import os
 import time
 import inspect
+from functools import cached_property
 from typing import Dict, Tuple, List, Optional

 from .jit import KernelInterface
@@ -94,6 +95,7 @@ class Autotuner(KernelInterface):
         while not inspect.isfunction(self.base_fn):
             self.base_fn = self.base_fn.fn

+        self._do_bench = do_bench
         self.num_warmups = warmup
         self.num_reps = rep
         self.use_cuda_graph = use_cuda_graph
@@ -107,7 +109,7 @@ class Autotuner(KernelInterface):
                           stacklevel=1)
             if use_cuda_graph:
                 from ..testing import do_bench_cudagraph
-                self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
+                self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
                     kernel_call,
                     rep=rep if rep is not None else 100,
                     quantiles=quantiles,
@@ -115,7 +117,7 @@ class Autotuner(KernelInterface):
                 return

             import triton.testing
-            self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
+            self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
                 kernel_call,
                 warmup=warmup if warmup is not None else 25,
                 rep=rep if rep is not None else 100,
@@ -123,10 +125,11 @@ class Autotuner(KernelInterface):
             )
             return

-        if do_bench is None:
-            self.do_bench = driver.active.get_benchmarker()
-        else:
-            self.do_bench = do_bench
+    @cached_property
+    def do_bench(self):
+        if self._do_bench is None:
+            return driver.active.get_benchmarker()
+        return self._do_bench

     def _bench(self, *args, config, **meta):
         from ..compiler.errors import CompileTimeAssertionFailure