MagpieTTS_Internal_Demo / external /patches /triton-lang_triton_6570_lazy_init.patch
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
/*
* Code imported via patch from https://github.com/triton-lang/triton/pull/6570, commit 2afae45951b74785b144151b31e91e6c82b0b02f.
* Copyright (c) 2018-2022 Philippe Tillet, OpenAI.
* Licensed under the MIT License.
*/
From 2afae45951b74785b144151b31e91e6c82b0b02f Mon Sep 17 00:00:00 2001
From: Han Zhu <[email protected]>
Date: Tue, 22 Apr 2025 18:42:23 -0700
Subject: [PATCH] [autotuner] Lazily initiailize do_bench
---
diff --git a/a/autotuner.py b/b/autotuner.py
index 0ee6bea09..b75c5e353 100644
--- a/a/autotuner.py
+++ b/b/autotuner.py
@@ -4,6 +4,7 @@ import builtins
import os
import time
import inspect
+from functools import cached_property
from typing import Dict, Tuple, List, Optional
from .jit import KernelInterface
@@ -94,6 +95,7 @@ class Autotuner(KernelInterface):
while not inspect.isfunction(self.base_fn):
self.base_fn = self.base_fn.fn
+ self._do_bench = do_bench
self.num_warmups = warmup
self.num_reps = rep
self.use_cuda_graph = use_cuda_graph
@@ -107,7 +109,7 @@ class Autotuner(KernelInterface):
stacklevel=1)
if use_cuda_graph:
from ..testing import do_bench_cudagraph
- self.do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
+ self._do_bench = lambda kernel_call, quantiles: do_bench_cudagraph(
kernel_call,
rep=rep if rep is not None else 100,
quantiles=quantiles,
@@ -115,7 +117,7 @@ class Autotuner(KernelInterface):
return
import triton.testing
- self.do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
+ self._do_bench = lambda kernel_call, quantiles: triton.testing.do_bench(
kernel_call,
warmup=warmup if warmup is not None else 25,
rep=rep if rep is not None else 100,
@@ -123,10 +125,11 @@ class Autotuner(KernelInterface):
)
return
- if do_bench is None:
- self.do_bench = driver.active.get_benchmarker()
- else:
- self.do_bench = do_bench
+ @cached_property
+ def do_bench(self):
+ if self._do_bench is None:
+ return driver.active.get_benchmarker()
+ return self._do_bench
def _bench(self, *args, config, **meta):
from ..compiler.errors import CompileTimeAssertionFailure