Skip to content

Instantly share code, notes, and snippets.

@numb3r3
Created March 7, 2022 06:34
Show Gist options
  • Save numb3r3/d80515f68f8d0503766959ea26af8a70 to your computer and use it in GitHub Desktop.
Save numb3r3/d80515f68f8d0503766959ea26af8a70 to your computer and use it in GitHub Desktop.
prepare-dataset
import json
import re
from pathlib import Path
import click
from docarray import Document, DocumentArray
BLACK_TOKENS = set(['test', 'prefab', 'background'])
def strip_token(token):
token = re.sub("\d+", " ", token)
return token.strip()
# return ''.join(filter(lambda x: x.isalpha(), token))
def is_valid_token(token):
if not token: return False
if len(token) <= 2: return False
if token.startswith('[') and token.endswith(']'): return False
if token in BLACK_TOKENS: return False
count = 0
for t in token:
if t.isalpha() and t.islower():
count += 1
return count / len(token) > 0.65
def split_on_uppercase(s, seperators=['.', '-', '_', ' '], keep_contiguous: bool = False):
string_length = len(s)
is_lower_around = (lambda: s[i - 1].islower() or
string_length > (i + 1) and s[i + 1].islower())
start = 0
parts = []
for i in range(1, string_length):
if (s[i] in seperators) or (s[i].isupper() and (not keep_contiguous or is_lower_around())):
parts.append(s[start: i])
if s[i] in seperators:
i += 1
start = i
parts.append(s[start:])
return [t for t in parts if t]
def get_tokens(s):
tokens = split_on_uppercase(s, keep_contiguous=True)
tokens = [strip_token(t) for t in tokens]
tokens = [t for t in tokens if is_valid_token(t)]
if len(tokens) <= 2:
max_token_len = max([len(t) for t in tokens] + [0])
if max_token_len < 3:
return None
return tokens
def norm_tokens(tokens):
text = ' '.join([t.lower() for t in tokens])
return text
@click.command()
@click.option('-i', '--input_path', help='the input JSON-data path')
@click.option('-o', '--output_path', help='the output da path')
def main(input_path, output_path):
da = DocumentArray()
for fn in Path(input_path).glob('**/*.json'):
data = json.load(fn.open())
category = data['category']['slug']
package_path = data['extendedProperties']['packagePath']
images = data['images']['default']['featured']
package_file = ' '.join(package_path.split('/')[-1].split('.')[:-1])
tokens = get_tokens(package_file)
if tokens:
caption = norm_tokens(tokens)
doc = Document(uri=package_path, tags={'name': data['name'], 'category': category, 'caption': caption,
**data['extendedProperties']})
for img in images:
img_doc = Document(uri=img['href'])
img_doc.load_uri_to_image_tensor().set_image_tensor_shape((256, 256))
doc.chunks.append(img_doc)
da.append(doc)
da.save_binary(output_path)
if __name__ == '__main__':
main()
@numb3r3
Copy link
Author

numb3r3 commented Mar 7, 2022

Run

$ python3 prepare_data.py -i /path/to/data_json_folder -o dataset.bin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment