Skip to content

Instantly share code, notes, and snippets.

@MasterSkepticista
Last active July 5, 2024 09:17
Show Gist options
  • Save MasterSkepticista/c854bce837a5cb5ca0489bd33b3a2259 to your computer and use it in GitHub Desktop.
Save MasterSkepticista/c854bce837a5cb5ca0489bd33b3a2259 to your computer and use it in GitHub Desktop.
A standalone tool to convert torchvision.resnet50 weights to flax format.
"""Tool to convert torchvision.resnet50 pretrained weights for flax.
This tool loads torchvision weights, unflattens them to a nested PyTree of
np.ndarray, which can be loaded to a ResNet50 flax model (assuming it has
the same tree structure).
A sample ResNet flax implementation which can load these weights is available
here: https://github.com/MasterSkepticista/detr/blob/main/models/resnet.py
Usage:
```shell
# 1. Install few packages
$> pip install torch torchvision absl-py
# 2. Convert
$> python convert_torchvision_weights.py --outdir artifacts/resnet50
```
"""
import collections
import os
import numpy as np
import torchvision
from absl import app, flags, logging
logging.set_verbosity("info")
flags.DEFINE_string(
"outdir", default=None, help="Path to store converted weights.")
flags.mark_flags_as_required(["outdir"])
FLAGS = flags.FLAGS
def recover_tree(keys, values) -> dict:
"""Recovers a tree as a nested dict from flat names and values.
Args:
keys: a list of keys where `/` is used as a separator between nodes.
values: a list of leaf values.
Returns:
A nested tree-like dict.
"""
tree = {}
sub_trees = collections.defaultdict(list)
for k, v in zip(keys, values):
if "/" not in k:
tree[k] = v
else:
k_left, k_right = k.split("/", 1)
sub_trees[k_left].append((k_right, v))
for k, kv_pairs in sub_trees.items():
k_subtree, v_subtree = zip(*kv_pairs)
tree[k] = recover_tree(k_subtree, v_subtree)
return tree
def _convert(key: str, value: np.ndarray) -> tuple:
"""Converts parameters to be compatible with Flax, returns updated (key, value) pair.
Common transforms done here:
1. Flax calls `weight` as `kernel` for conv/fc, `weight` as `scale` for norms.
2. `conv` weights are converted from NCHW to NHWC.
3. `fc` weights are converted from (inC, outC) to (outC, inC).
4. `bias` terms are left unchanged.
"""
if "conv" in key or "downsample/0" in key:
key = key.replace("weight", "kernel")
value = np.transpose(value, (2, 3, 1, 0))
return key, value
if "fc" in key and "weight" in key:
key = key.replace("weight", "kernel")
value = np.transpose(value, (1, 0))
return key, value
if "bn" in key or "downsample/1" in key:
return key.replace("weight", "scale"), value
logging.info(f"Using as-is `{key}`: {value.shape}")
return key, value
def main(unused_argv):
# Load torch weights.
logging.info("Loading weights...")
model = torchvision.models.resnet50(
weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
# Drop unused keys, change path separator, and convert to numpy.
weights = {
k.replace(".", "/"): v.numpy()
for k, v in model.state_dict().items()
if "num_batches_tracked" not in k
}
# Flax uses a separate `batch_stats` collection for mean/var.
# Rest are parameters. We create a list of (name, value) pairs which makes
# it easy to filter/map/unflatten later.
batch_stats = [(k.replace("running_", ""), v) for k, v in weights.items() if "running_" in k]
params = [(k, v) for k, v in weights.items() if "running_" not in k]
# Convert params.
logging.info("Converting params...")
params = [_convert(k, v) for k, v in params]
# We can store params and batch_stats as a numpy dictionary.
# Since numpy does not support nested dicts, we use the flattened version
# with (path-like-key, value) pairs. We prefix all parameters with their type
# name - as params or batch_stats.
# variables = {
# "params/a/0/kernel": np.ndarray,
# "params/a/1/kernel": np.ndarray,
# ...,
# "batch_stats/x/mean": np.ndarray,
# "batch_stats/x/var": np.ndarray,
# }
variables = {
**{f"params/{k}": v for k, v in params},
**{f"batch_stats/{k}": v for k, v in batch_stats},
}
os.makedirs(FLAGS.outdir, exist_ok=True)
filename = os.path.join(FLAGS.outdir, "weights.npz")
np.savez(filename, **variables)
logging.info("Saved to %s", filename)
# Loading back and using with flax: Flax models use nested dictionaries of
# params and batch stats arranged in a particular order. We recover the nested
# structure from the path-like name.
variables = np.load(filename)
names, values = zip(*list(variables.items()))
variables = recover_tree(names, values)
# Now you can directly do flax_model.apply(variables, ...)
if __name__ == "__main__":
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment