Skip to content

Instantly share code, notes, and snippets.

View Chris-hughes10's full-sized avatar

Chris Hughes Chris-hughes10

View GitHub Profile
@Chris-hughes10
Chris-hughes10 / train_with_metrics_in_loop.py
Created November 24, 2021 11:11
pytorch_accelerated_blog_metrics_in_trainer_script
# https://github.com/Chris-hughes10/pytorch-accelerated/blob/main/examples/metrics/train_with_metrics_in_loop.py
import os
from torch import nn, optim
from torch.utils.data import random_split
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_accelerated import Trainer
@Chris-hughes10
Chris-hughes10 / trainer_with_metrics.py
Created November 24, 2021 11:02
pytorch_accelerated_blog_trainer_metrics_snippet
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
class TrainerWithMetrics(Trainer):
def __init__(self, num_classes, *args, **kwargs):
super().__init__(*args, **kwargs)
# this will be moved to the correct device automatically by the
# MoveModulesToDeviceCallback callback, which is used by default
self.metrics = MetricCollection(
{
@Chris-hughes10
Chris-hughes10 / train_mnist.py
Last active November 24, 2021 10:53
pytorch-accelerated_blog_mnist_quickstart
# this example is taken from
# https://github.com/Chris-hughes10/pytorch-accelerated/blob/main/examples/train_mnist.py
import os
from torch import nn, optim
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST
@Chris-hughes10
Chris-hughes10 / EfficientDet Pytorch-lightning with EfficientNet v2 backbone Blog Post.ipynb
Last active April 22, 2024 08:57
EfficientDet Pytorch-lightning with EfficientNet v2 backbone Blog Post.ipynb
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from objdetecteval.metrics.coco_metrics import get_coco_stats
@patch
def validation_epoch_end(self: EfficientDetModel, outputs):
"""Compute and log training loss and accuracy at the epoch level."""
validation_loss_mean = torch.stack(
[output["loss"] for output in outputs]
).mean()
@Chris-hughes10
Chris-hughes10 / effdet_aggregate_outputs.py
Created July 16, 2021 09:52
Effdet_blog_aggregate_outputs
from fastcore.basics import patch
@patch
def aggregate_prediction_outputs(self: EfficientDetModel, outputs):
detections = torch.cat(
[output["batch_predictions"]["predictions"] for output in outputs]
)
image_ids = []
@Chris-hughes10
Chris-hughes10 / effdet_run_inference.py
Created July 16, 2021 09:46
Effdet_blog_inference
def _run_inference(self, images_tensor, image_sizes):
dummy_targets = self._create_dummy_inference_targets(
num_images=images_tensor.shape[0]
)
detections = self.model(images_tensor.to(self.device), dummy_targets)[
"detections"
]
(
predicted_bboxes,
@typedispatch
def predict(self, images: List):
"""
For making predictions from images
Args:
images: a list of PIL images
Returns: a tuple of lists containing bboxes, predicted_class_labels, predicted_class_confidences
"""
@Chris-hughes10
Chris-hughes10 / effdet_model_1.py
Created July 16, 2021 09:40
Effdet_blog_model_1
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.core.decorators import auto_move_data
class EfficientDetModel(LightningModule):
def __init__(
self,
num_classes=1,
img_size=512,
prediction_confidence_threshold=0.2,
@Chris-hughes10
Chris-hughes10 / effdet_datamodule.py
Created July 16, 2021 09:36
Effdet_blog_datamodule
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
class EfficientDetDataModule(LightningDataModule):
def __init__(self,
train_dataset_adaptor,
validation_dataset_adaptor,
train_transforms=get_train_transforms(target_img_size=512),
valid_transforms=get_valid_transforms(target_img_size=512),