import collections import math import re from typing import Any, Dict, Sequence import torch import triton from .diff_engine import DiffCase def make_fwd_key(batch_size, seq_len, dim): return f"forward : ({batch_size}, {seq_len}, {dim})" def make_bwd_key(batch_size, seq_len, dim): return f"backward : ({batch_size}, {seq_len}, {dim})" def parse_config_string(config_str): match = re.match(r"(\w+)\s*:\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)", config_str) if not match: raise ValueError(f"Invalid config string: {config_str}") _, bs, sl, d = match.groups() return int(bs), int(sl), int(d) def make_fwd_benchmark_for_case( *, case: DiffCase, configs: Sequence[tuple[int, int, int]], plot_name: str, ylabel: str = "us", line_vals=("naive", "cuda", "speedup"), line_names: Dict[str, str] | None = None, dtype=torch.bfloat16, eps: float = 1e-6, time_unit_scale: float = 1000, ): timings_ms = collections.defaultdict(dict) line_vals = list(line_vals) line_names = line_names or {v: v.title() for v in line_vals} x_vals = [list(_) for _ in configs] @triton.testing.perf_report( triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"], x_vals=x_vals, line_arg="provider", line_vals=line_vals, line_names=[line_names[v] for v in line_vals], ylabel=ylabel, plot_name=plot_name, args={})) def bench(dim, batch_size, seq_len, provider): key = make_fwd_key(dim, batch_size, seq_len) I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) if provider == "speedup": return timings_ms["naive"][key] / timings_ms["cuda"][key] obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) run = lambda: case.forward(obj, I) ms = triton.testing.do_bench(run) timings_ms[provider][key] = ms return time_unit_scale * ms return bench def make_fwd_benchmark_plot_for_case( *, case: DiffCase, configs: Sequence[tuple[int, int, int]], plot_name: str, ylabel: str = "Relative Speedup", line_vals=("naive", "cuda"), line_names: Dict[str, str] | None = None, dtype=torch.bfloat16, eps: float = 1e-6, ): timings_ms = collections.defaultdict(dict) spdup_ratio = list() line_vals = list(line_vals) line_names = line_names or {v: v.title() for v in line_vals} x_vals = [make_fwd_key(*_) for _ in configs] x_vals.append("Geometric Mean") @triton.testing.perf_report( triton.testing.Benchmark(x_names=["config"], x_vals=x_vals, line_arg="provider", line_vals=line_vals, line_names=[line_names[v] for v in line_vals], ylabel=ylabel, plot_name=plot_name, args={})) def bench(config, provider): if config == "Geometric Mean": if provider == "cuda": return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2) else: return 1.00 batch_size, seq_len, dim = parse_config_string(config) I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) run = lambda: case.forward(obj, I) ms = triton.testing.do_bench(run) timings_ms[provider][config] = ms if provider == "cuda": ratio = timings_ms["naive"][config] / timings_ms["cuda"][config] spdup_ratio.append(ratio) return round(ratio, 2) else: return 1.00 return bench def make_bwd_benchmark_for_case( *, case: DiffCase, configs: Sequence[tuple[int, int, int]], plot_name: str, ylabel: str = "us", line_vals=("naive", "cuda", "speedup"), line_names: Dict[str, str] | None = None, dtype=torch.bfloat16, eps: float = 1e-6, time_unit_scale: float = 1000, ): timings_ms = collections.defaultdict(dict) line_vals = list(line_vals) line_names = line_names or {v: v.title() for v in line_vals} x_vals = [list(_) for _ in configs] @triton.testing.perf_report( triton.testing.Benchmark(x_names=["dim", "batch_size", "seq_len"], x_vals=x_vals, line_arg="provider", line_vals=line_vals, line_names=[line_names[v] for v in line_vals], ylabel=ylabel, plot_name=plot_name, args={})) def bench(dim, batch_size, seq_len, provider): key = make_bwd_key(dim, batch_size, seq_len) I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) if provider == "speedup": return timings_ms["naive"][key] / timings_ms["cuda"][key] obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) y = case.forward(obj, I) gin = list(case.grad_inputs(I)) + list(obj.parameters()) if isinstance(y, torch.Tensor): g = [torch.randn_like(y)] else: g = [torch.randn_like(r) for r in y] run = lambda: torch.autograd.grad(y, gin, g, retain_graph=True, create_graph=False, allow_unused=False) ms = triton.testing.do_bench(run) timings_ms[provider][key] = ms return time_unit_scale * ms return bench def make_bwd_benchmark_plot_for_case( *, case: DiffCase, configs: Sequence[tuple[int, int, int]], plot_name: str, ylabel: str = "Relative Speedup", line_vals=("naive", "cuda"), line_names: Dict[str, str] | None = None, dtype=torch.bfloat16, eps: float = 1e-6, ): timings_ms = collections.defaultdict(dict) spdup_ratio = list() line_vals = list(line_vals) line_names = line_names or {v: v.title() for v in line_vals} x_vals = [make_bwd_key(*_) for _ in configs] x_vals.append("Geometric Mean") @triton.testing.perf_report( triton.testing.Benchmark(x_names=["config"], x_vals=x_vals, line_arg="provider", line_vals=line_vals, line_names=[line_names[v] for v in line_vals], ylabel=ylabel, plot_name=plot_name, args={})) def bench(config, provider): if config == "Geometric Mean": if provider == "cuda": return round(math.prod(spdup_ratio)**(1 / len(spdup_ratio)), 2) else: return 1.00 batch_size, seq_len, dim = parse_config_string(config) I = case.build_inputs(batch_size, seq_len, dim, dtype, eps) obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I) y = case.forward(obj, I) gin = list(case.grad_inputs(I)) + list(obj.parameters()) if isinstance(y, torch.Tensor): g = [torch.randn_like(y)] else: g = [torch.randn_like(r) for r in y] run = lambda: torch.autograd.grad(y, gin, g, retain_graph=True, create_graph=False, allow_unused=False) ms = triton.testing.do_bench(run) timings_ms[provider][config] = ms if provider == "cuda": ratio = timings_ms["naive"][config] / timings_ms["cuda"][config] spdup_ratio.append(ratio) return round(ratio, 2) else: return 1.00 return bench