Skip to content

Commit

Permalink
Log autotune time in scuba (pytorch#122637)
Browse files Browse the repository at this point in the history
Summary:
This diff
* Refactors triton and autotune caches to be child classes of the original memcache based cache infra
* Swaps scuba table for autotune
* Adds autotune time spent/saved to scuba table

Test Plan:
Local testing using:
```
buck run mode/opt fbcode//caffe2/test/inductor/:max_autotune -- -r test_max_autotune_remote_caching_dynamic_False
```
and
```
TORCH_INDUCTOR_AUTOTUNE_REMOTE_CACHE=1 buck2 run mode/opt //scripts/oulgen:runner
```

Differential Revision: D55332620

Pull Request resolved: pytorch#122637
Approved by: https://github.com/jamesjwu
  • Loading branch information
oulgen authored and pytorchmergebot committed Mar 26, 2024
1 parent 1f5fcb4 commit e61aaab
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
9 changes: 6 additions & 3 deletions test/inductor/test_max_autotune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Owner(s): ["module: inductor"]
import json
import os
import unittest

Expand Down Expand Up @@ -252,18 +253,20 @@ def __init__(self, key, is_autotune=False):
def get(self, filenames):
nonlocal cache
nonlocal num_get
ret = {file: cache[file] for file in filenames if file in cache}
ret = {
file: json.loads(cache[file]) for file in filenames if file in cache
}
num_get += len(ret)
return ret

def put(self, filename, data):
nonlocal cache
nonlocal num_put
cache[filename] = data
cache[filename] = json.dumps(data)
num_put += 1

cache_module = (
"triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend"
"triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend"
if config.is_fbcode()
else "triton.runtime.cache.RedisRemoteCacheBackend"
)
Expand Down
46 changes: 25 additions & 21 deletions torch/_inductor/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os.path
import re
import threading
import time
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -528,10 +529,12 @@ def benchmark_all_configs(self, *args, **kwargs):

def autotune_to_one_config(self, *args, **kwargs):
"""Do the actual autotuning"""
start_time = time.time_ns()
timings = self.benchmark_all_configs(*args, **kwargs)
time_taken_ns = time.time_ns() - start_time
self.launchers = [builtins.min(timings, key=timings.get)]
if self.save_cache_hook:
self.save_cache_hook(self.launchers[0].config)
self.save_cache_hook(self.launchers[0].config, time_taken_ns)

def save_cuda_kernel(self, grid, stream, launcher):
if callable(grid):
Expand Down Expand Up @@ -618,13 +621,15 @@ def benchmark_one_config(config):
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
and "RBLOCK" in launcher.config.kwargs
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have RBLOCK"
start_time = time.time_ns()
best_config = self.coordesc_tuner.autotune(
benchmark_one_config, launcher.config, None
)
time_taken_ns = time.time_ns() - start_time
best_config.found_by_coordesc = True

if self.save_cache_hook:
self.save_cache_hook(best_config, found_by_coordesc=True)
self.save_cache_hook(best_config, time_taken_ns, found_by_coordesc=True)
return config2launcher.get(best_config)

def run(self, *args, grid, stream, **kwargs):
Expand Down Expand Up @@ -800,6 +805,9 @@ def load_cached_autotuning(
if best_config.pop("configs_hash", None) != configs_hash:
return None

# Remove time taken for comparison
best_config.pop("time_taken_ms", None)

if config.coordinate_descent_tuning and best_config.pop("found_by_coordesc", False):
num_warps = best_config.pop("num_warps")
num_stages = best_config.pop("num_stages")
Expand Down Expand Up @@ -850,7 +858,7 @@ def cached_autotune(
"""
configs = unique_configs(configs)
assert len(configs) == 1 or filename
save_cache_hook: Optional[Callable[[Any, Any], Any]]
save_cache_hook: Optional[Callable[[Any, Any, Any], Any]]
inductor_meta = {} if inductor_meta is None else inductor_meta

# on disk caching logic and/or remote caching
Expand All @@ -865,15 +873,13 @@ def cached_autotune(
if should_use_remote_autotune_cache():
backend_hash = inductor_meta.get("backend_hash", None)
if backend_hash is not None:
key = backend_hash + configs_hash + "autotune-best-config"
key = backend_hash + configs_hash + "autotune-best-config-v2"
key = hashlib.sha256(key.encode("utf-8")).hexdigest()

try:
if config.is_fbcode():
remote_cache = (
triton.runtime.fb_memcache.FbMemcacheRemoteCacheBackend(
key, is_autotune=True
)
remote_cache = triton.runtime.fb_memcache.FbMemcacheRemoteAutotuneCacheBackend(
key
)
else:
remote_cache = triton.runtime.cache.RedisRemoteCacheBackend(key)
Expand All @@ -893,26 +899,24 @@ def cached_autotune(
best_config = json.loads(fd.read())
elif remote_cache is not None and remote_cache_key is not None:
cache_outs = remote_cache.get([remote_cache_key])
cache_out = cache_outs.get(remote_cache_key, None)
best_config = json.loads(cache_out) if cache_out else None
best_config = cache_outs.get(remote_cache_key, None)

best_config = load_cached_autotuning(best_config, configs_hash, configs)
if best_config:
configs = [best_config]

def save_cache_hook(cfg, found_by_coordesc=False):
data = json.dumps(
{
**cfg.kwargs,
"num_warps": cfg.num_warps,
"num_stages": cfg.num_stages,
"configs_hash": configs_hash,
"found_by_coordesc": found_by_coordesc,
}
)
def save_cache_hook(cfg, time_taken_ns, found_by_coordesc=False):
data = {
**cfg.kwargs,
"num_warps": cfg.num_warps,
"num_stages": cfg.num_stages,
"configs_hash": configs_hash,
"found_by_coordesc": found_by_coordesc,
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
}
if cache_filename is not None:
with open(cache_filename, "w") as fd:
fd.write(data)
fd.write(json.dumps(data))
if remote_cache is not None and remote_cache_key is not None:
remote_cache.put(remote_cache_key, data)

Expand Down

0 comments on commit e61aaab

Please sign in to comment.