Skip to content

Instantly share code, notes, and snippets.

View nbroad1881's full-sized avatar

Nicholas Broad nbroad1881

  • Hugging Face
  • San Francisco, California
  • 16:14 (UTC -07:00)
View GitHub Profile
import torch
from torch import nn
class MultiSampleDropout(nn.Module):
def __init__(self, dropout_probs, problem_type, num_labels) -> None:
super().__init__()
self.dropouts = [nn.Dropout(p=p) for p in dropout_probs]
self.problem_type = problem_type
@nbroad1881
nbroad1881 / summarize.js
Last active June 30, 2022 21:12 — forked from feconroses/gist:302474ddd3f3c466dc069ecf16bb09d7
Add summarization to Google Sheets using HuggingFace's API
function SUMMARIZE(input, repo_id="google/pegasus-xsum", use_gpu=false) {
// other models to consider
// short sequences
// sshleifer/distilbart-cnn-12-6
// knkarthick/MEETING_SUMMARY
// long sequences
// google/bigbird-pegasus-large-bigpatent
@nbroad1881
nbroad1881 / deberta_mlm.py
Last active August 17, 2022 22:47
Implementation that makes use of the pretrained weights for Deberta for Masked Language Modeling.
from typing import Any, Optional, Union, Tuple
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.models.deberta.modeling_deberta import (
DebertaPreTrainedModel,
DebertaModel,
)
from transformers.models.deberta_v2.modeling_deberta_v2 import (
# Generic LM
roberta-base
roberta-large
microsoft/deberta-v3-base
microsoft/deberta-v3-large
microsoft/deberta-v3-xsmall
# Long LM
allenai/longformer-base-4096
google/bigbird-roberta-base
@nbroad1881
nbroad1881 / mlflow_tracker.py
Last active September 19, 2022 17:00
MLflow tracker for accelerate, with a check for Azure ML. Azure ML cannot log more than 100 params
import os
import json
from typing import Optional, Any, Union, Dict
import mlflow
from transformers import TrainingArguments
from accelerate.tracking import GeneralTracker
from accelerate.logging import get_logger
@nbroad1881
nbroad1881 / test_mlm.py
Last active September 22, 2022 23:35
Quickly test how a Masked LM will do on texts.
import argparse
from itertools import chain
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
if __name__ == "__main__":