Last active
July 5, 2024 09:17
-
-
Save MasterSkepticista/c854bce837a5cb5ca0489bd33b3a2259 to your computer and use it in GitHub Desktop.
A standalone tool to convert torchvision.resnet50 weights to flax format.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Tool to convert torchvision.resnet50 pretrained weights for flax. | |
This tool loads torchvision weights, unflattens them to a nested PyTree of | |
np.ndarray, which can be loaded to a ResNet50 flax model (assuming it has | |
the same tree structure). | |
A sample ResNet flax implementation which can load these weights is available | |
here: https://github.com/MasterSkepticista/detr/blob/main/models/resnet.py | |
Usage: | |
```shell | |
# 1. Install few packages | |
$> pip install torch torchvision absl-py | |
# 2. Convert | |
$> python convert_torchvision_weights.py --outdir artifacts/resnet50 | |
``` | |
""" | |
import collections | |
import os | |
import numpy as np | |
import torchvision | |
from absl import app, flags, logging | |
logging.set_verbosity("info") | |
flags.DEFINE_string( | |
"outdir", default=None, help="Path to store converted weights.") | |
flags.mark_flags_as_required(["outdir"]) | |
FLAGS = flags.FLAGS | |
def recover_tree(keys, values) -> dict: | |
"""Recovers a tree as a nested dict from flat names and values. | |
Args: | |
keys: a list of keys where `/` is used as a separator between nodes. | |
values: a list of leaf values. | |
Returns: | |
A nested tree-like dict. | |
""" | |
tree = {} | |
sub_trees = collections.defaultdict(list) | |
for k, v in zip(keys, values): | |
if "/" not in k: | |
tree[k] = v | |
else: | |
k_left, k_right = k.split("/", 1) | |
sub_trees[k_left].append((k_right, v)) | |
for k, kv_pairs in sub_trees.items(): | |
k_subtree, v_subtree = zip(*kv_pairs) | |
tree[k] = recover_tree(k_subtree, v_subtree) | |
return tree | |
def _convert(key: str, value: np.ndarray) -> tuple: | |
"""Converts parameters to be compatible with Flax, returns updated (key, value) pair. | |
Common transforms done here: | |
1. Flax calls `weight` as `kernel` for conv/fc, `weight` as `scale` for norms. | |
2. `conv` weights are converted from NCHW to NHWC. | |
3. `fc` weights are converted from (inC, outC) to (outC, inC). | |
4. `bias` terms are left unchanged. | |
""" | |
if "conv" in key or "downsample/0" in key: | |
key = key.replace("weight", "kernel") | |
value = np.transpose(value, (2, 3, 1, 0)) | |
return key, value | |
if "fc" in key and "weight" in key: | |
key = key.replace("weight", "kernel") | |
value = np.transpose(value, (1, 0)) | |
return key, value | |
if "bn" in key or "downsample/1" in key: | |
return key.replace("weight", "scale"), value | |
logging.info(f"Using as-is `{key}`: {value.shape}") | |
return key, value | |
def main(unused_argv): | |
# Load torch weights. | |
logging.info("Loading weights...") | |
model = torchvision.models.resnet50( | |
weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1) | |
# Drop unused keys, change path separator, and convert to numpy. | |
weights = { | |
k.replace(".", "/"): v.numpy() | |
for k, v in model.state_dict().items() | |
if "num_batches_tracked" not in k | |
} | |
# Flax uses a separate `batch_stats` collection for mean/var. | |
# Rest are parameters. We create a list of (name, value) pairs which makes | |
# it easy to filter/map/unflatten later. | |
batch_stats = [(k.replace("running_", ""), v) for k, v in weights.items() if "running_" in k] | |
params = [(k, v) for k, v in weights.items() if "running_" not in k] | |
# Convert params. | |
logging.info("Converting params...") | |
params = [_convert(k, v) for k, v in params] | |
# We can store params and batch_stats as a numpy dictionary. | |
# Since numpy does not support nested dicts, we use the flattened version | |
# with (path-like-key, value) pairs. We prefix all parameters with their type | |
# name - as params or batch_stats. | |
# variables = { | |
# "params/a/0/kernel": np.ndarray, | |
# "params/a/1/kernel": np.ndarray, | |
# ..., | |
# "batch_stats/x/mean": np.ndarray, | |
# "batch_stats/x/var": np.ndarray, | |
# } | |
variables = { | |
**{f"params/{k}": v for k, v in params}, | |
**{f"batch_stats/{k}": v for k, v in batch_stats}, | |
} | |
os.makedirs(FLAGS.outdir, exist_ok=True) | |
filename = os.path.join(FLAGS.outdir, "weights.npz") | |
np.savez(filename, **variables) | |
logging.info("Saved to %s", filename) | |
# Loading back and using with flax: Flax models use nested dictionaries of | |
# params and batch stats arranged in a particular order. We recover the nested | |
# structure from the path-like name. | |
variables = np.load(filename) | |
names, values = zip(*list(variables.items())) | |
variables = recover_tree(names, values) | |
# Now you can directly do flax_model.apply(variables, ...) | |
if __name__ == "__main__": | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment