Skip to content

Instantly share code, notes, and snippets.

View fujiyuu75's full-sized avatar

fujiyuu75 fujiyuu75

View GitHub Profile
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
def freeze_network(model):
for name, p in model.named_parameters():
if "l0" not in name:
p.requires_grad = False
return model
---
model = freeze_network(model)
class EarlyStopping:
def __init__(self, patience=7, mode="max", delta=0.001):
self.patience = patience
self.counter = 0
self.mode = mode
self.best_score = None
self.early_stop = False
self.delta = delta
if self.mode == "min":
self.val_score = np.Inf
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from skimage.measure import compare_ssim
import imutils
import cv2
from tqdm import tqdm
TRAIN_VIDEOS='../input/deepfake-detection-challenge/train_sample_videos'
TEST_VIDEOS='../input/deepfake-detection-challenge/test_videos'
ALL_VIDEOS='../input/dfdc_train_all'
### Model Visualizations
def normalize_channels(img):
_min, _max = img.min(axis=(0,1)), img.std(axis=(0,1))
img = (img - _min) / (_max - _min)
return img
def plot_first_kernels(weight):
''' plot first filters of a model '''
with torch.no_grad():
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
## Memory optimization
# Original code from https://www.kaggle.com/gemartin/load-data-reduce-memory-usage by @gemartin
# Modified to support timestamp type, categorical type
# Modified to add option to use float16
from pandas.api.types import is_datetime64_any_dtype as is_datetime
from pandas.api.types import is_categorical_dtype
def reduce_mem_usage(df, use_float16=False):
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])
plt.show()
plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.grid(False)
plt.show()