Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created September 7, 2023 11:11
Show Gist options
  • Save chenyaofo/a8578931e4e3b4c83152652c4d1264a2 to your computer and use it in GitHub Desktop.
Save chenyaofo/a8578931e4e3b4c83152652c4d1264a2 to your computer and use it in GitHub Desktop.
serialize transforms into files.
import torch.package as package
import torch
import torchvision.transforms as T
def get_train_transforms(crop_size, mean, std, is_training):
pipelines = []
if is_training:
pipelines.append(T.RandomResizedCrop(crop_size))
pipelines.append(T.RandomHorizontalFlip())
else:
pipelines.append(T.Resize(int(crop_size/7*8)))
pipelines.append(T.CenterCrop(crop_size))
pipelines.append(T.ToTensor())
pipelines.append(T.Normalize(mean=mean, std=std))
return T.Compose(pipelines)
pre_processing = get_train_transforms(224, [0.5, 0.5, 0.5], [1, 1, 1], False)
with package.PackageExporter("pre_processing.pt.package") as exporter:
exporter.intern("codebase.**")
exporter.intern("torchvision.**")
exporter.extern("numpy.**")
exporter.extern("PIL.**")
exporter.save_pickle("pre_processing", "transform.pkl", get_train_transforms)
importer = package.PackageImporter("pre_processing.pt.package")
f = importer.load_pickle("pre_processing", "transform.pkl")
pre_processing_pkl = f(224, [0.5, 0.5, 0.5], [1, 1, 1], False)
from PIL import Image
x = Image.open("img.png").convert("RGB")
y1 = pre_processing(x)
y2 = pre_processing_pkl(x)
print(torch.allclose(y1,y2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment