Skip to content

Instantly share code, notes, and snippets.

@Chris-hughes10
Created July 16, 2021 09:40
Show Gist options
  • Save Chris-hughes10/d0b80e4c0949381f517a47fa362649e2 to your computer and use it in GitHub Desktop.
Save Chris-hughes10/d0b80e4c0949381f517a47fa362649e2 to your computer and use it in GitHub Desktop.
Effdet_blog_model_1
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.core.decorators import auto_move_data
class EfficientDetModel(LightningModule):
def __init__(
self,
num_classes=1,
img_size=512,
prediction_confidence_threshold=0.2,
learning_rate=0.0002,
wbf_iou_threshold=0.44,
inference_transforms=get_valid_transforms(target_img_size=512),
model_architecture='tf_efficientnetv2_l',
):
super().__init__()
self.img_size = img_size
self.model = create_model(
num_classes, img_size, architecture=model_architecture
)
self.prediction_confidence_threshold = prediction_confidence_threshold
self.lr = learning_rate
self.wbf_iou_threshold = wbf_iou_threshold
self.inference_tfms = inference_transforms
@auto_move_data
def forward(self, images, targets):
return self.model(images, targets)
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=self.lr)
def training_step(self, batch, batch_idx):
images, annotations, _, image_ids = batch
losses = self.model(images, annotations)
logging_losses = {
"class_loss": losses["class_loss"].detach(),
"box_loss": losses["box_loss"].detach(),
}
self.log("train_loss", losses["loss"], on_step=True, on_epoch=True, prog_bar=True,
logger=True)
self.log(
"train_class_loss", losses["class_loss"], on_step=True, on_epoch=True, prog_bar=True,
logger=True
)
self.log("train_box_loss", losses["box_loss"], on_step=True, on_epoch=True, prog_bar=True,
logger=True)
return losses['loss']
@torch.no_grad()
def validation_step(self, batch, batch_idx):
images, annotations, targets, image_ids = batch
outputs = self.model(images, annotations)
detections = outputs["detections"]
batch_predictions = {
"predictions": detections,
"targets": targets,
"image_ids": image_ids,
}
logging_losses = {
"class_loss": outputs["class_loss"].detach(),
"box_loss": outputs["box_loss"].detach(),
}
self.log("valid_loss", outputs["loss"], on_step=True, on_epoch=True, prog_bar=True,
logger=True, sync_dist=True)
self.log(
"valid_class_loss", logging_losses["class_loss"], on_step=True, on_epoch=True,
prog_bar=True, logger=True, sync_dist=True
)
self.log("valid_box_loss", logging_losses["box_loss"], on_step=True, on_epoch=True,
prog_bar=True, logger=True, sync_dist=True)
return {'loss': outputs["loss"], 'batch_predictions': batch_predictions}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment