|
from argparse import ArgumentParser |
|
import time |
|
import os |
|
import torch |
|
import torchvision.transforms as transforms |
|
from contextlib import nullcontext |
|
import json |
|
from models import get_model |
|
|
|
|
|
parser = ArgumentParser(description="Train an EBC model.") |
|
parser.add_argument("--model_info_path", type=str, required=True, help="Path to the model information file.") |
|
|
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for the model.") |
|
parser.add_argument("--height", type=int, default=768, help="Height of the input image.") |
|
parser.add_argument("--width", type=int, default=1024, help="Width of the input image.") |
|
|
|
parser.add_argument("--num_iterations", type=int, default=200, help="Number of iterations to run the model.") |
|
parser.add_argument("--num_warmup", type=int, default=20, help="Dispose of the first N iterations.") |
|
|
|
parser.add_argument("--device", type=str, choices=["cpu", "cuda", "mps"], help="Device to run the model on. Options are 'cpu', 'cuda', or 'mps'.") |
|
parser.add_argument("--amp", action="store_true", help="Enable autocast mixed precision (fp16/bf16).") |
|
parser.add_argument("--half", action="store_true", help="Use half precision for the model.") |
|
parser.add_argument("--channels_last", action="store_true", help="Use NHWC memory format (recommended for CUDA).") |
|
parser.add_argument("--compile", action="store_true", help="Enable torch.compile if available.") |
|
parser.add_argument("--threads", type=int, default=None, help="torch.set_num_threads(threads) for CPU") |
|
parser.add_argument("--sleep_time", type=float, default=0.0, help="Seconds to sleep after *each* iteration (cool-down).") |
|
|
|
_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
|
|
def _dummy_input(bs, h, w, device, half, channels_last): |
|
x = torch.rand(bs, 3, h, w, device=device) |
|
x = _normalize(x) |
|
if half: |
|
x = x.half() |
|
if channels_last: |
|
x = x.to(memory_format=torch.channels_last) |
|
return x |
|
|
|
|
|
def _maybe_sync(dev): |
|
if dev.type == "cuda": |
|
torch.cuda.synchronize() |
|
|
|
|
|
@torch.inference_mode() |
|
def benchmark( |
|
model: torch.nn.Module, |
|
inp: torch.Tensor, |
|
warmup: int, |
|
steps: int, |
|
amp: bool, |
|
sleep_time: float = 0.0 |
|
): |
|
cm = torch.autocast(device_type=inp.device.type) if amp else nullcontext() |
|
|
|
|
|
for _ in range(warmup): |
|
with cm: |
|
_ = model(inp) |
|
_maybe_sync(inp.device) |
|
|
|
|
|
total_time = 0.0 |
|
for _ in range(steps): |
|
tic = time.perf_counter() |
|
with cm: |
|
_ = model(inp) |
|
|
|
toc = time.perf_counter() |
|
total_time += toc - tic |
|
|
|
if sleep_time > 0: |
|
time.sleep(sleep_time) |
|
|
|
_maybe_sync(inp.device) |
|
|
|
fps = steps / total_time |
|
return fps, total_time / steps |
|
|
|
|
|
def main(args): |
|
assert os.path.isfile(args.model_info_path), \ |
|
f"{args.model_info_path} not found" |
|
|
|
model = get_model(model_info_path=args.model_info_path) |
|
model.eval() |
|
|
|
if args.channels_last: |
|
model = model.to(memory_format=torch.channels_last) |
|
if args.half: |
|
model = model.half() |
|
|
|
device = torch.device(args.device) |
|
model = model.to(device) |
|
|
|
if args.compile and hasattr(torch, "compile"): |
|
model = torch.compile(model, mode="reduce-overhead") |
|
|
|
if args.threads: |
|
torch.set_num_threads(args.threads) |
|
torch.set_num_interop_threads(1) |
|
|
|
inp = _dummy_input( |
|
args.batch_size, |
|
args.height, |
|
args.width, |
|
device, |
|
args.half, |
|
args.channels_last |
|
) |
|
|
|
fps, t_avg = benchmark( |
|
model, |
|
inp, |
|
warmup=args.num_warmup, |
|
steps=args.num_iterations, |
|
amp=args.amp, |
|
sleep_time=args.sleep_time |
|
) |
|
|
|
cfg = vars(args) |
|
cfg.pop("model_info_path") |
|
print(json.dumps(cfg, indent=2)) |
|
print(f"\nAverage latency: {t_avg*1000:6.2f} ms | FPS: {fps:,.2f}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(parser.parse_args()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|