Skip to content

Instantly share code, notes, and snippets.

@lawlesst

lawlesst/colab_setup.py Secret

Last active Jun 25, 2020
Embed
What would you like to do?
Colab client for TDM
!pip install nltk gensim pyLDAvis
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
def setup_tdm():
import importlib
import requests
gist = "https://api.github.com/gists/9ccb340f15c1aab6846983738cffb4cc"
rsp = requests.get(gist)
url = rsp.json()['files']['tdm_client.py']['raw_url']
with open("/content/tdm_client.py", "w") as of:
rsp = requests.get(url)
of.write(rsp.text)
import tdm_client
_ = importlib.reload(tdm_client)
setup_tdm()
"""
Quick TDM client for demos.
Snipped to load in Colab:
!pip install nltk gensim pyLDAvis
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
def setup_tdm():
import importlib
import requests
gist = "https://api.github.com/gists/9ccb340f15c1aab6846983738cffb4cc"
rsp = requests.get(gist)
url = rsp.json()['files']['tdm_client.py']['raw_url']
with open("/content/tdm_client.py", "w") as of:
rsp = requests.get(url)
of.write(rsp.text)
import tdm_client
_ = importlib.reload(tdm_client)
setup_tdm()
"""
import concurrent.futures
import copy
import gzip
import json
import logging
from pathlib import Path
import os
import threading
import time
import urllib.request
import urllib
import zlib
import requests
version = 2.9
DEV_ENV = False
if DEV_ENV is True:
PROD_SERVICE = 'http://localhost:5000/'
else:
PROD_SERVICE = 'https://www.jstor.org/api/tdm/v1/'
if DEV_ENV is True:
DOWNLOAD_WORKERS = 1
else:
DOWNLOAD_WORKERS = 5
# Try to use a shared cache directory if it exists, otherwise
# create cache dir in home directory. On colab, the cache dir
# will always be /content/tdm_data
shared_cache_dir = '/var/cache/tdm-data'
if os.environ.get("COLAB_GPU") is not None:
cache_dir = "/content/tdm_data"
elif os.path.exists(shared_cache_dir) is True:
cache_dir = shared_cache_dir
else:
home = str(Path.home())
cache_dir = os.path.join(home, 'tdm-data')
if os.path.exists(cache_dir) is not True:
os.mkdir(cache_dir)
def get_token():
qargs = {
"appName": "tdm",
}
headers = {
'Content-type': 'application/json',
'Accept': 'application/json',
'User-Agent': "tdm-colab-client"
}
rsp = requests.get("https://www.jstor.org/api/labs-jwt-service/iac-jwt", params=qargs, headers=headers)
token = rsp.json()["jwt"]
return token
class FileCache(object):
def __init__(self, **kwargs):
self.path = kwargs.get('path', self.default_cache_dir)
@property
def default_cache_dir(self):
"""
Cache files in user's home directory in the
.tdm directory
"""
return cache_dir
def file_path(self, key):
fn = key.replace('http://', '')\
.replace('www', '')\
.replace('://', '-')\
.replace('/', '-')\
.replace('.', '-')\
.lstrip('-')
prefix = fn.split('-')[0]
if os.path.exists(self.path) is not True:
os.mkdir(self.path)
sub_dir = os.path.join(self.path, prefix)
if os.path.exists(sub_dir) is not True:
os.mkdir(sub_dir)
return os.path.join(sub_dir, fn)
def __getitem__(self, key):
fp = self.file_path(key)
with open(fp, 'rb') as inf:
raw = zlib.decompress(inf.read())
return json.loads(raw)
def __setitem__(self, key, value):
fp = self.file_path(key)
with open(fp, 'wb') as outf:
raw = json.dumps(value)
c = zlib.compress(raw.encode('utf-8'))
outf.write(c)
def __delitem__(self, key):
fp = self.file_path(key)
return os.remove(fp)
def get(self, key):
try:
return self.__getitem__(key)
except (FileNotFoundError, zlib.error):
return None
def set(self, key):
return self.__setitem__(key)
class Dataset(object):
def __init__(self, bundle_id, token=None, cache=True):
self.id = bundle_id
self.auth_token = get_token()
self.request_headers = {
'Authorization': 'JWT {}'.format(self.auth_token),
'User-Agent': 'TDM',
'Accept-Encoding': 'gzip'
}
self.cache = cache
self.spec = self._load()
self.available = []
self.items = self.spec['documents']
self.gzip_path = os.path.join(cache_dir, "{}.jsonl.gz".format(self.id))
self.queued_items = copy.copy(self.spec['documents'])
self.num_docs = len(self.items)
if cache_dir is None:
self.file_cache = {}
else:
fc = FileCache()
self.file_cache = fc
_ = self._download_gzip()
# Background thread that downloads the documents.
# See: https://gist.github.com/sebdah/832219525541e059aefa
# thread = threading.Thread(target=self._populate, args=(), name='tdm-download', daemon=True)
# thread.start()
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
def close(self):
pass
def __len__(self):
return self.num_docs
def get_features2(self):
fetched = 0
while True:
if len(self.available) == 0:
time.sleep(.25)
else:
for url in self.available:
yield self._fetch(url)
fetched += 1
if fetched >= self.num_docs:
break
def get_features(self):
out = []
with gzip.open(self.gzip_path, "rb") as inf:
for row in inf:
doc = json.loads(row.decode("utf-8"))
uc = doc.get("unigramCount")
if uc is None:
continue
else:
yield uc
def get_metadata2(self):
url = PROD_SERVICE + 'nb/dataset/{}/metadata/'.format(self.id)
rsp = requests.get(url, headers=self.request_headers)
if rsp.status_code != 200:
raise Exception("Can't load metadata for dataset {}.".format(self.id))
else:
return rsp.json()
def get_metadata(self):
out = []
with gzip.open(self.gzip_path, "rb") as inf:
for row in inf:
doc = json.loads(row.decode("utf-8"))
try:
del doc["unigramCount"]
except KeyError:
pass
out.append(doc)
return out
def get_feature(self, url):
return self._fetch(url)
def _load(self):
url = PROD_SERVICE + "nb/dataset/{}/info/".format(self.id)
d = requests.get(url, headers=self.request_headers)
meta = d.json()
if meta.get('message') is not None:
raise Exception(json.dumps(meta))
return meta
def _download_gzip(self, force=False):
if (force is False) and (os.path.exists(self.gzip_path)):
pass
else:
_ = urllib.request.urlretrieve(self.spec["download_url"], self.gzip_path)
def _multiple_populate(self):
# multiple background download workers
with concurrent.futures.ThreadPoolExecutor(max_workers=DOWNLOAD_WORKERS) as executor:
future_to_url = {executor.submit(self._fetch, url): url for url in self.items}
for future in concurrent.futures.as_completed(future_to_url):
url = future_to_url[future]
self.available.append(url)
def _populate(self):
# single background thread for downloading features
if hasattr(self, 'items'):
for url in self.items:
_ = self._fetch(url)
self.available.append(url)
def query(self):
return self.spec['query']
def query_text(self):
"""
Plain language search string.
"""
return self.spec["search_description"]
def _fetch(self, url):
"""
Returns dictionary of extracted features content.
"""
logging.info("Fetching {}".format(url))
if self.cache is True:
from_cache = self.file_cache.get(url)
else:
from_cache = None
if from_cache is not None:
logging.info("{} found in cache.".format(url))
return from_cache
else:
service_url = PROD_SERVICE + "nb/dataset/features/"
rsp = requests.post(
service_url,
data=json.dumps({"id": url}),
headers=self.request_headers
)
try:
raw = rsp.json()
self.file_cache[url] = raw
except json.JSONDecodeError:
logging.debug("Unable to load features: {}".format(url))
raw = {}
return raw
class LocalDataset(Dataset):
def get(self, name=None, force=False):
path = os.path.join(os.getcwd(), "datasets")
if not os.path.exists(path):
os.mkdir(path)
gz_file = "{}.jsonl.gz".format(self.id)
gzip_path = os.path.join(path, gz_file)
if name is None:
unzip_file = gz_file.replace(".gz", "")
else:
unzip_file = "{}.jsonl".format(name)
unzip_path = os.path.join(path, unzip_file)
# download it
_ = urllib.request.urlretrieve(self.spec["download_url"], gzip_path)
# unzip to file with jsonl extension
with open(unzip_path, "w") as f_out:
with gzip.open(gzip_path, "rb") as f_in:
for row in f_in:
f_out.write(row.decode("utf-8"))
os.remove(gzip_path)
return os.path.join("datasets", unzip_file)
def get_dataset(dataset_id, name=None, force=False):
return LocalDataset(dataset_id).get(name=name, force=force)
## HTRC Stopwords and corrections
CORRECTIONS_URL = 'http://data.analytics.hathitrust.org/data/text/default_corrections.txt'
STOPWORDS_URL = 'http://data.analytics.hathitrust.org/data/text/default_stopwords_en.txt'
def htrc_fetch(url):
response = urllib.request.urlopen(url)
data = response.read()
text = data.decode('utf-8')
return text
def get_corrections():
fc = FileCache()
_key = "htrc-corrections"
cv = fc.get(_key)
if cv is not None:
return cv
else:
raw = htrc_fetch(CORRECTIONS_URL)
d = {}
for n, row in enumerate(raw.splitlines()):
if n == 1:
continue
token, corr = row.split(',')
d[token] = corr
fc[_key] = d
return d
def get_stopwords():
raw = htrc_fetch(STOPWORDS_URL)
fc = FileCache()
_key = "htrc-stopwords"
cv = fc.get(_key)
if cv is not None:
return cv
else:
stopwords = []
for n, row in enumerate(raw.splitlines()):
token = row.strip()
stopwords.append(token)
fc[_key] = stopwords
return stopwords
try:
htrc_corrections = get_corrections()
except Exception:
htrc_corrections = {}
try:
htrc_stopwords = get_stopwords()
except Exception:
htrc_stopwords = []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment