Skip to content

Instantly share code, notes, and snippets.

View pmelchior's full-sized avatar

Peter Melchior pmelchior

View GitHub Profile
@pmelchior
pmelchior / pytorch-adaprox.py
Last active June 16, 2022 03:23
Proximal Adam for pytorch
from torch.optim import Optimizer
import math
import torch
from torch import Tensor
from typing import List, Optional, Callable
def adaprox(params: List[Tensor],
proxes: List[Callable[[Tensor, float], Tensor]],
grads: List[Tensor],
exp_avgs: List[Tensor],
@pmelchior
pmelchior / get_source_code.py
Last active May 20, 2022 03:08
Source code for python class instance
import inspect
from IPython.display import Code
def get_source_code(obj, display=False):
if inspect.isclass(obj):
this_class = obj
else:
# get class from instance
this_class = type(obj)
@pmelchior
pmelchior / cprof2png.py
Last active June 18, 2019 20:19
Profile python code and plot inline chart
import cProfile, os, tempfile
from IPython.display import Image
def cprof2png(command):
# create tmp file to store cprof
temp = tempfile.NamedTemporaryFile()
# run profiler
cProfile.run(command, temp.name)
# parse to dot and make png figure from temp
os.system(f"gprof2dot -f pstats {temp.name} | dot -Tpng -o {temp.name}")
@pmelchior
pmelchior / pytorch_pgm.py
Created December 30, 2018 23:23
Proximal Gradient Method for pytorch (minimal extension of pytorch.optim.SGD)
from torch.optim.sgd import SGD
from torch.optim.optimizer import required
class PGM(SGD):
def __init__(self, params, proxs, lr=required, momentum=0, dampening=0,
nesterov=False):
kwargs = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=0, nesterov=nesterov)
super().__init__(params, **kwargs)
if len(proxs) != len(self.param_groups):
raise ValueError("Invalid length of argument proxs: {} instead of {}".format(len(proxs), len(self.param_groups)))