Spaces:
Runtime error
Runtime error
| /* | |
| * Code imported via patch from https://github.com/triton-lang/triton/pull/6570, commit 2afae45951b74785b144151b31e91e6c82b0b02f. | |
| * Copyright (c) 2018-2022 Philippe Tillet, OpenAI. | |
| * Licensed under the MIT License. | |
| */ | |
| From 2afae45951b74785b144151b31e91e6c82b0b02f Mon Sep 17 00:00:00 2001 | |
| From: Han Zhu <[email protected]> | |
| Date: Tue, 22 Apr 2025 18:42:23 -0700 | |
| Subject: [PATCH] [autotuner] Lazily initiailize do_bench | |
| --- | |
| diff --git a/a/autotuner.py b/b/autotuner.py | |
| index 0ee6bea09..b75c5e353 100644 | |
| --- a/a/autotuner.py | |
| +++ b/b/autotuner.py | |
| import builtins | |
| import os | |
| import time | |
| import inspect | |
| +from functools import cached_property | |
| from typing import Dict, Tuple, List, Optional | |
| from .jit import KernelInterface | |
| class Autotuner(KernelInterface): | |
| while not inspect.isfunction(self.base_fn): | |
| self.base_fn = self.base_fn.fn | |
| + self._do_bench = do_bench | |
| self.num_warmups = warmup | |
| self.num_reps = rep | |
| self.use_cuda_graph = use_cuda_graph | |
| class Autotuner(KernelInterface): | |
| stacklevel=1) | |
| if use_cuda_graph: | |
| from ..testing import do_bench_cudagraph | |
| - self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( | |
| + self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph( | |
| kernel_call, | |
| rep=rep if rep is not None else 100, | |
| quantiles=quantiles, | |
| class Autotuner(KernelInterface): | |
| return | |
| import triton.testing | |
| - self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( | |
| + self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench( | |
| kernel_call, | |
| warmup=warmup if warmup is not None else 25, | |
| rep=rep if rep is not None else 100, | |
| class Autotuner(KernelInterface): | |
| ) | |
| return | |
| - if do_bench is None: | |
| - self.do_bench = driver.active.get_benchmarker() | |
| - else: | |
| - self.do_bench = do_bench | |
| + @cached_property | |
| + def do_bench(self): | |
| + if self._do_bench is None: | |
| + return driver.active.get_benchmarker() | |
| + return self._do_bench | |
| def _bench(self, *args, config, **meta): | |
| from ..compiler.errors import CompileTimeAssertionFailure | |