Skip to content

Instantly share code, notes, and snippets.

@marcossilva
Last active December 9, 2022 20:43
Show Gist options
  • Save marcossilva/4025519bda136d0cd202bef936b8d560 to your computer and use it in GitHub Desktop.
Save marcossilva/4025519bda136d0cd202bef936b8d560 to your computer and use it in GitHub Desktop.
Download weights from GECToR pre-trained model
from pathlib import Path
import requests
from tqdm import tqdm
TEST_FIXTURES_DIR_PATH = "test_fixtures"
MODEL_URL = "https://grammarly-nlp-data-public.s3.amazonaws.com/gector/roberta_1_gectorv2.th"
def download_weights():
model_path = Path("test_fixtures/roberta_model/weights.th")
if not model_path.exists():
response = requests.get(MODEL_URL)
with model_path.open("wb") as out_fp:
# Write out data with progress bar
for data in tqdm(response.iter_content()):
out_fp.write(data)
assert model_path.exists()
return model_path
download_weights()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment