Skip to content

Instantly share code, notes, and snippets.

View justusschock's full-sized avatar

Justus Schock justusschock

View GitHub Profile
@justusschock
justusschock / sam.py
Created February 28, 2023 13:16
PL usage with SAM optimizer
# copied from https://github.com/davda54/sam
import torch
class SAM(torch.optim.Optimizer):
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)
@justusschock
justusschock / lazy_cached_eval.py
Created May 22, 2022 15:08
LazyEvaluateCached
import inspect
from typing import Any
class LazyEvaluateCached:
"""Boolean-like class to check arbitrary commands lazily.
>>> LazyEvaluateCached("[4, 5]")
LazyEvaluateCached([4, 5])
>> LazyEvaluateCached("[4, 5]")()
[4, 5]
@justusschock
justusschock / get_invalid_ips_docker_htcondor.py
Last active September 10, 2021 10:38
Filters invalid ips in a htcondor cluster based on log files where user has no permission to connect to docker (in my case due to group id missmatch).
import os
from tqdm import tqdm
def get_invalid_ips_docker_htcondor(files, only_ips: bool = True):
invalid_ips = {}
for f in tqdm(files):
with open(f) as efs:
if 'docker: Got permission denied while trying to connect to the Docker daemon socket' in efs.read():
with open(f.replace('_err.log', '_log.log')) as lfs:
for l in lfs.readlines():
@justusschock
justusschock / metrics_usage_without_lightning.py
Created June 21, 2020 11:31
Example to show metric usage outside PyTorch lightning!
import torch
from pytorch_lightning.metrics.functional import accuracy
# our mocked predictions - for real usage drop in your fancy method here!
pred = torch.randint(10, (200,))
target = torch.randint(5, (200,))
print(accuracy(pred, target))
@justusschock
justusschock / rmse_lightning_modular.py
Last active June 21, 2020 12:17
Modular interface for metrics in PyTorch Lightning
import torch
from pytorch_lightning.metrics import TensorMetric
def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.sqrt(torch.mean(torch.pow(pred-target, 2.0)))
class RMSE(TensorMetric):
def forward(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
@justusschock
justusschock / rmse_lightning.py
Last active June 21, 2020 12:17
More advanced PyTorch Lightning implementation of RMSE Metric.
import torch
from pytorch_lightning.metrics import tensor_metric
@tensor_metric()
def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.sqrt(torch.mean(torch.pow(pred-target, 2.0)))
@justusschock
justusschock / rmse_plain_functional.py
Last active June 21, 2020 12:18
Plain functional example of a metric (RMSE)
import torch
def rmse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.sqrt(torch.mean(torch.pow(pred-target, 2.0)))
@justusschock
justusschock / rising_loading.py
Created May 24, 2020 10:04
DataLoading with rising
from rising.loading import DataLoader
from .dataset import DummyDataset
dset = DummyDataset(length=500, transforms=None)
loader = DataLoader(dset, num_workers=4, shuffle=True, batch_size=10)
for batch in loader:
print(batch['data'].shape)
print(batch['label'])
@justusschock
justusschock / plain_pytorch_loading.py
Created May 24, 2020 10:01
Data Loading with Plain PyTorch
from torch.utils.data import DataLoader
from .dataset import DummyDataset
dset = DummyDataset(length=500, transforms=None)
loader = DataLoader(dset, num_workers=4, shuffle=True, batch_size=10)
for batch in loader:
print(batch['data'].shape)
print(batch['label'])
@justusschock
justusschock / dataset.py
Last active May 24, 2020 10:03
PyTorch Dummy Dataset
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self, length: int, transforms=None):
self.length = length
self.transforms = transforms
def __getitem__(self, idx: int):
# random image shape with 1 channel and 3 spatial dimensions
img = torch.rand(1, 224, 224, 224)