|
import argparse |
|
import glob |
|
import importlib |
|
import itertools |
|
import os |
|
|
|
import torch |
|
from common.bench_framework import (make_bwd_benchmark_for_case, |
|
make_bwd_benchmark_plot_for_case, |
|
make_fwd_benchmark_for_case, |
|
make_fwd_benchmark_plot_for_case) |
|
from common.diff_engine import DiffCase, calculate_diff |
|
|
|
|
|
def make_title_tag(): |
|
if torch.cuda.is_available(): |
|
dev_name = torch.cuda.get_device_name(0) |
|
else: |
|
dev_name = "CPU" |
|
|
|
torch_ver = torch.__version__ |
|
|
|
return f"[{dev_name} | torch {torch_ver}]" |
|
|
|
|
|
def plot_result(r_path): |
|
import matplotlib.pyplot as plt |
|
import pandas as pd |
|
df = pd.read_csv(r_path + ".csv") |
|
plt.figure(figsize=(12, 6)) |
|
ax = df.plot(x="config", y=["Naive", "Cuda"], kind="bar", ax=plt.gca()) |
|
ax.set_title("Speedup over torch (higher is better)\n" + make_title_tag(), |
|
fontsize=14, |
|
fontweight="bold") |
|
ax.set_ylabel("Relative Speedup", fontsize=14) |
|
ax.set_xlabel("") |
|
plt.xticks(rotation=45, fontsize=12, ha="right", rotation_mode="anchor") |
|
for container in ax.containers: |
|
labels = [f"x{v.get_height():.2f}" for v in container] |
|
ax.bar_label(container, labels=labels, label_type="edge", fontsize=10) |
|
plt.tight_layout() |
|
plt.savefig(r_path + ".png", bbox_inches="tight") |
|
|
|
|
|
def main(): |
|
ap = argparse.ArgumentParser() |
|
ap.add_argument("--case", |
|
choices=["rms", "add_rms", "poly", "mul_poly"], |
|
required=True) |
|
ap.add_argument("--plot", action="store_true") |
|
ap.add_argument( |
|
"--save-path", |
|
type=str, |
|
default="./configs/", |
|
help="Path to save benchmark results", |
|
) |
|
args = ap.parse_args() |
|
|
|
torch.set_default_device("cuda") |
|
mod = importlib.import_module(f"cases.{args.case}") |
|
case: DiffCase = mod.CASE |
|
|
|
calculate_diff( |
|
case, |
|
batch_size=2, |
|
seq_len=128, |
|
hidden_size=4096, |
|
) |
|
|
|
save_dir = os.path.join(args.save_path, args.case) |
|
if args.plot: |
|
batch_size_range = [1] |
|
seq_length_range = [4096, 8192, 16384] |
|
dim = [8192, 16384] if "poly" in args.case else [2048, 4096] |
|
configs = list( |
|
itertools.product(batch_size_range, seq_length_range, dim)) |
|
plot_name = f"plot_{args.case}-fwd-perf" |
|
bench = make_fwd_benchmark_plot_for_case( |
|
case=case, |
|
configs=configs, |
|
plot_name=plot_name, |
|
line_names={ |
|
"naive": "Naive", |
|
"cuda": "Cuda", |
|
}, |
|
) |
|
bench.run(print_data=True, save_path=save_dir) |
|
plot_result(os.path.join(save_dir, plot_name)) |
|
|
|
plot_name = f"plot_{args.case}-bwd-perf" |
|
bench = make_bwd_benchmark_plot_for_case( |
|
case=case, |
|
configs=configs, |
|
plot_name=plot_name, |
|
line_names={ |
|
"naive": "Naive", |
|
"cuda": "Cuda", |
|
}, |
|
) |
|
bench.run(print_data=True, save_path=save_dir) |
|
plot_result(os.path.join(save_dir, plot_name)) |
|
for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( |
|
os.path.join(save_dir, "*.csv")): |
|
os.remove(f) |
|
else: |
|
batch_size_range = [2**i for i in range(0, 4, 1)] |
|
seq_length_range = [2**i for i in range(10, 14, 1)] |
|
dim = [8192, 16384] if "poly" in args.case else [2048, 4096] |
|
configs = list( |
|
itertools.product(dim, batch_size_range, seq_length_range)) |
|
|
|
bench = make_fwd_benchmark_for_case( |
|
case=case, |
|
configs=configs, |
|
plot_name=f"{args.case}-fwd-perf", |
|
line_names={ |
|
"naive": "Naive", |
|
"cuda": "Cuda", |
|
"speedup": "SpeedUp" |
|
}, |
|
) |
|
|
|
bench.run(print_data=True, save_path=save_dir) |
|
|
|
bench = make_bwd_benchmark_for_case( |
|
case=case, |
|
configs=configs, |
|
plot_name=f"{args.case}-bwd-perf", |
|
line_names={ |
|
"naive": "Naive", |
|
"cuda": "Cuda", |
|
"speedup": "SpeedUp" |
|
}, |
|
) |
|
|
|
bench.run(print_data=True, save_path=save_dir) |
|
for f in glob.glob(os.path.join(save_dir, "*.html")) + glob.glob( |
|
os.path.join(save_dir, "*.png")): |
|
os.remove(f) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|