This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader, random_split | |
import os | |
from PIL import Image, ImageOps | |
from sklearn.metrics import f1_score, recall_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.models as models | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader, random_split | |
import os | |
from PIL import Image, ImageOps | |
from sklearn.metrics import f1_score, recall_score | |
import numpy as np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
The inference transforms are available at GoogLeNet_Weights.IMAGENET1K_V1.transforms | |
and perform the following preprocessing operations: Accepts PIL.Image, batched (B, C, H, W) | |
and single (C, H, W) image torch.Tensor objects. | |
The images are resized to resize_size=[256] using interpolation=InterpolationMode. | |
BILINEAR, followed by a central crop of crop_size=[224]. Finally the values are first | |
rescaled to [0.0, 1.0] and then normalized using mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225]. | |
""" | |
import os |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision.models as models | |
# Define your model architecture (ensure it matches AlexNet) | |
class MyAlexNet(torch.nn.Module): | |
def __init__(self): | |
super(MyAlexNet, self).__init__() | |
self.features = torch.nn.Sequential( | |
torch.nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), | |
torch.nn.ReLU(inplace=True), |