Colab client for TDM
!pip install nltk gensim pyLDAvis | |
import nltk | |
nltk.download('stopwords') | |
nltk.download('wordnet') | |
def setup_tdm(): | |
import importlib | |
import requests | |
gist = "https://api.github.com/gists/9ccb340f15c1aab6846983738cffb4cc" | |
rsp = requests.get(gist) | |
url = rsp.json()['files']['tdm_client.py']['raw_url'] | |
with open("/content/tdm_client.py", "w") as of: | |
rsp = requests.get(url) | |
of.write(rsp.text) | |
import tdm_client | |
_ = importlib.reload(tdm_client) | |
setup_tdm() |
""" | |
Quick TDM client for demos. | |
Snipped to load in Colab: | |
!pip install nltk gensim pyLDAvis | |
import nltk | |
nltk.download('stopwords') | |
nltk.download('wordnet') | |
def setup_tdm(): | |
import importlib | |
import requests | |
gist = "https://api.github.com/gists/9ccb340f15c1aab6846983738cffb4cc" | |
rsp = requests.get(gist) | |
url = rsp.json()['files']['tdm_client.py']['raw_url'] | |
with open("/content/tdm_client.py", "w") as of: | |
rsp = requests.get(url) | |
of.write(rsp.text) | |
import tdm_client | |
_ = importlib.reload(tdm_client) | |
setup_tdm() | |
""" | |
import concurrent.futures | |
import copy | |
import gzip | |
import json | |
import logging | |
from pathlib import Path | |
import os | |
import threading | |
import time | |
import urllib.request | |
import urllib | |
import zlib | |
import requests | |
version = 2.9 | |
DEV_ENV = False | |
if DEV_ENV is True: | |
PROD_SERVICE = 'http://localhost:5000/' | |
else: | |
PROD_SERVICE = 'https://www.jstor.org/api/tdm/v1/' | |
if DEV_ENV is True: | |
DOWNLOAD_WORKERS = 1 | |
else: | |
DOWNLOAD_WORKERS = 5 | |
# Try to use a shared cache directory if it exists, otherwise | |
# create cache dir in home directory. On colab, the cache dir | |
# will always be /content/tdm_data | |
shared_cache_dir = '/var/cache/tdm-data' | |
if os.environ.get("COLAB_GPU") is not None: | |
cache_dir = "/content/tdm_data" | |
elif os.path.exists(shared_cache_dir) is True: | |
cache_dir = shared_cache_dir | |
else: | |
home = str(Path.home()) | |
cache_dir = os.path.join(home, 'tdm-data') | |
if os.path.exists(cache_dir) is not True: | |
os.mkdir(cache_dir) | |
def get_token(): | |
qargs = { | |
"appName": "tdm", | |
} | |
headers = { | |
'Content-type': 'application/json', | |
'Accept': 'application/json', | |
'User-Agent': "tdm-colab-client" | |
} | |
rsp = requests.get("https://www.jstor.org/api/labs-jwt-service/iac-jwt", params=qargs, headers=headers) | |
token = rsp.json()["jwt"] | |
return token | |
class FileCache(object): | |
def __init__(self, **kwargs): | |
self.path = kwargs.get('path', self.default_cache_dir) | |
@property | |
def default_cache_dir(self): | |
""" | |
Cache files in user's home directory in the | |
.tdm directory | |
""" | |
return cache_dir | |
def file_path(self, key): | |
fn = key.replace('http://', '')\ | |
.replace('www', '')\ | |
.replace('://', '-')\ | |
.replace('/', '-')\ | |
.replace('.', '-')\ | |
.lstrip('-') | |
prefix = fn.split('-')[0] | |
if os.path.exists(self.path) is not True: | |
os.mkdir(self.path) | |
sub_dir = os.path.join(self.path, prefix) | |
if os.path.exists(sub_dir) is not True: | |
os.mkdir(sub_dir) | |
return os.path.join(sub_dir, fn) | |
def __getitem__(self, key): | |
fp = self.file_path(key) | |
with open(fp, 'rb') as inf: | |
raw = zlib.decompress(inf.read()) | |
return json.loads(raw) | |
def __setitem__(self, key, value): | |
fp = self.file_path(key) | |
with open(fp, 'wb') as outf: | |
raw = json.dumps(value) | |
c = zlib.compress(raw.encode('utf-8')) | |
outf.write(c) | |
def __delitem__(self, key): | |
fp = self.file_path(key) | |
return os.remove(fp) | |
def get(self, key): | |
try: | |
return self.__getitem__(key) | |
except (FileNotFoundError, zlib.error): | |
return None | |
def set(self, key): | |
return self.__setitem__(key) | |
class Dataset(object): | |
def __init__(self, bundle_id, token=None, cache=True): | |
self.id = bundle_id | |
self.auth_token = get_token() | |
self.request_headers = { | |
'Authorization': 'JWT {}'.format(self.auth_token), | |
'User-Agent': 'TDM', | |
'Accept-Encoding': 'gzip' | |
} | |
self.cache = cache | |
self.spec = self._load() | |
self.available = [] | |
self.items = self.spec['documents'] | |
self.gzip_path = os.path.join(cache_dir, "{}.jsonl.gz".format(self.id)) | |
self.queued_items = copy.copy(self.spec['documents']) | |
self.num_docs = len(self.items) | |
if cache_dir is None: | |
self.file_cache = {} | |
else: | |
fc = FileCache() | |
self.file_cache = fc | |
_ = self._download_gzip() | |
# Background thread that downloads the documents. | |
# See: https://gist.github.com/sebdah/832219525541e059aefa | |
# thread = threading.Thread(target=self._populate, args=(), name='tdm-download', daemon=True) | |
# thread.start() | |
def __enter__(self): | |
pass | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
pass | |
def close(self): | |
pass | |
def __len__(self): | |
return self.num_docs | |
def get_features2(self): | |
fetched = 0 | |
while True: | |
if len(self.available) == 0: | |
time.sleep(.25) | |
else: | |
for url in self.available: | |
yield self._fetch(url) | |
fetched += 1 | |
if fetched >= self.num_docs: | |
break | |
def get_features(self): | |
out = [] | |
with gzip.open(self.gzip_path, "rb") as inf: | |
for row in inf: | |
doc = json.loads(row.decode("utf-8")) | |
uc = doc.get("unigramCount") | |
if uc is None: | |
continue | |
else: | |
yield uc | |
def get_metadata2(self): | |
url = PROD_SERVICE + 'nb/dataset/{}/metadata/'.format(self.id) | |
rsp = requests.get(url, headers=self.request_headers) | |
if rsp.status_code != 200: | |
raise Exception("Can't load metadata for dataset {}.".format(self.id)) | |
else: | |
return rsp.json() | |
def get_metadata(self): | |
out = [] | |
with gzip.open(self.gzip_path, "rb") as inf: | |
for row in inf: | |
doc = json.loads(row.decode("utf-8")) | |
try: | |
del doc["unigramCount"] | |
except KeyError: | |
pass | |
out.append(doc) | |
return out | |
def get_feature(self, url): | |
return self._fetch(url) | |
def _load(self): | |
url = PROD_SERVICE + "nb/dataset/{}/info/".format(self.id) | |
d = requests.get(url, headers=self.request_headers) | |
meta = d.json() | |
if meta.get('message') is not None: | |
raise Exception(json.dumps(meta)) | |
return meta | |
def _download_gzip(self, force=False): | |
if (force is False) and (os.path.exists(self.gzip_path)): | |
pass | |
else: | |
_ = urllib.request.urlretrieve(self.spec["download_url"], self.gzip_path) | |
def _multiple_populate(self): | |
# multiple background download workers | |
with concurrent.futures.ThreadPoolExecutor(max_workers=DOWNLOAD_WORKERS) as executor: | |
future_to_url = {executor.submit(self._fetch, url): url for url in self.items} | |
for future in concurrent.futures.as_completed(future_to_url): | |
url = future_to_url[future] | |
self.available.append(url) | |
def _populate(self): | |
# single background thread for downloading features | |
if hasattr(self, 'items'): | |
for url in self.items: | |
_ = self._fetch(url) | |
self.available.append(url) | |
def query(self): | |
return self.spec['query'] | |
def query_text(self): | |
""" | |
Plain language search string. | |
""" | |
return self.spec["search_description"] | |
def _fetch(self, url): | |
""" | |
Returns dictionary of extracted features content. | |
""" | |
logging.info("Fetching {}".format(url)) | |
if self.cache is True: | |
from_cache = self.file_cache.get(url) | |
else: | |
from_cache = None | |
if from_cache is not None: | |
logging.info("{} found in cache.".format(url)) | |
return from_cache | |
else: | |
service_url = PROD_SERVICE + "nb/dataset/features/" | |
rsp = requests.post( | |
service_url, | |
data=json.dumps({"id": url}), | |
headers=self.request_headers | |
) | |
try: | |
raw = rsp.json() | |
self.file_cache[url] = raw | |
except json.JSONDecodeError: | |
logging.debug("Unable to load features: {}".format(url)) | |
raw = {} | |
return raw | |
class LocalDataset(Dataset): | |
def get(self, name=None, force=False): | |
path = os.path.join(os.getcwd(), "datasets") | |
if not os.path.exists(path): | |
os.mkdir(path) | |
gz_file = "{}.jsonl.gz".format(self.id) | |
gzip_path = os.path.join(path, gz_file) | |
if name is None: | |
unzip_file = gz_file.replace(".gz", "") | |
else: | |
unzip_file = "{}.jsonl".format(name) | |
unzip_path = os.path.join(path, unzip_file) | |
# download it | |
_ = urllib.request.urlretrieve(self.spec["download_url"], gzip_path) | |
# unzip to file with jsonl extension | |
with open(unzip_path, "w") as f_out: | |
with gzip.open(gzip_path, "rb") as f_in: | |
for row in f_in: | |
f_out.write(row.decode("utf-8")) | |
os.remove(gzip_path) | |
return os.path.join("datasets", unzip_file) | |
def get_dataset(dataset_id, name=None, force=False): | |
return LocalDataset(dataset_id).get(name=name, force=force) | |
## HTRC Stopwords and corrections | |
CORRECTIONS_URL = 'http://data.analytics.hathitrust.org/data/text/default_corrections.txt' | |
STOPWORDS_URL = 'http://data.analytics.hathitrust.org/data/text/default_stopwords_en.txt' | |
def htrc_fetch(url): | |
response = urllib.request.urlopen(url) | |
data = response.read() | |
text = data.decode('utf-8') | |
return text | |
def get_corrections(): | |
fc = FileCache() | |
_key = "htrc-corrections" | |
cv = fc.get(_key) | |
if cv is not None: | |
return cv | |
else: | |
raw = htrc_fetch(CORRECTIONS_URL) | |
d = {} | |
for n, row in enumerate(raw.splitlines()): | |
if n == 1: | |
continue | |
token, corr = row.split(',') | |
d[token] = corr | |
fc[_key] = d | |
return d | |
def get_stopwords(): | |
raw = htrc_fetch(STOPWORDS_URL) | |
fc = FileCache() | |
_key = "htrc-stopwords" | |
cv = fc.get(_key) | |
if cv is not None: | |
return cv | |
else: | |
stopwords = [] | |
for n, row in enumerate(raw.splitlines()): | |
token = row.strip() | |
stopwords.append(token) | |
fc[_key] = stopwords | |
return stopwords | |
try: | |
htrc_corrections = get_corrections() | |
except Exception: | |
htrc_corrections = {} | |
try: | |
htrc_stopwords = get_stopwords() | |
except Exception: | |
htrc_stopwords = [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment