Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active February 16, 2022 18:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfdev-5/951e46d48edf522400b69ae594715cbc to your computer and use it in GitHub Desktop.
Save vfdev-5/951e46d48edf522400b69ae594715cbc to your computer and use it in GitHub Desktop.
functorch make_functional + grad checks vs pytorch computed grads (torchvision models)
Torch: 1.12.0.dev20220215+cu111
torchvision: 0.13.0.dev20220215+cu111
Functorch: 0.2.0a0+c9d03e8
-- Check fasterrcnn_resnet50_fpn model
-- Check fasterrcnn_mobilenet_v3_large_320_fpn model
-- Check fasterrcnn_mobilenet_v3_large_fpn model
-- Check maskrcnn_resnet50_fpn model
-- Check keypointrcnn_resnet50_fpn model
-- Check retinanet_resnet50_fpn model
-- Check ssd300_vgg16 model
-- Check ssdlite320_mobilenet_v3_large model
-- Check fcos_resnet50_fpn model
import torch
import torch.nn as nn
import torchvision
import torchvision.models.detection as models
from functorch.version import __version__ as ft_version
from functorch import make_functional_with_buffers, make_functional, grad
tested_models = []
for model_name in models.__dict__:
if model_name.startswith("_") or model_name[0].isupper():
continue
if not callable(models.__dict__[model_name]):
continue
tested_models.append(model_name)
def compute_grads(model, images, targets):
# Fix seed to fix dropout
torch.manual_seed(0)
loss_dict = model(images, targets)
loss = sum(loss for loss in loss_dict.values())
loss.backward()
device = 'cpu'
def check_grads_model(model_name, device):
batch_size = 8
torch.manual_seed(0)
size = (224, 224)
model = models.__dict__[model_name](num_classes=10, pretrained=False, pretrained_backbone=False)
model = model.to(device)
images = [torch.rand(3, 224, 224) for _ in range(4)]
targets = [
{
"boxes": torch.tensor([[10 + i, 10 + i, 20 + i, 20 + i], [20 + i, 20 + i, 30 + i, 30 + i]]),
"labels": torch.tensor([(1 + i) % 10, (2 + i) % 10]),
"keypoints": torch.rand(2, 12, 3),
"masks": torch.randint(0, 1, size=(2, 224, 224), dtype=torch.uint8),
}
for i in range(4)
]
has_buffers = len(list(model.buffers())) > 0
gen_make_functional_fn = None
if has_buffers:
gen_make_functional_fn = make_functional_with_buffers
else:
gen_make_functional_fn = make_functional
output = gen_make_functional_fn(model)
if has_buffers:
func_model, weights, buffers = output
else:
func_model, weights = output
buffers = None
def compute_loss_ft(weights, buffers, images, targets):
# Fix seed to fix dropout
torch.manual_seed(0)
if buffers is None:
loss_dict = func_model(weights, images, targets)
else:
loss_dict = func_model(weights, buffers, images, targets)
loss = sum(loss for loss in loss_dict.values())
return loss
compute_grad = grad(compute_loss_ft)
w_grads = compute_grad(weights, buffers, images, targets)
compute_grads(model, images, targets)
assert len(w_grads) == len(list(model.parameters()))
for wg, (n, p) in zip(w_grads, model.named_parameters()):
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}"
print("")
print("Torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("Functorch:", ft_version)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_grads_model(model_name, device=device)
except AssertionError as e:
print(e)
import torch
import torch.nn as nn
import torchvision
import torchvision.models.segmentation as models
from functorch.version import __version__ as ft_version
from functorch import make_functional_with_buffers, make_functional, grad
tested_models = []
for model_name in models.__dict__:
if model_name.startswith("_") or model_name[0].isupper():
continue
if not callable(models.__dict__[model_name]):
continue
tested_models.append(model_name)
criterion = nn.CrossEntropyLoss()
def compute_grads(model, image, target):
# Fix seed to fix dropout
torch.manual_seed(0)
output = model(image)
loss = criterion(output["out"], target)
loss.backward()
device = 'cpu'
def check_grads_model(model_name, device):
batch_size = 8
torch.manual_seed(0)
size = (224, 224)
model = models.__dict__[model_name](num_classes=10, pretrained=False, pretrained_backbone=False)
model = model.to(device)
images = torch.randn(batch_size, 3, *size, device=device)
targets = torch.randint(0, 10, (batch_size, ) + size, device=device)
has_buffers = len(list(model.buffers())) > 0
gen_make_functional_fn = None
if has_buffers:
gen_make_functional_fn = make_functional_with_buffers
else:
gen_make_functional_fn = make_functional
output = gen_make_functional_fn(model)
if has_buffers:
func_model, weights, buffers = output
else:
func_model, weights = output
buffers = None
def compute_loss_ft(weights, buffers, image, target):
# Fix seed to fix dropout
torch.manual_seed(0)
if buffers is None:
output = func_model(weights, image)
else:
output = func_model(weights, buffers, image)
loss = criterion(output["out"], target)
return loss
compute_grad = grad(compute_loss_ft)
w_grads = compute_grad(weights, buffers, images, targets)
compute_grads(model, images, targets)
assert len(w_grads) == len(list(model.parameters()))
for wg, (n, p) in zip(w_grads, model.named_parameters()):
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}"
print("")
print("Torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("Functorch:", ft_version)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_grads_model(model_name, device=device)
except AssertionError as e:
print(e)
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from functorch.version import __version__ as ft_version
from functorch import make_functional_with_buffers, make_functional, grad
tested_models = []
for model_name in models.__dict__:
if model_name.startswith("_") or model_name[0].isupper():
continue
if model_name in ["segmentation", "detection", "video", "quantization", "feature_extraction"]:
continue
if not callable(models.__dict__[model_name]):
continue
tested_models.append(model_name)
criterion = nn.CrossEntropyLoss()
def compute_grads(model, image, target):
# Fix seed to fix dropout
torch.manual_seed(0)
output = model(image)
loss = criterion(output, target)
loss.backward()
device = 'cpu'
def check_grads_model(model_name, device):
batch_size = 8
torch.manual_seed(0)
if model_name == "inception_v3":
size = (299, 299)
kwargs = {"aux_logits": False}
elif model_name == "googlenet":
size = (224, 224)
kwargs = {"aux_logits": False}
else:
size = (224, 224)
kwargs = {}
model = models.__dict__[model_name](num_classes=10, **kwargs)
model = model.to(device)
images = torch.randn(batch_size, 3, *size, device=device)
targets = torch.randint(0, 10, (batch_size,), device=device)
has_buffers = len(list(model.buffers())) > 0
gen_make_functional_fn = None
if has_buffers:
gen_make_functional_fn = make_functional_with_buffers
else:
gen_make_functional_fn = make_functional
output = gen_make_functional_fn(model)
if has_buffers:
func_model, weights, buffers = output
else:
func_model, weights = output
buffers = None
def compute_loss_ft(weights, buffers, image, target):
# Fix seed to fix dropout
torch.manual_seed(0)
if buffers is None:
output = func_model(weights, image)
else:
output = func_model(weights, buffers, image)
loss = criterion(output, target)
return loss
compute_grad = grad(compute_loss_ft)
w_grads = compute_grad(weights, buffers, images, targets)
compute_grads(model, images, targets)
assert len(w_grads) == len(list(model.parameters()))
for wg, (n, p) in zip(w_grads, model.named_parameters()):
assert p.grad.allclose(wg, atol=1e-5), f"grad mismatch for {n}: {p.grad.mean()} vs {wg.mean()}"
print("")
print("Torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("Functorch:", ft_version)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_grads_model(model_name, device=device)
except AssertionError as e:
print(e)
Torch: 1.12.0.dev20220209+cu111
torchvision: 0.13.0.dev20220209+cu111
Functorch: 0.2.0a0+da6ec37
-- Check alexnet model
-- Check convnext_tiny model
-- Check convnext_small model
-- Check convnext_base model
-- Check convnext_large model
-- Check resnet18 model
-- Check resnet34 model
-- Check resnet50 model
-- Check resnet101 model
-- Check resnet152 model
-- Check resnext50_32x4d model
-- Check resnext101_32x8d model
-- Check wide_resnet50_2 model
-- Check wide_resnet101_2 model
-- Check vgg11 model
-- Check vgg11_bn model
-- Check vgg13 model
-- Check vgg13_bn model
-- Check vgg16 model
-- Check vgg16_bn model
-- Check vgg19_bn model
-- Check vgg19 model
-- Check squeezenet1_0 model
-- Check squeezenet1_1 model
-- Check inception_v3 model
-- Check densenet121 model
-- Check densenet169 model
-- Check densenet201 model
-- Check densenet161 model
-- Check googlenet model
-- Check mobilenet_v2 model
-- Check mobilenet_v3_large model
-- Check mobilenet_v3_small model
-- Check mnasnet0_5 model
-- Check mnasnet0_75 model
-- Check mnasnet1_0 model
-- Check mnasnet1_3 model
-- Check shufflenet_v2_x0_5 model
-- Check shufflenet_v2_x1_0 model
-- Check shufflenet_v2_x1_5 model
-- Check shufflenet_v2_x2_0 model
-- Check efficientnet_b0 model
grad mismatch for features.1.0.block.0.0.weight: -0.12117128819227219 vs -0.12117135524749756
-- Check efficientnet_b1 model
grad mismatch for features.0.0.weight: -0.008784545585513115 vs -0.008784502744674683
-- Check efficientnet_b2 model
grad mismatch for features.0.0.weight: -0.17733778059482574 vs -0.17733803391456604
-- Check efficientnet_b3 model
grad mismatch for features.0.0.weight: 0.056723516434431076 vs 0.05672360211610794
-- Check efficientnet_b4 model
grad mismatch for features.0.0.weight: -0.008806591853499413 vs -0.008806349709630013
-- Check efficientnet_b5 model
grad mismatch for features.0.0.weight: 0.18076033890247345 vs 0.1807602196931839
-- Check efficientnet_b6 model
grad mismatch for features.0.0.weight: 0.1274760216474533 vs 0.12747612595558167
-- Check efficientnet_b7 model
grad mismatch for features.0.0.weight: -0.030335871502757072 vs -0.030335767194628716
-- Check regnet_y_400mf model
-- Check regnet_y_800mf model
-- Check regnet_y_1_6gf model
-- Check regnet_y_3_2gf model
-- Check regnet_y_8gf model
-- Check regnet_y_16gf model
-- Check regnet_y_32gf model
-- Check regnet_y_128gf model
-- Check regnet_x_400mf model
-- Check regnet_x_800mf model
-- Check regnet_x_1_6gf model
-- Check regnet_x_3_2gf model
-- Check regnet_x_8gf model
-- Check regnet_x_16gf model
-- Check regnet_x_32gf model
-- Check vit_b_16 model
-- Check vit_b_32 model
-- Check vit_l_16 model
-- Check vit_l_32 model
Torch: 1.12.0.dev20220209+cu111
torchvision: 0.13.0.dev20220209+cu111
Functorch: 0.2.0a0+da6ec37
-- Check fcn_resnet50 model
-- Check fcn_resnet101 model
-- Check deeplabv3_resnet50 model
-- Check deeplabv3_resnet101 model
-- Check deeplabv3_mobilenet_v3_large model
-- Check lraspp_mobilenet_v3_large model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment