diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 7835d1a27bb0d..25285c2cc04b8 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +import json import os import unittest @@ -252,18 +253,20 @@ def __init__(self, key, is_autotune=False): def get(self, filenames): nonlocal cache nonlocal num_get - ret = {file: cache[file] for file in filenames if file in cache} + ret = { + file: json.loads(cache[file]) for file in filenames if file in cache + } num_get += len(ret) return ret def put(self, filename, data): nonlocal cache nonlocal num_put - cache[filename] = data + cache[filename] = json.dumps(data) num_put += 1 cache_module = ( - "triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend" + "triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend" if config.is_fbcode() else "triton.runtime.cache.RedisRemoteCacheBackend" ) diff --git a/torch/_inductor/triton_heuristics.py b/torch/_inductor/triton_heuristics.py index 35e3ac1e313db..d92189f4202df 100644 --- a/torch/_inductor/triton_heuristics.py +++ b/torch/_inductor/triton_heuristics.py @@ -11,6 +11,7 @@ import os.path import re import threading +import time from enum import auto, Enum from typing import Any, Callable, Dict, List, Optional, Set, Tuple @@ -528,10 +529,12 @@ def benchmark_all_configs(self, *args, **kwargs): def autotune_to_one_config(self, *args, **kwargs): """Do the actual autotuning""" + start_time = time.time_ns() timings = self.benchmark_all_configs(*args, **kwargs) + time_taken_ns = time.time_ns() - start_time self.launchers = [builtins.min(timings, key=timings.get)] if self.save_cache_hook: - self.save_cache_hook(self.launchers[0].config) + self.save_cache_hook(self.launchers[0].config, time_taken_ns) def save_cuda_kernel(self, grid, stream, launcher): if callable(grid): @@ -618,13 +621,15 @@ def benchmark_one_config(config): self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION and "RBLOCK" in launcher.config.kwargs ), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK" + start_time = time.time_ns() best_config = self.coordesc_tuner.autotune( benchmark_one_config, launcher.config, None ) + time_taken_ns = time.time_ns() - start_time best_config.found_by_coordesc = True if self.save_cache_hook: - self.save_cache_hook(best_config, found_by_coordesc=True) + self.save_cache_hook(best_config, time_taken_ns, found_by_coordesc=True) return config2launcher.get(best_config) def run(self, *args, grid, stream, **kwargs): @@ -800,6 +805,9 @@ def load_cached_autotuning( if best_config.pop("configs_hash", None) != configs_hash: return None + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False): num_warps = best_config.pop("num_warps") num_stages = best_config.pop("num_stages") @@ -850,7 +858,7 @@ def cached_autotune( """ configs = unique_configs(configs) assert len(configs) == 1 or filename - save_cache_hook: Optional[Callable[[Any, Any], Any]] + save_cache_hook: Optional[Callable[[Any, Any, Any], Any]] inductor_meta = {} if inductor_meta is None else inductor_meta # on disk caching logic and/or remote caching @@ -865,15 +873,13 @@ def cached_autotune( if should_use_remote_autotune_cache(): backend_hash = inductor_meta.get("backend_hash", None) if backend_hash is not None: - key = backend_hash + configs_hash + "autotune-best-config" + key = backend_hash + configs_hash + "autotune-best-config-v2" key = hashlib.sha256(key.encode("utf-8")).hexdigest() try: if config.is_fbcode(): - remote_cache = ( - triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend( - key, is_autotune=True - ) + remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend( + key ) else: remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key) @@ -893,26 +899,24 @@ def cached_autotune( best_config = json.loads(fd.read()) elif remote_cache is not None and remote_cache_key is not None: cache_outs = remote_cache.get([remote_cache_key]) - cache_out = cache_outs.get(remote_cache_key, None) - best_config = json.loads(cache_out) if cache_out else None + best_config = cache_outs.get(remote_cache_key, None) best_config = load_cached_autotuning(best_config, configs_hash, configs) if best_config: configs = [best_config] - def save_cache_hook(cfg, found_by_coordesc=False): - data = json.dumps( - { - **cfg.kwargs, - "num_warps": cfg.num_warps, - "num_stages": cfg.num_stages, - "configs_hash": configs_hash, - "found_by_coordesc": found_by_coordesc, - } - ) + def save_cache_hook(cfg, time_taken_ns, found_by_coordesc=False): + data = { + **cfg.kwargs, + "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages, + "configs_hash": configs_hash, + "found_by_coordesc": found_by_coordesc, + "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + } if cache_filename is not None: with open(cache_filename, "w") as fd: - fd.write(data) + fd.write(json.dumps(data)) if remote_cache is not None and remote_cache_key is not None: remote_cache.put(remote_cache_key, data)