Consider this blog post model:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from bs4 import BeautifulSoup | |
from markdown import markdown | |
def markdown_to_text(markdown_string): | |
""" Converts a markdown string to plaintext """ | |
# md -> html -> text since BeautifulSoup can extract text cleanly | |
html = markdown(markdown_string) | |
# remove code snippets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Script for fine-tuning Pegasus | |
Example usage: | |
# use XSum dataset as example, with first 1000 docs as training data | |
from datasets import load_dataset | |
dataset = load_dataset("xsum") | |
train_texts, train_labels = dataset['train']['document'][:1000], dataset['train']['summary'][:1000] | |
# use Pegasus Large model as base for fine-tuning | |
model_name = 'google/pegasus-large' | |
train_dataset, _, _, tokenizer = prepare_data(model_name, train_texts, train_labels) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModel | |
def mean_pooling(model_output, attention_mask): | |
""" | |
Mean pooling to get sentence embeddings. See: | |
https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1 | |
""" | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # Sum columns |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import numpy as np | |
class MyDataset(Dataset): | |
def __init__(self): | |
x = np.random.rand(1000, 3) # 1000 3-dim samples | |
self.x = [x[i].tolist() for i in range(1000)] | |
y = np.random.randint(low=0, high=2, size=(1000,)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Hack to add per-session state to Streamlit. | |
Usage | |
----- | |
>>> import SessionState | |
>>> | |
>>> session_state = SessionState.get(user_name='', favorite_color='black') | |
>>> session_state.user_name | |
'' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import faiss | |
import numpy as np | |
class FaissKMeans: | |
def __init__(self, n_clusters=8, n_init=10, max_iter=300): | |
self.n_clusters = n_clusters | |
self.n_init = n_init | |
self.max_iter = max_iter | |
self.kmeans = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import torch | |
class ConvLSTMCell(nn.Module): | |
def __init__(self, input_dim, hidden_dim, kernel_size, bias): | |
""" | |
Initialize ConvLSTM cell. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from models.ConvLSTMCell import ConvLSTMCell | |
class EncoderDecoderConvLSTM(nn.Module): | |
def __init__(self, nf, in_chan): | |
super(EncoderDecoderConvLSTM, self).__init__() | |
""" ARCHITECTURE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pytorch_lightning as pl | |
class MyTransformer(pl.LightningModule): | |
def __init__( | |
self, | |
learning_rate=0.001, | |
warmup=4000, | |
): | |
self.learning_rate = learning_rate |
NewerOlder