This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import segmentation_models_pytorch as smp | |
model = smp.Unet( | |
encoder_name="resnet50", # choose encoder | |
encoder_weights="imagenet", # choose pretrained (not required) | |
in_channels=3, # model input channels | |
classes=10, # model output channels | |
activation="None" # None|"sigmoid"|"softmax"; default is None | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torchmetrics import MetricCollection, Accuracy, Precision, Recall | |
class MyModule(LightningModule): | |
def __init__(self): | |
... | |
self.metric_collection = MetricCollection({ | |
'acc': Accuracy(), | |
'prec': Precision(num_classes=10, average='macro'), | |
'rec': Recall(num_classes=10, average='macro') | |
}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchmetrics | |
class MyAccuracy(Metric): | |
def __init__(self): | |
super().__init__() | |
# to count the correct predictions | |
self.add_state('corrects', default=torch.tensor(0)) | |
# to count the total predictions | |
self.add_state('total', default=torch.tensor(0)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchmetrics import MetricCollection, Accuracy, Precision, Recall | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = YourModel().to(device) | |
# collection of all validation metrics | |
metric_collection = MetricCollection({ | |
'acc': Accuracy(), | |
'prec': Precision(num_classes=10, average='macro'), | |
'rec': Recall(num_classes=10, average='macro') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchmetrics | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = YourModel().to(device) | |
metric = torchmetrics.Accuracy() | |
for batch_idx, (data, target) in enumerate(val_dataloader): | |
data, target = data.to(device), target.to(device) | |
preds = model(data) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import open3d as o3d | |
from PIL import Image | |
# data | |
image = Image.open('image.jpg') | |
model = Model() | |
output = model(image) # depth, 512 x 512 matrix | |
H, W = output.shape[0], output.shape[1] |