activation / benchmarks /run_cases.py
TaehyunKim
Add fusion (#3)
e5e2eeb unverified
import argparse
import glob
import importlib
import itertools
import os
import torch
from common.bench_framework import (make_bwd_benchmark_for_case,
make_bwd_benchmark_plot_for_case,
make_fwd_benchmark_for_case,
make_fwd_benchmark_plot_for_case)
from common.diff_engine import DiffCase, calculate_diff
def make_title_tag():
if torch.cuda.is_available():
dev_name = torch.cuda.get_device_name(0)
else:
dev_name = "CPU"
torch_ver = torch.__version__
return f"[{dev_name} | torch {torch_ver}]"
def plot_result(r_path):
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv(r_path + ".csv")
plt.figure(figsize=(12, 6))
ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca())
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(),
fontsize=14,
fontweight="bold")
ax.set_ylabel("Relative Speedup", fontsize=14)
ax.set_xlabel("")
plt.xticks(rotation=45, fontsize=12, ha="right", rotation_mode="anchor")
for container in ax.containers:
labels = [f"x{v.get_height():.2f}" for v in container]
ax.bar_label(container, labels=labels, label_type="edge", fontsize=10)
plt.tight_layout()
plt.savefig(r_path + ".png", bbox_inches="tight")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--case",
choices=["rms", "add_rms", "poly", "mul_poly"],
required=True)
ap.add_argument("--plot", action="store_true")
ap.add_argument(
"--save-path",
type=str,
default="./configs/",
help="Path to save benchmark results",
)
args = ap.parse_args()
torch.set_default_device("cuda")
mod = importlib.import_module(f"cases.{args.case}")
case: DiffCase = mod.CASE
calculate_diff(
case,
batch_size=2,
seq_len=128,
hidden_size=4096,
)
save_dir = os.path.join(args.save_path, args.case)
if args.plot:
batch_size_range = [1]
seq_length_range = [4096, 8192, 16384]
dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
configs = list(
itertools.product(batch_size_range, seq_length_range, dim))
plot_name = f"plot_{args.case}-fwd-perf"
bench = make_fwd_benchmark_plot_for_case(
case=case,
configs=configs,
plot_name=plot_name,
line_names={
"naive": "Naive",
"cuda": "Cuda",
},
)
bench.run(print_data=True, save_path=save_dir)
plot_result(os.path.join(save_dir, plot_name))
plot_name = f"plot_{args.case}-bwd-perf"
bench = make_bwd_benchmark_plot_for_case(
case=case,
configs=configs,
plot_name=plot_name,
line_names={
"naive": "Naive",
"cuda": "Cuda",
},
)
bench.run(print_data=True, save_path=save_dir)
plot_result(os.path.join(save_dir, plot_name))
for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
os.path.join(save_dir, "*.csv")):
os.remove(f)
else:
batch_size_range = [2**i for i in range(0, 4, 1)]
seq_length_range = [2**i for i in range(10, 14, 1)]
dim = [8192, 16384] if "poly" in args.case else [2048, 4096]
configs = list(
itertools.product(dim, batch_size_range, seq_length_range))
bench = make_fwd_benchmark_for_case(
case=case,
configs=configs,
plot_name=f"{args.case}-fwd-perf",
line_names={
"naive": "Naive",
"cuda": "Cuda",
"speedup": "SpeedUp"
},
)
bench.run(print_data=True, save_path=save_dir)
bench = make_bwd_benchmark_for_case(
case=case,
configs=configs,
plot_name=f"{args.case}-bwd-perf",
line_names={
"naive": "Naive",
"cuda": "Cuda",
"speedup": "SpeedUp"
},
)
bench.run(print_data=True, save_path=save_dir)
for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob(
os.path.join(save_dir, "*.png")):
os.remove(f)
if __name__ == "__main__":
main()