Skip to content

Instantly share code, notes, and snippets.

@daniil-lyakhov
Created June 19, 2024 12:51
Show Gist options
  • Save daniil-lyakhov/e0d024d3a61bfa0052895cffee0fef8e to your computer and use it in GitHub Desktop.
Save daniil-lyakhov/e0d024d3a61bfa0052895cffee0fef8e to your computer and use it in GitHub Desktop.
Input difference between Yolov8 and Yolov8 compiled by the `torch.compile`
# torch==2.3.1
# ultralytics==8.2.35
import torch
from ultralytics.models.yolo import YOLO
torch.manual_seed(42)
def run_yolo(torch_fx, inputs):
yolo_model = YOLO("yolov8n")
model = yolo_model.model
if torch_fx:
model = torch.compile(model)
return model(inputs)[0]
if __name__ == "__main__":
inputs = torch.rand((1, 3, 640, 640))
print("Run Torch model...")
torch_t = run_yolo(torch_fx=False, inputs=inputs)
print("Run Torch FX model...")
fx_t = run_yolo(torch_fx=True, inputs=inputs)
abs_diff = torch.abs(torch_t - fx_t)
idx = torch.argmax(abs_diff)
print(f"argmax idx: {idx}")
print(f"torch value: {torch_t.view(-1)[idx]}")
print(f"torch FX value: {fx_t.view(-1)[idx]}")
print(f'abs diff: {abs_diff.view(-1)[idx]}')
print(f"torch.quantile(abs_diff, 0.96) {torch.quantile(abs_diff, 0.96)}")
@daniil-lyakhov
Copy link
Author

Run Torch model...
Run Torch FX model...
argmax idx: 25132
torch value: 490.80194091796875
torch FX value: 855.9827270507812
abs diff: 365.1807861328125
torch.quantile(abs_diff, 0.96) 2.0144500732421875

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment