Skip to content

Instantly share code, notes, and snippets.

@Hummer12007
Created December 18, 2021 13:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Hummer12007/b93c34b4067cc23e48853f945231292b to your computer and use it in GitHub Desktop.
Save Hummer12007/b93c34b4067cc23e48853f945231292b to your computer and use it in GitHub Desktop.
Exporting and using a Flash image classifier in torchvision
import click
@click.group(context_settings=dict(help_option_names=['-h']))
def main():
"""
Flash classifier training utilities
"""
@main.command('train', help='Train a flash image classifier')
@click.option('--data', default='data/images', help='Train dataset path')
@click.option('--checkpoint', default='checkpoints/train.pt', help='Checkpoint path path')
@click.option('--batch-size', type=int, default=8, help='Batch size')
@click.option('--epochs', type=int, default=3, help='Number of epochs')
def train(data, checkpoint, batch_size, epochs):
import os
from flash import Trainer
from flash.image import ImageClassifier, ImageClassificationData
datamodule = ImageClassificationData.from_folders(
train_folder=f'{data}/train',
val_folder=f'{data}/val',
batch_size=8
)
model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18")
trainer = Trainer(max_epochs=epochs)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
if os.path.dirname(checkpoint):
os.makedirs(os.path.dirname(checkpoint), exist_ok=True)
trainer.save_checkpoint(checkpoint)
trainer.validate()
@main.command('export', help='Export a trained flash ImageClassifier')
@click.option('--checkpoint', default='checkpoints/train.pt', help='Checkpoint path path')
@click.option('--model-path', default='checkpoints/train.pk', help='Output pk model path')
def export_model(checkpoint, model_path):
from collections import OrderedDict
import torch
from torchvision import transforms
from torchvision.models import resnet18
from flash import RunningStage
from flash.image import ImageClassifier
from flash.core.data.io.classification_input import ClassificationState
from flash.image.classification import ImageClassificationInputTransform
# load flash model
flash_model = ImageClassifier.load_from_checkpoint(checkpoint)
model_state = flash_model.get_state(ClassificationState)
# get labels
labels = model_state.labels
# build torchvision model
model = resnet18(num_classes=model_state.num_classes, pretrained=False)
# convert and load model weights
flash_state_dict = OrderedDict({(k.replace("backbone.", "").replace("head.0", "fc"),v) for (k, v) in flash_model.adapter.state_dict().items()})
model.load_state_dict(flash_state_dict)
model.eval()
# get default input transform
transform = ImageClassificationInputTransform(running_stage=RunningStage.PREDICTING).input_per_sample_transform()
torch.save({
'model': model,
'labels': labels,
'transform': transform,
}, model_path)
@main.command('infer_dir')
@click.option('--images', help='Unannotated images path')
@click.option('--model-path', default='checkpoints/train.pk', help='Model path')
def infer_dir(images, model_path):
import os
import torch
from PIL import Image
from torchvision import transforms
model = torch.load(model_path)
clf = model['model']
labels = model['labels']
preprocess = model['transform']
files = [f'{images}/{f}' for f in os.listdir(images)]
crops = []
tensors = []
for f in files:
input_image = Image.open(f).convert('RGB')
input_tensor = preprocess(input_image)
tensors.append(input_tensor)
input_batch = torch.stack(tensors)
with torch.no_grad():
output = clf(input_batch)
result = torch.argmax(output, axis=1)
result = list(map(lambda i: labels[i], result.cpu().tolist()))
for f, r in zip(files, result):
print(f, r)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment