Skip to content

Instantly share code, notes, and snippets.

@vpipkt
Last active November 17, 2020 20:58
Show Gist options
  • Save vpipkt/32ad88a71ae266f121976ce8bbd2c75b to your computer and use it in GitHub Desktop.
Save vpipkt/32ad88a71ae266f121976ce8bbd2c75b to your computer and use it in GitHub Desktop.
get a bunch of segmentation masks from labelbox
import labelbox
import numpy as np
import PIL
import requests
from tqdm import tqdm
from retry import retry
from typing import *
import io
def session(lb_api_key):
_session = requests.Session()
_session.headers.update({'Authorization': f'Bearer {lb_api_key}'})
return _session
def get_labels(project_id: str, lb_api_key: str, timeout_seconds: int = 60):
client = labelbox.Client() # rely on env var for auth
project = client.get_project(project_id)
# is this blocking?
labels_url = project.export_labels(timeout_seconds)
if not labels_url:
raise Exception(f"Labelbox back end didnt generate a response in {timeout_seconds} seconds.")
# at this point no streaming
r = requests.get(labels_url)
r.raise_for_status()
# list of label JSON objects
return r.json()
@retry(requests.HTTPError, tries=5, delay=1, backoff=2)
def get_label_mask(obj: Union[str, Dict], lb_api_key: str) -> np.ndarray:
"""
:param obj: either a dict or str; if dict it should have a key `instanceURI`; see Label.objects
:returns: a numpy ndarray uint8 dtype and single channel (2d)
"""
if isinstance(obj, dict) and 'instanceURI' in obj:
uri = obj['instanceURI']
else:
uri = str(obj)
with session(lb_api_key).get(uri) as r:
if r.status_code == 410:
# Gone, deleted
return None
r.raise_for_status()
with PIL.Image.open(io.BytesIO(r.content)) as i:
# an assumption about LB labels. always B/W content encoded as 3 or 4 channel image
_m = np.array(i.getchannel(0)).reshape((i.height, i.width))
return _m > 0
if __name__ == "__main__":
import sys
# two arguments to CLI
# 1. api key
lb_api_key = sys.argv[1]
# 2. project id
PROJECT_ID = sys.argv[2]
labels_download = get_labels(PROJECT_ID, lb_api_key)
seg_objects_count = 0
for l in tqdm(labels_download):
label = l['Label']
for o in tqdm(label['objects'], desc='seg objects'):
_ = get_label_mask(o, lb_api_key)
seg_objects_count += 1
print("got a bunch of masks:", seg_objects_count)
# use this __main__ instead for parallel processing
# this will hit the rate limit in a minute or two
if __name__ == "__main__":
import sys
# two arguments to CLI
# 1. api key
lb_api_key = sys.argv[1]
# 2. project id
PROJECT_ID = sys.argv[2]
labels_download = get_labels(PROJECT_ID, lb_api_key)
seg_objects_count = 0
def get_all_masks(label, lb_api_key):
_seg_objects_count = 0
for o in label['Label']['objects']:
_ = get_label_mask(o, lb_api_key)
_seg_objects_count += 1
return _seg_objects_count
from functools import partial
from tqdm.contrib.concurrent import process_map
f = partial(get_all_masks, lb_api_key=lb_api_key)
seg_objects_count = process_map(f, labels_download, chunksize=5)
print("got a bunch of masks:", sum(seg_objects_count))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment