This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def unilm_mask(inputs, s): | |
idxs = torch.cumsum(s, dim=1) | |
mask = idxs[:, None, :] <= idxs[:, :, None] | |
mask = mask[:, None].squeeze(1) | |
return mask.to(dtype=torch.int64) | |
def create_lm_mask(attention_mask, direction='l2r'): | |
seq_len = attention_mask.size(-1) | |
if attention_mask.ndim == 2: | |
attention_mask = attention_mask.view(-1, 1, seq_len) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
>>> x = torch.randn(3, 2) | |
>>> y = torch.ones(3, 2) | |
>>> x | |
tensor([[-0.4620, 0.3139], | |
[ 0.3898, -0.7197], | |
[ 0.0478, -0.1657]]) | |
>>> torch.where(x > 0, x, y) | |
tensor([[ 1.0000, 0.3139], | |
[ 0.3898, 1.0000], | |
[ 0.0478, 1.0000]]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model = Model() | |
# 这里是一般情况,共享层往往不止一层,所以做一个for循环 | |
for para in model.linear1.parameters(): | |
para.requires_grad = False | |
# 假如真的只有一层也可以这样操作: | |
# model.linear1.weight.requires_grad = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CASE 1: A single test dataset | |
def test_step(self, batch, batch_idx): | |
x, y = batch | |
# implement your own | |
out = self(x) | |
loss = self.loss(out, y) | |
# log 6 example images | |
# or generated text... or whatever |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function throttle( fn, time ) { | |
var t = 0; | |
return function() { | |
var args = arguments, ctx = this; | |
clearTimeout(t); | |
t = setTimeout( function() { | |
fn.apply( ctx, args ); | |
}, time ); | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.config = AutoConfig.from_pretrained(MODEL_NAME) | |
# tokenizer = BertTokenizer.from_pretrained(tokenizer_MODEL_NAME) | |
self.model = AutoModel.from_pretrained(MODEL_NAME) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch | |
class ConvLSTMCell(nn.Module): | |
def __init__(self, input_dim, hidden_dim, kernel_size, bias): | |
""" | |
Initialize ConvLSTM cell. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from models.ConvLSTMCell import ConvLSTMCell | |
class EncoderDecoderConvLSTM(nn.Module): | |
def __init__(self, nf, in_chan): | |
super(EncoderDecoderConvLSTM, self).__init__() | |
""" ARCHITECTURE |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.