Skip to content

Instantly share code, notes, and snippets.

@ArneNx
Last active July 26, 2023 14:14
Show Gist options
  • Save ArneNx/fd91d60cef787a7909c4e8fab2755d25 to your computer and use it in GitHub Desktop.
Save ArneNx/fd91d60cef787a7909c4e8fab2755d25 to your computer and use it in GitHub Desktop.
Reproducing ImageNet-C depends on the python version?
# -*- coding: utf-8 -*-
import argparse
import os
import time
import torch
from torch.autograd import Variable as V
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.datasets as dset
import torchvision.transforms as trn
import torchvision.models as models
import torch.utils.model_zoo as model_zoo
import numpy as np
import sys
print("Python Version:", sys.version)
print("TORCH VERSION:", torch.__version__)
parser = argparse.ArgumentParser(
description="Evaluates robustness of various nets on ImageNet",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Architecture
parser.add_argument(
"--model-name",
"-m",
type=str,
choices=[
"alexnet",
"squeezenet1.0",
"squeezenet1.1",
"condensenet4",
"condensenet8",
"vgg11",
"vgg",
"vggbn",
"densenet121",
"densenet169",
"densenet201",
"densenet161",
"densenet264",
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50",
"resnext101",
"resnext101_64",
],
)
# Acceleration
parser.add_argument("--ngpu", type=int, default=1, help="0 = CPU.")
args = parser.parse_args()
print(args)
# /////////////// Model Setup ///////////////
if args.model_name == "alexnet":
net = models.AlexNet()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth",
model_dir="/share/data/lang/users/dan/.torch/models",
)
)
args.test_bs = 256
elif args.model_name == "squeezenet1.0":
net = models.SqueezeNet(version=1.0)
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/squeezenet1_0-a815701f.pth",
model_dir="/share/data/lang/users/dan/.torch/models",
)
)
args.test_bs = 256
elif args.model_name == "squeezenet1.1":
net = models.SqueezeNet(version=1.1)
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth",
model_dir="/share/data/lang/users/dan/.torch/models",
)
)
args.test_bs = 256
elif args.model_name == "condensenet4":
args.evaluate = True
args.stages = [4, 6, 8, 10, 8]
args.growth = [8, 16, 32, 64, 128]
args.data = "imagenet"
args.num_classes = 1000
args.bottleneck = 4
args.group_1x1 = 4
args.group_3x3 = 4
args.reduction = 0.5
args.condense_factor = 4
net = CondenseNet(args)
state_dict = torch.load("./converted_condensenet_4.pth")["state_dict"]
for i in range(len(state_dict)):
name, v = state_dict.popitem(False)
state_dict[name[7:]] = v # remove 'module.' in key beginning
net.load_state_dict(state_dict)
args.test_bs = 256
elif args.model_name == "condensenet8":
args.evaluate = True
args.stages = [4, 6, 8, 10, 8]
args.growth = [8, 16, 32, 64, 128]
args.data = "imagenet"
args.num_classes = 1000
args.bottleneck = 4
args.group_1x1 = 8
args.group_3x3 = 8
args.reduction = 0.5
args.condense_factor = 8
net = CondenseNet(args)
state_dict = torch.load("./converted_condensenet_8.pth")["state_dict"]
for i in range(len(state_dict)):
name, v = state_dict.popitem(False)
state_dict[name[7:]] = v # remove 'module.' in key beginning
net.load_state_dict(state_dict)
args.test_bs = 256
elif "vgg" in args.model_name:
if "bn" not in args.model_name:
net = models.vgg19()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
model_dir="/work/models",
)
)
elif "11" in args.model_name:
net = models.vgg11()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
model_dir="/work/models",
)
)
else:
net = models.vgg19_bn()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
model_dir="/work/models",
)
)
args.test_bs = 64
elif args.model_name == "densenet121":
net = models.densenet121()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/densenet121-a639ec97.pth",
model_dir="/work/models",
)
)
args.test_bs = 128
elif args.model_name == "densenet169":
net = models.densenet169()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/densenet169-6f0f7f60.pth",
model_dir="/work/models",
)
)
args.test_bs = 128
elif args.model_name == "densenet201":
net = models.densenet201()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/densenet201-c1103571.pth",
model_dir="/work/models",
)
)
args.test_bs = 64
elif args.model_name == "densenet161":
net = models.densenet161()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/densenet161-8d451a50.pth",
model_dir="/work/models",
)
)
args.test_bs = 64
elif args.model_name == "densenet264":
net = densenet_cosine_264_k48
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/densenet_cosine_264_k48.pth",
model_dir="/work/models",
)
)
args.test_bs = 64
elif args.model_name == "resnet18":
net = models.resnet18()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/resnet18-5c106cde.pth",
model_dir="/work/models",
)
)
args.test_bs = 256
elif args.model_name == "resnet34":
net = models.resnet34()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/resnet34-333f7ec4.pth",
model_dir="/work/models",
)
)
args.test_bs = 128
elif args.model_name == "resnet50":
net = models.resnet50()
net.load_state_dict(
model_zoo.load_url(
"https://download.pytorch.org/models/resnet50-19c8e357.pth",
model_dir="./models",
)
)
args.test_bs = 128
args.prefetch = 4
for p in net.parameters():
p.volatile = True
if args.ngpu > 1:
net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))
if args.ngpu > 0:
net.cuda()
torch.manual_seed(1)
np.random.seed(1)
if args.ngpu > 0:
torch.cuda.manual_seed(1)
net.eval()
cudnn.benchmark = True # fire on all cylinders
print("Model Loaded")
# /////////////// Data Loader ///////////////
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
clean_loader = torch.utils.data.DataLoader(
dset.ImageFolder(
root="/var/sinz-shared/image_classification/ImageNet/val",
transform=trn.Compose(
[
trn.Resize(256),
trn.CenterCrop(224),
trn.ToTensor(),
trn.Normalize(mean, std),
]
),
),
batch_size=args.test_bs,
shuffle=False,
num_workers=args.prefetch,
pin_memory=True,
)
# /////////////// Further Setup ///////////////
def auc(errs): # area under the distortion-error curve
area = 0
for i in range(1, len(errs)):
area += (errs[i] + errs[i - 1]) / 2
area /= len(errs) - 1
return area
# correct = 0
# for batch_idx, (data, target) in enumerate(clean_loader):
# data = V(data.cuda(), volatile=True)
#
# output = net(data)
#
# pred = output.data.max(1)[1]
# correct += pred.eq(target.cuda()).sum()
#
# clean_error = 1 - correct / len(clean_loader.dataset)
# print('Clean dataset error (%): {:.2f}'.format(100 * clean_error))
def show_performance(distortion_name):
errs = []
for severity in range(1, 6):
distorted_dataset = dset.ImageFolder(
root="/var/sinz-shared/image_classification/ImageNet-C/"
+ distortion_name
+ "/"
+ str(severity),
transform=trn.Compose(
[trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean, std)]
),
)
distorted_dataset_loader = torch.utils.data.DataLoader(
distorted_dataset,
batch_size=args.test_bs,
shuffle=False,
num_workers=args.prefetch,
pin_memory=True,
)
correct = 0
for batch_idx, (data, target) in enumerate(distorted_dataset_loader):
with torch.no_grad():
data = V(data.cuda(), volatile=True)
output = net(data)
pred = output.data.max(1)[1]
correct += pred.eq(target.cuda()).sum().item()
print(
f"ACC: {distortion_name} {severity}", 1.0 * correct / len(distorted_dataset)
)
errs.append(1 - 1.0 * correct / len(distorted_dataset))
print("\n=Average", tuple(errs))
return np.mean(errs)
# /////////////// End Further Setup ///////////////
# /////////////// Display Results ///////////////
import collections
print("\nUsing ImageNet data")
distortions = [
"gaussian_noise",
"shot_noise",
"impulse_noise",
"defocus_blur",
"glass_blur",
"motion_blur",
"zoom_blur",
"snow",
"frost",
"fog",
"brightness",
"contrast",
"elastic_transform",
"pixelate",
"jpeg_compression",
# "speckle_noise",
# "gaussian_blur",
# "spatter",
# "saturate",
]
NAME_MAP = {
"brightness": "Brightness",
"contrast": "Contrast",
"defocus_blur": "Defocus Blur",
"elastic_transform": "Elastic Transform",
"fog": "Fog",
"frost": "Frost",
"gaussian_noise": "Gaussian Noise",
"glass_blur": "Glass Blur",
"impulse_noise": "Impulse Noise",
"jpeg_compression": "JPEG Compression",
"motion_blur": "Motion Blur",
"pixelate": "Pixelate",
"shot_noise": "Shot Noise",
"snow": "Snow",
"zoom_blur": "Zoom Blur",
}
CE_ALEXNET = {
"Gaussian Noise": 0.886428,
"Shot Noise": 0.894468,
"Impulse Noise": 0.922640,
"Defocus Blur": 0.819880,
"Glass Blur": 0.826268,
"Motion Blur": 0.785948,
"Zoom Blur": 0.798360,
"Snow": 0.866816,
"Frost": 0.826572,
"Fog": 0.819324,
"Brightness": 0.564592,
"Contrast": 0.853204,
"Elastic Transform": 0.646056,
"Pixelate": 0.717840,
"JPEG Compression": 0.606500,
}
def get_mce_from_accuracy(error, error_alexnet):
"""Computes mean Corruption Error from accuracy"""
ce = error / (error_alexnet * 100.0)
return ce
error_rates = []
normalized_error_rates = []
for distortion_name in distortions:
rate = show_performance(distortion_name)
if distortion_name in NAME_MAP:
CE = get_mce_from_accuracy(rate * 100, CE_ALEXNET[NAME_MAP[distortion_name]])
normalized_error_rates.append(CE)
error_rates.append(rate)
print(
"Distortion: {:15s} | CE (unnormalized) (%): {:.2f}".format(
distortion_name, 100 * rate
)
)
print(
"mCE (unnormalized by AlexNet errors) (%): {:.2f}".format(
100 * np.mean(error_rates)
)
)
print(
"mCE (normalized by AlexNet errors) (%): {:.2f}".format(
100 * np.mean(normalized_error_rates)
)
)
Python Version: 3.10.6 (main, Mar 10 2023, 10:55:28) [GCC 11.3.0]
TORCH VERSION: 2.0.0+cu117
Namespace(model_name='resnet50', ngpu=1)
Model Loaded
Using ImageNet data
ACC: gaussian_noise 1 0.59616
ACC: gaussian_noise 2 0.48556
ACC: gaussian_noise 3 0.31494
ACC: gaussian_noise 4 0.1385
ACC: gaussian_noise 5 0.02998
=Average (0.40384, 0.51444, 0.68506, 0.8614999999999999, 0.97002)
Distortion: gaussian_noise | CE (unnormalized) (%): 68.70
ACC: shot_noise 1 0.5779
ACC: shot_noise 2 0.4456
ACC: shot_noise 3 0.28524
ACC: shot_noise 4 0.09948
ACC: shot_noise 5 0.03702
=Average (0.42210000000000003, 0.5544, 0.7147600000000001, 0.90052, 0.96298)
Distortion: shot_noise | CE (unnormalized) (%): 71.10
ACC: impulse_noise 1 0.48372
ACC: impulse_noise 2 0.38534
ACC: impulse_noise 3 0.28936
ACC: impulse_noise 4 0.1134
ACC: impulse_noise 5 0.0264
=Average (0.5162800000000001, 0.61466, 0.7106399999999999, 0.8866, 0.9736)
Distortion: impulse_noise | CE (unnormalized) (%): 74.04
ACC: defocus_blur 1 0.58998
ACC: defocus_blur 2 0.51766
ACC: defocus_blur 3 0.37874
ACC: defocus_blur 4 0.26482
ACC: defocus_blur 5 0.17904
=Average (0.41002000000000005, 0.48234, 0.6212599999999999, 0.73518, 0.82096)
Distortion: defocus_blur | CE (unnormalized) (%): 61.40
ACC: glass_blur 1 0.53708
ACC: glass_blur 2 0.40148
ACC: glass_blur 3 0.16798
ACC: glass_blur 4 0.12678
ACC: glass_blur 5 0.09736
=Average (0.46292, 0.5985199999999999, 0.83202, 0.87322, 0.90264)
Distortion: glass_blur | CE (unnormalized) (%): 73.39
ACC: motion_blur 1 0.64454
ACC: motion_blur 2 0.54074
ACC: motion_blur 3 0.37664
ACC: motion_blur 4 0.21914
ACC: motion_blur 5 0.1472
=Average (0.35546, 0.45926, 0.62336, 0.78086, 0.8528)
Distortion: motion_blur | CE (unnormalized) (%): 61.43
ACC: zoom_blur 1 0.52188
ACC: zoom_blur 2 0.42396
ACC: zoom_blur 3 0.35032
ACC: zoom_blur 4 0.28254
ACC: zoom_blur 5 0.22456
=Average (0.47812, 0.57604, 0.64968, 0.71746, 0.77544)
Distortion: zoom_blur | CE (unnormalized) (%): 63.93
ACC: snow 1 0.54278
ACC: snow 2 0.31776
ACC: snow 3 0.34816
ACC: snow 4 0.23714
ACC: snow 5 0.1661
=Average (0.45721999999999996, 0.68224, 0.65184, 0.76286, 0.8339)
Distortion: snow | CE (unnormalized) (%): 67.76
ACC: frost 1 0.61088
ACC: frost 2 0.43914
ACC: frost 3 0.31878
ACC: frost 4 0.29648
ACC: frost 5 0.23066
=Average (0.38912, 0.56086, 0.6812199999999999, 0.7035199999999999, 0.76934)
Distortion: frost | CE (unnormalized) (%): 62.08
ACC: fog 1 0.61402
ACC: fog 2 0.55444
ACC: fog 3 0.46244
ACC: fog 4 0.39868
ACC: fog 5 0.24002
=Average (0.38598, 0.44555999999999996, 0.53756, 0.6013200000000001, 0.75998)
Distortion: fog | CE (unnormalized) (%): 54.61
ACC: brightness 1 0.73798
ACC: brightness 2 0.7215
ACC: brightness 3 0.69612
ACC: brightness 4 0.65104
ACC: brightness 5 0.59116
=Average (0.26202000000000003, 0.27849999999999997, 0.30388000000000004, 0.34896000000000005, 0.40884)
Distortion: brightness | CE (unnormalized) (%): 32.04
ACC: contrast 1 0.64406
ACC: contrast 2 0.57972
ACC: contrast 3 0.4563
ACC: contrast 4 0.2036
ACC: contrast 5 0.05378
=Average (0.35594000000000003, 0.42028, 0.5437000000000001, 0.7964, 0.94622)
Distortion: contrast | CE (unnormalized) (%): 61.25
ACC: elastic_transform 1 0.66366
ACC: elastic_transform 2 0.44534
ACC: elastic_transform 3 0.54984
ACC: elastic_transform 4 0.41434
ACC: elastic_transform 5 0.16498
=Average (0.33634, 0.5546599999999999, 0.45016, 0.5856600000000001, 0.83502)
Distortion: elastic_transform | CE (unnormalized) (%): 55.24
ACC: pixelate 1 0.63906
ACC: pixelate 2 0.63736
ACC: pixelate 3 0.4624
ACC: pixelate 4 0.2902
ACC: pixelate 5 0.20878
=Average (0.36094000000000004, 0.36263999999999996, 0.5376000000000001, 0.7098, 0.79122)
Distortion: pixelate | CE (unnormalized) (%): 55.24
ACC: jpeg_compression 1 0.65868
ACC: jpeg_compression 2 0.62392
ACC: jpeg_compression 3 0.59332
ACC: jpeg_compression 4 0.48164
ACC: jpeg_compression 5 0.32634
=Average (0.34131999999999996, 0.37607999999999997, 0.40668000000000004, 0.5183599999999999, 0.6736599999999999)
Distortion: jpeg_compression | CE (unnormalized) (%): 46.32
mCE (unnormalized by AlexNet errors) (%): 60.57
mCE (normalized by AlexNet errors) (%): 76.43
Python Version: 3.8.16 (default, Dec 7 2022, 01:12:06)
[GCC 11.3.0]
TORCH VERSION: 2.0.0+cu117
Namespace(model_name='resnet50', ngpu=1)
Model Loaded
Using ImageNet data
ACC: gaussian_noise 1 0.59442
ACC: gaussian_noise 2 0.462
ACC: gaussian_noise 3 0.27596
ACC: gaussian_noise 4 0.11016
ACC: gaussian_noise 5 0.02208
=Average (0.40558000000000005, 0.538, 0.72404, 0.88984, 0.97792)
Distortion: gaussian_noise | CE (unnormalized) (%): 70.71
ACC: shot_noise 1 0.572
ACC: shot_noise 2 0.42096
ACC: shot_noise 3 0.2505
ACC: shot_noise 4 0.07892
ACC: shot_noise 5 0.02928
=Average (0.42800000000000005, 0.57904, 0.7495, 0.92108, 0.97072)
Distortion: shot_noise | CE (unnormalized) (%): 72.97
ACC: impulse_noise 1 0.4803
ACC: impulse_noise 2 0.35802
ACC: impulse_noise 3 0.25142
ACC: impulse_noise 4 0.08218
ACC: impulse_noise 5 0.01852
=Average (0.5197, 0.64198, 0.74858, 0.91782, 0.98148)
Distortion: impulse_noise | CE (unnormalized) (%): 76.19
ACC: defocus_blur 1 0.5936
ACC: defocus_blur 2 0.5197
ACC: defocus_blur 3 0.3796
ACC: defocus_blur 4 0.26534
ACC: defocus_blur 5 0.17922
=Average (0.4064, 0.48029999999999995, 0.6204000000000001, 0.73466, 0.8207800000000001)
Distortion: defocus_blur | CE (unnormalized) (%): 61.25
ACC: glass_blur 1 0.54042
ACC: glass_blur 2 0.4044
ACC: glass_blur 3 0.1686
ACC: glass_blur 4 0.12768
ACC: glass_blur 5 0.09816
=Average (0.45958, 0.5956, 0.8314, 0.87232, 0.90184)
Distortion: glass_blur | CE (unnormalized) (%): 73.21
ACC: motion_blur 1 0.64666
ACC: motion_blur 2 0.54256
ACC: motion_blur 3 0.37746
ACC: motion_blur 4 0.21926
ACC: motion_blur 5 0.14778
=Average (0.35334, 0.45743999999999996, 0.62254, 0.78074, 0.85222)
Distortion: motion_blur | CE (unnormalized) (%): 61.33
ACC: zoom_blur 1 0.52496
ACC: zoom_blur 2 0.42588
ACC: zoom_blur 3 0.35234
ACC: zoom_blur 4 0.28408
ACC: zoom_blur 5 0.22496
=Average (0.47504, 0.57412, 0.64766, 0.71592, 0.77504)
Distortion: zoom_blur | CE (unnormalized) (%): 63.76
ACC: snow 1 0.54578
ACC: snow 2 0.31946
ACC: snow 3 0.35198
ACC: snow 4 0.24058
ACC: snow 5 0.16894
=Average (0.45421999999999996, 0.6805399999999999, 0.64802, 0.75942, 0.83106)
Distortion: snow | CE (unnormalized) (%): 67.47
ACC: frost 1 0.61276
ACC: frost 2 0.44116
ACC: frost 3 0.32126
ACC: frost 4 0.2989
ACC: frost 5 0.23306
=Average (0.38724000000000003, 0.55884, 0.67874, 0.7011000000000001, 0.76694)
Distortion: frost | CE (unnormalized) (%): 61.86
ACC: fog 1 0.61818
ACC: fog 2 0.55882
ACC: fog 3 0.46642
ACC: fog 4 0.40406
ACC: fog 5 0.24426
=Average (0.38182000000000005, 0.44118, 0.5335799999999999, 0.59594, 0.75574)
Distortion: fog | CE (unnormalized) (%): 54.17
ACC: brightness 1 0.74024
ACC: brightness 2 0.72408
ACC: brightness 3 0.69614
ACC: brightness 4 0.65126
ACC: brightness 5 0.58934
=Average (0.25976, 0.27592000000000005, 0.30386, 0.34874000000000005, 0.41066)
Distortion: brightness | CE (unnormalized) (%): 31.98
ACC: contrast 1 0.64892
ACC: contrast 2 0.58408
ACC: contrast 3 0.46012
ACC: contrast 4 0.20554
ACC: contrast 5 0.05428
=Average (0.35107999999999995, 0.41591999999999996, 0.53988, 0.7944599999999999, 0.94572)
Distortion: contrast | CE (unnormalized) (%): 60.94
ACC: elastic_transform 1 0.6665
ACC: elastic_transform 2 0.44806
ACC: elastic_transform 3 0.55628
ACC: elastic_transform 4 0.42198
ACC: elastic_transform 5 0.16952
=Average (0.3335, 0.55194, 0.44372, 0.57802, 0.83048)
Distortion: elastic_transform | CE (unnormalized) (%): 54.75
ACC: pixelate 1 0.64142
ACC: pixelate 2 0.64064
ACC: pixelate 3 0.46194
ACC: pixelate 4 0.2892
ACC: pixelate 5 0.20608
=Average (0.35858, 0.35936, 0.53806, 0.7108, 0.79392)
Distortion: pixelate | CE (unnormalized) (%): 55.21
ACC: jpeg_compression 1 0.6616
ACC: jpeg_compression 2 0.62466
ACC: jpeg_compression 3 0.59284
ACC: jpeg_compression 4 0.47496
ACC: jpeg_compression 5 0.3165
=Average (0.33840000000000003, 0.37534, 0.40715999999999997, 0.52504, 0.6835)
Distortion: jpeg_compression | CE (unnormalized) (%): 46.59
mCE (unnormalized by AlexNet errors) (%): 60.83
mCE (normalized by AlexNet errors) (%): 76.70
Bootstrap: docker
From: nvidia/cuda:11.7.0-devel-ubuntu22.04
%labels
MAINTAINER Arne Nix <arnenix@gmail.com>
%files
%environment
export TZ=Europe/Berlin
export LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:${LD_LIBRARY_PATH}
%post
ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
apt-get update &&\
apt-get install -y software-properties-common
add-apt-repository ppa:deadsnakes/ppa &&\
apt-get install -y \
build-essential \
curl \
zlib1g-dev \
pkg-config \
libgl-dev \
libblas-dev \
liblapack-dev \
python3-tk \
python3-wheel \
graphviz \
libhdf5-dev \
python3 \
python3-dev \
python3-distutils \
swig \
apt-transport-https \
lsb-release \
ca-certificates &&\
apt-get clean &&\
ln -s /usr/bin/python3.8 /usr/local/bin/python &&\
ln -s /usr/bin/python3.8 /usr/local/bin/python3 &&\
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py &&\
python3 get-pip.py &&\
rm get-pip.py # &&\
python3 -m pip --no-cache-dir install \
numpy \
scipy \
Pillow==8.0.1
python3 -m pip --no-cache-dir install \
torch==2.0.0+cu117 \
--extra-index-url https://download.pytorch.org/whl/cu117 \
torchvision \
torchaudio
%startscript
exec "$@"
%runscript
exec "$@"
Bootstrap: docker
From: nvidia/cuda:11.7.0-devel-ubuntu22.04
#From: nvidia/cuda:11.1.1-cudnn8-devel-ubuntu18.04
%labels
MAINTAINER Arne Nix <arnenix@gmail.com>
%files
%environment
export TZ=Europe/Berlin
export LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:${LD_LIBRARY_PATH}
%post
ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
apt-get update &&\
apt-get install -y software-properties-common
add-apt-repository ppa:deadsnakes/ppa &&\
apt-get install -y \
build-essential \
curl \
zlib1g-dev \
pkg-config \
libgl-dev \
libblas-dev \
liblapack-dev \
python3-tk \
python3-wheel \
graphviz \
libhdf5-dev \
python3.8 \
python3.8-dev \
python3.8-distutils \
swig \
apt-transport-https \
lsb-release \
ca-certificates &&\
apt-get clean &&\
ln -s /usr/bin/python3.8 /usr/local/bin/python &&\
ln -s /usr/bin/python3.8 /usr/local/bin/python3 &&\
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py &&\
python3 get-pip.py &&\
rm get-pip.py # &&\
python3 -m pip --no-cache-dir install \
numpy \
scipy \
Pillow==8.0.1
python3 -m pip --no-cache-dir install \
torch==2.0.0+cu117 \
--extra-index-url https://download.pytorch.org/whl/cu117 \
torchvision \
torchaudio
%startscript
exec "$@"
%runscript
exec "$@"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment