Created
December 18, 2021 13:45
-
-
Save Hummer12007/b93c34b4067cc23e48853f945231292b to your computer and use it in GitHub Desktop.
Exporting and using a Flash image classifier in torchvision
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import click | |
@click.group(context_settings=dict(help_option_names=['-h'])) | |
def main(): | |
""" | |
Flash classifier training utilities | |
""" | |
@main.command('train', help='Train a flash image classifier') | |
@click.option('--data', default='data/images', help='Train dataset path') | |
@click.option('--checkpoint', default='checkpoints/train.pt', help='Checkpoint path path') | |
@click.option('--batch-size', type=int, default=8, help='Batch size') | |
@click.option('--epochs', type=int, default=3, help='Number of epochs') | |
def train(data, checkpoint, batch_size, epochs): | |
import os | |
from flash import Trainer | |
from flash.image import ImageClassifier, ImageClassificationData | |
datamodule = ImageClassificationData.from_folders( | |
train_folder=f'{data}/train', | |
val_folder=f'{data}/val', | |
batch_size=8 | |
) | |
model = ImageClassifier(num_classes=datamodule.num_classes, backbone="resnet18") | |
trainer = Trainer(max_epochs=epochs) | |
trainer.finetune(model, datamodule=datamodule, strategy="freeze") | |
if os.path.dirname(checkpoint): | |
os.makedirs(os.path.dirname(checkpoint), exist_ok=True) | |
trainer.save_checkpoint(checkpoint) | |
trainer.validate() | |
@main.command('export', help='Export a trained flash ImageClassifier') | |
@click.option('--checkpoint', default='checkpoints/train.pt', help='Checkpoint path path') | |
@click.option('--model-path', default='checkpoints/train.pk', help='Output pk model path') | |
def export_model(checkpoint, model_path): | |
from collections import OrderedDict | |
import torch | |
from torchvision import transforms | |
from torchvision.models import resnet18 | |
from flash import RunningStage | |
from flash.image import ImageClassifier | |
from flash.core.data.io.classification_input import ClassificationState | |
from flash.image.classification import ImageClassificationInputTransform | |
# load flash model | |
flash_model = ImageClassifier.load_from_checkpoint(checkpoint) | |
model_state = flash_model.get_state(ClassificationState) | |
# get labels | |
labels = model_state.labels | |
# build torchvision model | |
model = resnet18(num_classes=model_state.num_classes, pretrained=False) | |
# convert and load model weights | |
flash_state_dict = OrderedDict({(k.replace("backbone.", "").replace("head.0", "fc"),v) for (k, v) in flash_model.adapter.state_dict().items()}) | |
model.load_state_dict(flash_state_dict) | |
model.eval() | |
# get default input transform | |
transform = ImageClassificationInputTransform(running_stage=RunningStage.PREDICTING).input_per_sample_transform() | |
torch.save({ | |
'model': model, | |
'labels': labels, | |
'transform': transform, | |
}, model_path) | |
@main.command('infer_dir') | |
@click.option('--images', help='Unannotated images path') | |
@click.option('--model-path', default='checkpoints/train.pk', help='Model path') | |
def infer_dir(images, model_path): | |
import os | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
model = torch.load(model_path) | |
clf = model['model'] | |
labels = model['labels'] | |
preprocess = model['transform'] | |
files = [f'{images}/{f}' for f in os.listdir(images)] | |
crops = [] | |
tensors = [] | |
for f in files: | |
input_image = Image.open(f).convert('RGB') | |
input_tensor = preprocess(input_image) | |
tensors.append(input_tensor) | |
input_batch = torch.stack(tensors) | |
with torch.no_grad(): | |
output = clf(input_batch) | |
result = torch.argmax(output, axis=1) | |
result = list(map(lambda i: labels[i], result.cpu().tolist())) | |
for f, r in zip(files, result): | |
print(f, r) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment