Skip to content

Instantly share code, notes, and snippets.

View shandilya1998's full-sized avatar

Shreyas Shandilya shandilya1998

View GitHub Profile
" Don't try to be vi compatible
set nocompatible
" Helps force plugins to load correctly when it is turned back on below
filetype off
" TODO: Load plugins here (pathogen or vundle)
" Turn on syntax highlighting
syntax on
@shandilya1998
shandilya1998 / gradient_accumulation.py
Created May 3, 2021 20:52 — forked from thomwolf/gradient_accumulation.py
PyTorch gradient accumulation training loop
model.zero_grad() # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs) # Forward pass
loss = loss_function(predictions, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optimizer.step() # Now we can do an optimizer step
model.zero_grad() # Reset gradients tensors
if (i+1) % evaluation_steps == 0: # Evaluate the model when we...
@shandilya1998
shandilya1998 / imagenet1000_clsidx_to_labels.txt
Created December 20, 2020 13:03 — forked from yrevar/imagenet1000_clsidx_to_labels.txt
text: imagenet 1000 class idx to human readable labels (Fox, E., & Guestrin, C. (n.d.). Coursera Machine Learning Specialization.)
{0: 'tench, Tinca tinca',
1: 'goldfish, Carassius auratus',
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
3: 'tiger shark, Galeocerdo cuvieri',
4: 'hammerhead, hammerhead shark',
5: 'electric ray, crampfish, numbfish, torpedo',
6: 'stingray',
7: 'cock',
8: 'hen',
9: 'ostrich, Struthio camelus',
@shandilya1998
shandilya1998 / patches.py
Created September 13, 2019 06:25 — forked from dwf/patches.py
Some patch extraction code I'm using to process images.
import os
import numpy as np
import scipy.ndimage as ndimage
import matplotlib
import matplotlib.pyplot as plt
def frac_eq_to(image, value=0):
return (image == value).sum() / float(np.prod(image.shape))
def extract_patches(image, patchshape, overlap_allowed=0.5, cropvalue=None,