Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created March 21, 2021 10:47
Show Gist options
  • Save SannaPersson/cb17429c919fac0239a5b18f5082bbfc to your computer and use it in GitHub Desktop.
Save SannaPersson/cb17429c919fac0239a5b18f5082bbfc to your computer and use it in GitHub Desktop.
def test():
num_classes = 20
model = YOLOv3(num_classes=num_classes)
img_size = 416
x = torch.randn((2, 3, img_size, img_size))
out = model(x)
assert out[0].shape == (2, 3, img_size//32, img_size//32, 5 + num_classes)
assert out[1].shape == (2, 3, img_size//16, img_size//16, 5 + num_classes)
assert out[2].shape == (2, 3, img_size//8, img_size//8, 5 + num_classes)
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment