Last active
September 24, 2023 04:24
-
-
Save xx025/2e41e8ae2bf0cd31e0d3dacb36c8d506 to your computer and use it in GitHub Desktop.
计算数据集的均值和方差
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
import torch | |
from torchvision.datasets import ImageFolder | |
from torchvision.transforms import ToTensor, PILToTensor | |
from tqdm import tqdm | |
def main(data_path, save_name, out_dir='./', is_to01=False): | |
""" | |
计算数据集的均值和方差 | |
:param data_path: 数据集路径,支持父文件夹 father/ = father/train + father/val | |
:param save_name: 要保存的文件名 如 xx.json | |
:param out_dir: 保存文件的文件夹 | |
:param is_to01: 是否将像素值缩小到0~1之间 | |
:return: | |
""" | |
# 定义一个字典存储均值和方差 | |
norm = dict( | |
means=torch.zeros(3), | |
stds=torch.zeros(3), | |
) | |
if is_to01: | |
# 可将图片类型转化为张量,并把0~255的像素值缩小到0~1之间 | |
transform = ToTensor() | |
else: | |
# 不缩放像素值 | |
class PILToFloat(PILToTensor): | |
def __call__(self, pic): # 将pic 转化为张量 | |
img = super().__call__(pic) | |
# 继续将img 转换成float类型 | |
# 将 img 转成float类型 范围不变 | |
return img.float() | |
transform = PILToFloat() | |
# 导入数据集 | |
dataset = ImageFolder(data_path, transform=transform) # 导入数据集的图片,并且转化为张量 | |
# 通过for循环,将所有图片的像素值相加,然后除以图片的总数,得到均值 | |
for i in tqdm(range(len(dataset))): | |
img, _ = dataset[i] | |
norm['means'] += img.mean([1, 2]) # 计算三个通道的均值 | |
norm['stds'] += img.std([1, 2]) # 计算三个通道的方差 | |
else: | |
norm['means'] /= len(dataset) # 计算所有图片的均值 | |
norm['stds'] /= len(dataset) # 计算所有图片的方差 | |
print('计算完成!') | |
print('均值:', norm['means']) | |
print('方差:', norm['stds']) | |
norm['means'] = norm['means'].tolist() | |
norm['stds'] = norm['stds'].tolist() | |
# 将均值和方差保存到指定文件夹下的json文件中 | |
if not os.path.exists(out_dir): # 如果文件夹不存在,则创建 | |
os.makedirs(out_dir) | |
save_path = os.path.join(out_dir, save_name) | |
with open(save_path, 'w') as f: # 将字典保存到json 文件中 | |
json.dump(norm, f) | |
print('保存完成!', save_path) | |
if __name__ == '__main__': | |
data = r"D:\workspace\DataSets\source\LLVIP\visible" | |
# 支持父文件夹 如:father/train father/val | |
main(data_path=data, save_name='mean_std.json', out_dir='data') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment