import argparse import time import torch def main(): p = argparse.ArgumentParser() p.add_argument("--N", type=int, default=100_000_000) p.add_argument("--iters", type=int, default=10) p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") args = p.parse_args() torch.manual_seed(0) N = args.N iters = args.iters device = torch.device(args.device) A = torch.empty(N, dtype=torch.float32).uniform_(-1, 1) B = torch.empty(N, dtype=torch.float32).uniform_(-1, 1) C = torch.empty(N, dtype=torch.float32).uniform_(-1, 1) A = A.to(device) B = B.to(device) C = C.to(device) # warmup D = A + B E = D * C + B _ = E.sum() if device.type == "cuda": torch.cuda.synchronize() t0 = time.perf_counter() total = 0.0 for _ in range(iters): D = A + B E = D * C + B s = E.sum() if device.type == "cuda": torch.cuda.synchronize() total += float(s.item()) t1 = time.perf_counter() ms = (t1 - t0) * 1000.0 bytes_per_iter = 7.0 * N * 4.0 gbps = (bytes_per_iter * iters) / (t1 - t0) / 1e9 print("PyTorch baseline") print(f"device={device}") print(f"N={N} iters={iters}") print(f"time_ms={ms:.3f}") print(f"throughput_GBps={gbps:.3f}") print(f"result={total}") if __name__ == "__main__": main()