Created
September 1, 2021 02:53
-
-
Save e96031413/e6e3e766ad0b4aa89afdc8fdc41519cf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/pytorch/examples/blob/master/imagenet/main.py | |
import argparse | |
import os | |
import sys | |
import random | |
import shutil | |
import time | |
import warnings | |
from tqdm import tqdm | |
import pdb | |
import torch | |
import torch.nn as nn | |
import torch.nn.parallel | |
import torch.backends.cudnn as cudnn | |
import torch.distributed as dist | |
import torch.optim | |
import torch.multiprocessing as mp | |
import torch.utils.data | |
import torch.utils.data.distributed | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
import torchvision.models as models | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import json | |
def main(): | |
random.seed(999) | |
torch.manual_seed(999) | |
cudnn.deterministic = True | |
warnings.warn('You have chosen to seed training. ' | |
'This will turn on the CUDNN deterministic setting, ' | |
'which can slow down your training considerably! ' | |
'You may see unexpected behavior when restarting ' | |
'from checkpoints.') | |
main_worker() | |
def main_worker(): | |
class CustomDataset(torch.utils.data.Dataset): | |
def __init__(self, dataframe): | |
self.dataframe = dataframe | |
def __len__(self): | |
return len(self.dataframe) | |
def __getitem__(self, index): | |
row = self.dataframe.iloc[index] | |
image = Image.open((row["file_path"])) | |
label = np.asarray(row["class"]) | |
return (image, label) | |
class UnNormalize(object): | |
def __init__(self, mean, std): | |
self.mean = mean | |
self.std = std | |
def __call__(self, tensor): | |
""" | |
Args: | |
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. | |
Returns: | |
Tensor: Normalized image. | |
""" | |
for t, m, s in zip(tensor, self.mean, self.std): | |
t.mul_(s).add_(m) | |
# The normalize code -> t.sub_(m).div_(s) | |
return tensor | |
def compute_mean_and_std(dataset): | |
# 输入PyTorch的dataset,输出均值和标准差 | |
mean_r = 0 | |
mean_g = 0 | |
mean_b = 0 | |
for img, _ in tqdm(dataset): | |
img = np.asarray(img) # change PIL Image to numpy array | |
mean_b += np.mean(img[:, :, 0]) | |
mean_g += np.mean(img[:, :, 1]) | |
mean_r += np.mean(img[:, :, 2]) | |
mean_b /= len(dataset) | |
mean_g /= len(dataset) | |
mean_r /= len(dataset) | |
diff_r = 0 | |
diff_g = 0 | |
diff_b = 0 | |
N = 0 | |
for img, _ in tqdm(dataset): | |
img = np.asarray(img) | |
diff_b += np.sum(np.power(img[:, :, 0] - mean_b, 2)) | |
diff_g += np.sum(np.power(img[:, :, 1] - mean_g, 2)) | |
diff_r += np.sum(np.power(img[:, :, 2] - mean_r, 2)) | |
N += np.prod(img[:, :, 0].shape) | |
std_b = np.sqrt(diff_b / N) | |
std_g = np.sqrt(diff_g / N) | |
std_r = np.sqrt(diff_r / N) | |
mean = (mean_b.item() / 255.0, mean_g.item() / 255.0, mean_r.item() / 255.0) | |
std = (std_b.item() / 255.0, std_g.item() / 255.0, std_r.item() / 255.0) | |
print("mean:", mean) | |
print("std:", std) | |
return mean, std | |
data_dir = "path/to/datasetFolder" | |
print('==> Preparing data..') | |
dataset_info = json.load(open("/root/notebooks/Folder/dataset.json", "r")) | |
df = pd.DataFrame.from_dict(dataset_info, orient="index") | |
df['file_path'] = df.index | |
df["file_path"] = data_dir + df["file_path"].astype(str) | |
# 分成6個class | |
df.loc[df['class'] == "good", 'class'] = 0 | |
df.loc[df['class'] == "missing", 'class'] = 1 | |
df.loc[df['class'] == "shift", 'class'] = 2 | |
df.loc[df['class'] == "stand", 'class'] = 3 | |
df.loc[df['class'] == "broke", 'class'] = 4 | |
df.loc[df['class'] == "short", 'class'] = 5 | |
df = df.drop(columns=['component_name']) | |
size_df = len(df) | |
total_dataset = CustomDataset(df) | |
dataset_loader = torch.utils.data.DataLoader(total_dataset, batch_size=128, shuffle=False, | |
num_workers=8, pin_memory=True, sampler=None) | |
compute_mean_and_std(total_dataset) | |
return | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment