This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _save_pred(self, logits, filenames): | |
# here need to extend from 1-dim to 3-dim in channel dimension | |
invert_mask_mapping = { | |
0: (255, 255, 255), # impervious surfaces | |
1: (0, 0, 255), # Buildings | |
2: (0, 255, 255), # Low Vegetation | |
3: (0, 255, 0), # Tree | |
4: (255, 255, 0), # Car |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
class ConvBnRelu(nn.Module): | |
def __init__(self, inc, outc, ksize, stride, pad, dilation=1, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.nn.init as init | |
import torchvision.models as models | |
from functools import partial | |
from opts import _nostride2dilation | |
from Models.layers import ASPP, _FCNHead, ConvBnRelu |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_one_epoch(self, epoch, train_log_func): | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
losses = AverageMeter() | |
accs = AverageMeter() | |
mious = AverageMeter() | |
fscores = AverageMeter() | |
self.model.train() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import tqdm | |
import os | |
import glob | |
from Utils.tools import ensure_dir | |
from PIL import Image | |
from Configs.config import Configurations | |
import argparse | |
def image_cropper(configs, subset): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from torch.utils.data import Dataset | |
import torchvision.transforms.functional as TF | |
import random | |
import glob | |
import os | |
from PIL import Image | |
class TrainingData_Dataset(Dataset): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graphviz import Digraph | |
import torch | |
from torch.autograd import Variable, Function | |
def iter_graph(root, callback): | |
queue = [root] | |
seen = set() | |
while queue: | |
fn = queue.pop() | |
if fn in seen: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
use_adam = False | |
class MyModel(nn.Module): | |
def __init__(self): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A weighted version of categorical_crossentropy for keras (2.0.6). This lets you apply a weight to unbalanced classes. | |
@url: https://gist.github.com/wassname/ce364fddfc8a025bfab4348cf5de852d | |
@author: wassname | |
""" | |
from keras import backend as K | |
def weighted_categorical_crossentropy(weights): | |
""" | |
A weighted version of keras.objectives.categorical_crossentropy | |