Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vfdev-5/99d07eab50d77493f52c5d6b1fa6dd84 to your computer and use it in GitHub Desktop.
Save vfdev-5/99d07eab50d77493f52c5d6b1fa6dd84 to your computer and use it in GitHub Desktop.
functorch combine_state_for_ensemble + vmap checks vs for-loop pytorch computations (torchvision models)
import torch
import torch.nn as nn
import torchvision
import torchvision.models.detection as tv_models
import functorch
from functorch import combine_state_for_ensemble, vmap
tested_models = []
for model_name in tv_models.__dict__:
if model_name.startswith("_") or model_name[0].isupper():
continue
if not callable(tv_models.__dict__[model_name]):
continue
tested_models.append(model_name)
def compute_same_batch_preds(models, batch):
# Fix seed to fix dropout
torch.manual_seed(0)
output = [model(batch) for model in models]
return output
num_models = 5
device = 'cpu'
def check_ensembling_model(model_name, device):
batch_size = 8
torch.manual_seed(0)
size = (224, 224)
models = [
tv_models.__dict__[model_name](num_classes=10, pretrained=False, pretrained_backbone=False).to(device)
for _ in range(num_models)
]
# Disable training mode:
[m.eval() for m in models]
fmodel, params, buffers = combine_state_for_ensemble(models)
images = [torch.rand(3, 224, 224) for _ in range(4)]
ref_preds = compute_same_batch_preds(models, images)
ref_preds = torch.stack(ref_preds)
out_preds = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, images)
assert torch.allclose(out_preds, ref_preds, atol=1e-3, rtol=1e-5), f"Output does not match reference: {out_preds.mean()} vs {ref_preds.mean()}"
print("")
print("Torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("Functorch:", functorch.__version__)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_ensembling_model(model_name, device=device)
except AssertionError as e:
print(e)
import torch
import torch.nn as nn
import torchvision
import torchvision.models.segmentation as tv_models
import functorch
from functorch import combine_state_for_ensemble, vmap
tested_models = []
for model_name in tv_models.__dict__:
if model_name.startswith("_") or model_name[0].isupper():
continue
if not callable(tv_models.__dict__[model_name]):
continue
tested_models.append(model_name)
def compute_same_batch_preds(models, batch):
# Fix seed to fix dropout
torch.manual_seed(0)
output = [model(batch)["out"] for model in models]
return output
num_models = 5
device = 'cpu'
def check_ensembling_model(model_name, device):
batch_size = 8
torch.manual_seed(0)
size = (224, 224)
models = [
tv_models.__dict__[model_name](num_classes=10, pretrained=False, pretrained_backbone=False).to(device)
for _ in range(num_models)
]
# Disable training mode:
[m.eval() for m in models]
fmodel, params, buffers = combine_state_for_ensemble(models)
images = torch.randn(batch_size, 3, *size, device=device)
ref_preds = compute_same_batch_preds(models, images)
ref_preds = torch.stack(ref_preds)
out_preds = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, images)
assert torch.allclose(out_preds, ref_preds, atol=1e-3, rtol=1e-5), f"Output does not match reference: {out_preds.mean()} vs {ref_preds.mean()}"
print("")
print("Torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("Functorch:", functorch.__version__)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_ensembling_model(model_name, device=device)
except AssertionError as e:
print(e)
Torch: 1.12.0.dev20220215+cu111
torchvision: 0.13.0.dev20220215+cu111
Functorch: 0.2.0a0+c9d03e8
-- Check alexnet model
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1279: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::dropout. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /ft/functorch/csrc/BatchedFallback.cpp:82.)
return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
-- Check convnext_tiny model
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1241: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::as_strided_. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /ft/functorch/csrc/BatchedFallback.cpp:82.)
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
-- Check convnext_small model
-- Check convnext_base model
-- Check convnext_large model
-- Check resnet18 model
-- Check resnet34 model
-- Check resnet50 model
-- Check resnet101 model
Output does not match reference: 2613.704345703125 vs 2613.70458984375
-- Check resnet152 model
Output does not match reference: 3993238.0 vs 3993238.5
-- Check resnext50_32x4d model
-- Check resnext101_32x8d model
-- Check wide_resnet50_2 model
-- Check wide_resnet101_2 model
Output does not match reference: 5651.0712890625 vs 5651.06982421875
-- Check vgg11 model
-- Check vgg11_bn model
-- Check vgg13 model
-- Check vgg13_bn model
-- Check vgg16 model
-- Check vgg16_bn model
-- Check vgg19_bn model
-- Check vgg19 model
-- Check squeezenet1_0 model
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
-- Check squeezenet1_1 model
-- Check inception_v3 model
/usr/local/lib/python3.8/dist-packages/torchvision/models/inception.py:44: FutureWarning: The default weight initialization of inception_v3 will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
warnings.warn(
Output does not match reference: 12340685824.0 vs 12340652032.0
-- Check densenet121 model
-- Check densenet169 model
-- Check densenet201 model
-- Check densenet161 model
-- Check googlenet model
/usr/local/lib/python3.8/dist-packages/torchvision/models/googlenet.py:46: FutureWarning: The default weight initialization of GoogleNet will be changed in future releases of torchvision. If you wish to keep the old behavior (which leads to long initialization times due to scipy/scipy#11299), please set init_weights=True.
warnings.warn(
-- Check mobilenet_v2 model
-- Check mobilenet_v3_large model
/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py:1279: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::dropout_. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /ft/functorch/csrc/BatchedFallback.cpp:82.)
return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
-- Check mobilenet_v3_small model
-- Check mnasnet0_5 model
-- Check mnasnet0_75 model
-- Check mnasnet1_0 model
-- Check mnasnet1_3 model
-- Check shufflenet_v2_x0_5 model
-- Check shufflenet_v2_x1_0 model
-- Check shufflenet_v2_x1_5 model
-- Check shufflenet_v2_x2_0 model
-- Check efficientnet_b0 model
-- Check efficientnet_b1 model
-- Check efficientnet_b2 model
-- Check efficientnet_b3 model
-- Check efficientnet_b4 model
-- Check efficientnet_b5 model
-- Check efficientnet_b6 model
-- Check efficientnet_b7 model
-- Check regnet_y_400mf model
-- Check regnet_y_800mf model
-- Check regnet_y_1_6gf model
-- Check regnet_y_3_2gf model
-- Check regnet_y_8gf model
-- Check regnet_y_16gf model
-- Check regnet_y_32gf model
-- Check regnet_y_128gf model
-- Check regnet_x_400mf model
-- Check regnet_x_800mf model
-- Check regnet_x_1_6gf model
-- Check regnet_x_3_2gf model
-- Check regnet_x_8gf model
-- Check regnet_x_16gf model
-- Check regnet_x_32gf model
-- Check vit_b_16 model
-- Check vit_b_32 model
-- Check vit_l_16 model
-- Check vit_l_32 model
import torch
import torch.nn as nn
import torchvision
import torchvision.models as tv_models
import functorch
from functorch import combine_state_for_ensemble, vmap
tested_models = []
for model_name in tv_models.__dict__:
if model_name.startswith("_") or model_name[0].isupper():
continue
if model_name in ["segmentation", "detection", "video", "quantization", "feature_extraction"]:
continue
if not callable(tv_models.__dict__[model_name]):
continue
tested_models.append(model_name)
criterion = nn.CrossEntropyLoss()
def compute_same_batch_preds(models, batch):
# Fix seed to fix dropout
torch.manual_seed(0)
output = [model(batch) for model in models]
return output
num_models = 5
device = 'cpu'
def check_ensembling_model(model_name, device):
batch_size = 8
torch.manual_seed(0)
if model_name == "inception_v3":
size = (299, 299)
kwargs = {"aux_logits": False}
elif model_name == "googlenet":
size = (224, 224)
kwargs = {"aux_logits": False}
else:
size = (224, 224)
kwargs = {}
models = [
tv_models.__dict__[model_name](num_classes=10, **kwargs).to(device)
for _ in range(num_models)
]
# Disable training mode:
[m.eval() for m in models]
fmodel, params, buffers = combine_state_for_ensemble(models)
images = torch.randn(batch_size, 3, *size, device=device)
ref_preds = compute_same_batch_preds(models, images)
out_preds = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, images)
assert torch.allclose(out_preds, torch.stack(ref_preds), atol=1e-3, rtol=1e-5)
print("")
print("Torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("Functorch:", functorch.__version__)
print("")
for model_name in tested_models:
print(f"-- Check {model_name} model")
try:
check_ensembling_model(model_name, device=device)
except AssertionError as e:
print(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment