This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def freeze_network(model): | |
for name, p in model.named_parameters(): | |
if "l0" not in name: | |
p.requires_grad = False | |
return model | |
--- | |
model = freeze_network(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EarlyStopping: | |
def __init__(self, patience=7, mode="max", delta=0.001): | |
self.patience = patience | |
self.counter = 0 | |
self.mode = mode | |
self.best_score = None | |
self.early_stop = False | |
self.delta = delta | |
if self.mode == "min": | |
self.val_score = np.Inf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np # linear algebra | |
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
from skimage.measure import compare_ssim | |
import imutils | |
import cv2 | |
from tqdm import tqdm | |
TRAIN_VIDEOS='../input/deepfake-detection-challenge/train_sample_videos' | |
TEST_VIDEOS='../input/deepfake-detection-challenge/test_videos' | |
ALL_VIDEOS='../input/dfdc_train_all' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Model Visualizations | |
def normalize_channels(img): | |
_min, _max = img.min(axis=(0,1)), img.std(axis=(0,1)) | |
img = (img - _min) / (_max - _min) | |
return img | |
def plot_first_kernels(weight): | |
''' plot first filters of a model ''' | |
with torch.no_grad(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def seed_everything(seed): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Memory optimization | |
# Original code from https://www.kaggle.com/gemartin/load-data-reduce-memory-usage by @gemartin | |
# Modified to support timestamp type, categorical type | |
# Modified to add option to use float16 | |
from pandas.api.types import is_datetime64_any_dtype as is_datetime | |
from pandas.api.types import is_categorical_dtype | |
def reduce_mem_usage(df, use_float16=False): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.figure(figsize=(10,10)) | |
for i in range(25): | |
plt.subplot(5,5,i+1) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.grid(False) | |
plt.imshow(train_images[i], cmap=plt.cm.binary) | |
plt.xlabel(class_names[train_labels[i]]) | |
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.figure() | |
plt.imshow(train_images[0]) | |
plt.colorbar() | |
plt.grid(False) | |
plt.show() |
NewerOlder