Skip to content

Instantly share code, notes, and snippets.

@Syzygianinfern0
Last active September 4, 2021 12:50
Show Gist options
  • Save Syzygianinfern0/8d31e98bed8c05384675d82907fbfa52 to your computer and use it in GitHub Desktop.
Save Syzygianinfern0/8d31e98bed8c05384675d82907fbfa52 to your computer and use it in GitHub Desktop.
Quick Code Snippets

Quick Code Snippets

I have a bunch of useful code snippets I use quite frequently. I log them all here.

import os
if not os.path.exists(directory):
os.makedirs(directory)
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
def get_mean_and_std(dataset, device):
"""
Finds the mean and std of the dataset.
This is used for Normalization of the dataset.
:param dataset: dataset to calculate mean and std for
:param device: CUDA or CPU?
:return: tuple of mean and std
"""
dataloader = DataLoader(dataset, batch_size=4096)
channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0
for data, _ in tqdm(dataloader):
data = data.to(device)
channels_sum += torch.mean(data, dim=[0, 2, 3])
channels_sqrd_sum += torch.mean(data ** 2, dim=[0, 2, 3])
num_batches += 1
mean = channels_sum / num_batches
std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5
return mean, std
from collections import OrderedDict
def remove_dataparallel_wrapper(state_dict):
"""
Converts a DataParallel model to a normal one by removing the "module."
wrapper in the module dictionary.
:param state_dict: a torch.nn.DataParallel state dictionary
:return: a torch.nn.Module state dictionary
"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel
new_state_dict[name] = v
return new_state_dict
import os
import random
import numpy as np
import torch
from torch.backends import cudnn
def set_seed(seed):
"""
Seeds pretty much everything that can be.
:param seed: the seed number to be used
:return: None
"""
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment