Skip to content

Instantly share code, notes, and snippets.

@rohinarora
Last active June 28, 2020 04:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohinarora/2cc4780d559a4da0f322d67d7e96036f to your computer and use it in GitHub Desktop.
Save rohinarora/2cc4780d559a4da0f322d67d7e96036f to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch import nn
from torch.nn import Conv2d,BatchNorm2d,ReLU,MaxPool2d
#from tqdm import tqdm
from tqdm import tqdm_notebook as tqdm
%matplotlib inline
def get_data_set(train_flag=True):
if train_flag:
data_set = torchvision.datasets.__dict__['CIFAR10'](root='./dataset', train=True, download=True)
else:
data_set = torchvision.datasets.__dict__['CIFAR10'](root='./dataset', train=False,download=True)
return data_set
train_data_set = get_data_set(train_flag=True)
test_data_set = get_data_set(train_flag=False)
print(train_data_set.data.mean(axis=(0,1,2))/255)
print(train_data_set.data.std(axis=(0,1,2))/255)
print(test_data_set.data.mean(axis=(0,1,2))/255)
print(test_data_set.data.std(axis=(0,1,2))/255)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment