94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
import argparse
|
|
import time
|
|
import torch
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--N", type=int, default=100_000_000)
|
|
p.add_argument("--iters", type=int, default=10)
|
|
args = p.parse_args()
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
torch.manual_seed(0)
|
|
|
|
N = args.N
|
|
iters = args.iters
|
|
|
|
# Use pinned memory for faster H2D, ensure contiguity and dtype
|
|
A_cpu = torch.empty(N, dtype=torch.float32, pin_memory=True).uniform_(-1, 1)
|
|
B_cpu = torch.empty(N, dtype=torch.float32, pin_memory=True).uniform_(-1, 1)
|
|
C_cpu = torch.empty(N, dtype=torch.float32, pin_memory=True).uniform_(-1, 1)
|
|
|
|
non_blocking = True if device.type == "cuda" else False
|
|
A = A_cpu.to(device, non_blocking=non_blocking)
|
|
B = B_cpu.to(device, non_blocking=non_blocking)
|
|
C = C_cpu.to(device, non_blocking=non_blocking)
|
|
|
|
# Pre-allocate outputs to reduce allocations
|
|
D = torch.empty_like(A)
|
|
E = torch.empty_like(A)
|
|
|
|
# Optional: enable TF32 on Ampere+ for throughput (math still float32-ish)
|
|
if device.type == "cuda":
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# CUDA Graphs to reduce launch overhead when shapes are static
|
|
graph = None
|
|
if device.type == "cuda":
|
|
stream = torch.cuda.Stream()
|
|
torch.cuda.synchronize()
|
|
with torch.cuda.stream(stream):
|
|
D.copy_(A).add_(B, alpha=1.0) # warm
|
|
E.copy_(D).mul_(C).add_(B)
|
|
_ = E.sum()
|
|
torch.cuda.synchronize()
|
|
|
|
g = torch.cuda.CUDAGraph()
|
|
s_buf = torch.zeros((), dtype=torch.float32, device=device)
|
|
torch.cuda.synchronize()
|
|
g.capture_begin()
|
|
D.copy_(A).add_(B, alpha=1.0)
|
|
E.copy_(D).mul_(C).add_(B)
|
|
s_tmp = E.sum()
|
|
s_buf.copy_(s_tmp)
|
|
g.capture_end()
|
|
graph = (g, s_buf)
|
|
|
|
# timing
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
|
|
total = 0.0
|
|
for _ in range(iters):
|
|
if graph is not None:
|
|
g, s_buf = graph
|
|
g.replay()
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
total += float(s_buf.item())
|
|
else:
|
|
# CPU path or no graph
|
|
D.copy_(A).add_(B, alpha=1.0)
|
|
E.copy_(D).mul_(C).add_(B)
|
|
s = E.sum()
|
|
total += float(s.item())
|
|
|
|
t1 = time.perf_counter()
|
|
if device.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
|
|
ms = (t1 - t0) * 1000.0
|
|
bytes_per_iter = 7.0 * N * 4.0
|
|
gbps = (bytes_per_iter * iters) / (t1 - t0) / 1e9
|
|
|
|
print("PyTorch optimized")
|
|
print(f"device={device}")
|
|
print(f"N={N} iters={iters}")
|
|
print(f"time_ms={ms:.3f}")
|
|
print(f"throughput_GBps={gbps:.3f}")
|
|
print(f"result={total}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|