Skip to content

Instantly share code, notes, and snippets.

@zphang
Created October 16, 2019 18:46
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zphang/8eb4717b6f74c82a8ca4637ae9236e21 to your computer and use it in GitHub Desktop.
Save zphang/8eb4717b6f74c82a8ca4637ae9236e21 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
import transformers.modeling_bert as modeling_bert
import nlpr.shared.torch_utils as torch_utils
import nlpr.shared.model_resolution as model_resolution
import pyutils.io as io
import pyutils.datastructures as datastructures
DEFAULT_ADAPTER_SIZE = 64
DEFAULT_ADAPTER_INITIALIZER_RANGE = 0.0002
@dataclass
class AdapterConfig:
hidden_act: str = "gelu"
adapter_size: int = 64
adapter_initializer_range: float = 0.0002
class Adapter(nn.Module):
def __init__(self, hidden_size: int, adapter_config: AdapterConfig):
super(Adapter, self).__init__()
self.hidden_size = hidden_size
self.adapter_config = adapter_config
self.down_project = nn.Linear(
self.hidden_size,
self.adapter_config.adapter_size,
)
self.activation = modeling_bert.ACT2FN[self.adapter_config.hidden_act] \
if isinstance(self.adapter_config.hidden_act, str) else self.adapter_config.hidden_act
self.up_project = nn.Linear(self.adapter_config.adapter_size, self.hidden_size)
self.init_weights()
def forward(self, hidden_states):
down_projected = self.down_project(hidden_states)
activated = self.activation(down_projected)
up_projected = self.up_project(activated)
return hidden_states + up_projected
def init_weights(self):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
self.down_project.weight.data.normal_(mean=0.0, std=self.adapter_config.adapter_initializer_range)
self.down_project.bias.data.zero_()
self.up_project.weight.data.normal_(mean=0.0, std=self.adapter_config.adapter_initializer_range)
self.up_project.bias.data.zero_()
class BertOutputWithAdapters(nn.Module):
def __init__(self, dense, adapter, layer_norm, dropout):
super(BertOutputWithAdapters, self).__init__()
self.dense = dense
self.adapter = adapter
self.LayerNorm = layer_norm
self.dropout = dropout
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
@classmethod
def from_original(cls, old_module, adapter_config: AdapterConfig):
assert isinstance(old_module, modeling_bert.BertOutput)
return cls(
dense=old_module.dense,
adapter=Adapter(
hidden_size=old_module.dense.out_features,
adapter_config=adapter_config,
),
layer_norm=old_module.LayerNorm,
dropout=old_module.dropout,
)
class BertSelfOutputWithAdapters(nn.Module):
def __init__(self, dense, adapter, layer_norm, dropout):
super(BertSelfOutputWithAdapters, self).__init__()
self.dense = dense
self.adapter = adapter
self.LayerNorm = layer_norm
self.dropout = dropout
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.adapter(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
@classmethod
def from_original(cls, old_module, adapter_config: AdapterConfig):
assert isinstance(old_module, modeling_bert.BertSelfOutput)
return cls(
dense=old_module.dense,
adapter=Adapter(
hidden_size=old_module.dense.out_features,
adapter_config=adapter_config,
),
layer_norm=old_module.LayerNorm,
dropout=old_module.dropout,
)
def add_adapters(model, adapter_config):
modified = {}
for p_name, p_module, c_name, c_module in torch_utils.get_parent_child_module_list(model):
model_architecture = model_resolution.ModelArchitectures.from_ptt_model(model)
if model_architecture in [model_resolution.ModelArchitectures.BERT,
model_resolution.ModelArchitectures.ROBERTA]:
if isinstance(c_module, modeling_bert.BertOutput):
new_module = BertOutputWithAdapters.from_original(
old_module=c_module,
adapter_config=adapter_config,
)
setattr(p_module, c_name, new_module)
modified[f"{p_name}.{c_name}"] = new_module
elif isinstance(c_module, modeling_bert.BertSelfOutput):
new_module = BertSelfOutputWithAdapters.from_original(
old_module=c_module,
adapter_config=adapter_config,
)
setattr(p_module, c_name, new_module)
modified[f"{p_name}.{c_name}"] = new_module
else:
raise KeyError(model_architecture)
return modified
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment