import argparse import time import torch def main(): p = argparse.ArgumentParser() p.add_argument("--N", type=int, default=100_000_000) p.add_argument("--iters", type=int, default=10) args = p.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.manual_seed(0) N = args.N iters = args.iters # Use pinned memory for faster H2D, ensure contiguity and dtype A_cpu = torch.empty(N, dtype=torch.float32, pin_memory=True).uniform_(-1, 1) B_cpu = torch.empty(N, dtype=torch.float32, pin_memory=True).uniform_(-1, 1) C_cpu = torch.empty(N, dtype=torch.float32, pin_memory=True).uniform_(-1, 1) non_blocking = True if device.type == "cuda" else False A = A_cpu.to(device, non_blocking=non_blocking) B = B_cpu.to(device, non_blocking=non_blocking) C = C_cpu.to(device, non_blocking=non_blocking) # Pre-allocate outputs to reduce allocations D = torch.empty_like(A) E = torch.empty_like(A) # Optional: enable TF32 on Ampere+ for throughput (math still float32-ish) if device.type == "cuda": torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # CUDA Graphs to reduce launch overhead when shapes are static graph = None s_out = torch.zeros((), dtype=torch.float32, device=device) if device.type == "cuda": stream = torch.cuda.Stream() torch.cuda.synchronize() with torch.cuda.stream(stream): D.copy_(A).add_(B, alpha=1.0) # warm E.copy_(D).mul_(C).add_(B) _ = E.sum() torch.cuda.synchronize() g = torch.cuda.CUDAGraph() s_buf = torch.zeros((), dtype=torch.float32, device=device) torch.cuda.synchronize() g.capture_begin() D.copy_(A).add_(B, alpha=1.0) E.copy_(D).mul_(C).add_(B) s_tmp = E.sum() s_buf.copy_(s_tmp) g.capture_end() graph = (g, s_buf) # timing if device.type == "cuda": torch.cuda.synchronize() t0 = time.perf_counter() total = 0.0 for _ in range(iters): if graph is not None: g, s_buf = graph g.replay() if device.type == "cuda": torch.cuda.synchronize() total += float(s_buf.item()) else: # CPU path or no graph D.copy_(A).add_(B, alpha=1.0) E.copy_(D).mul_(C).add_(B) s = E.sum() total += float(s.item()) t1 = time.perf_counter() if device.type == "cuda": torch.cuda.synchronize() ms = (t1 - t0) * 1000.0 bytes_per_iter = 7.0 * N * 4.0 gbps = (bytes_per_iter * iters) / (t1 - t0) / 1e9 print("PyTorch optimized") print(f"device={device}") print(f"N={N} iters={iters}") print(f"time_ms={ms:.3f}") print(f"throughput_GBps={gbps:.3f}") print(f"result={total}") if __name__ == "__main__": main()