Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Created May 10, 2023 13:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mehdidc/199b67dc18d40e10bef5eccab247efa8 to your computer and use it in GitHub Desktop.
Save mehdidc/199b67dc18d40e10bef5eccab247efa8 to your computer and use it in GitHub Desktop.
import io
import tarfile
import random
from collections import defaultdict
from lxml import etree
import uuid
from PIL import Image, ImageDraw
from glob import glob
import time
import os
import json
import webdataset as wds
from subprocess import call
import shutil
from PIL import Image
def get_wds(path):
# path = "/p/fastdata/datasets/pubmed/raw/00000.tar"
# ds = wds.WebDataset(path).compose(preprocess_pmc)
# ds = wds.DataPipeline(
# wds.SimpleShardList(path),
# wds.split_by_worker,
# wds.tarfile_to_samples(),
# )
# for x in ds:
# print(x['json'], x['__key__'])
# sys.exit(0)
# ds = wds.SimpleShardList(path, splitter=wds.split_by_worker)
# loader = wds.WebLoader(ds, num_workers=4, batch_size=2, collate_fn=lambda x:x, )
bs = 8
ds = wds.DataPipeline(
wds.SimpleShardList(path),
wds.split_by_worker,
wds.tarfile_to_samples(),
(preprocess_pmc),
wds.rename(image="jpg;png",txt="txt"),
wds.to_tuple("key", "image", "txt"),
# wds.to_tuple("__key__"),
# wds.to_tuple("json"),
wds.batched(bs),
)
loader = wds.WebLoader(ds, num_workers=bs, batch_size=None, collate_fn=lambda x:x, persistent_workers=False)
# print(path)
# loader = wds.WebDataset(path)
# nb = 0
# uniq = set()
# t0 = time.time()
# i = 0
# for t, in loader:
# nb += len(t)
# if i % 1000 == 0:
# dt = time.time() - t0
# print(nb, dt, nb/dt)
# i += 1
# sys.exit(0)
return loader
def preprocess_pmc(src):
for sample in src:
try:
t0 = time.time()
#print(sample.keys())
#print(sample['json'])
js = json.loads(sample['json'])
K = js['url'].replace('/', '_')
filename = js['url']
t0 = time.time()
desc = io.BytesIO(sample['flac'])
tf = tarfile.open(fileobj=desc, mode="r:gz")
by_name_and_ext = {}
names =set()
by_ext = defaultdict(list)
members = {}
for member in tf.getmembers():
f = member.name
members[f] = member
name = os.path.basename(os.path.splitext(f)[0])
ext = os.path.basename(os.path.splitext(f)[1])
by_name_and_ext[(name, ext)] = f
names.add(name)
by_ext[ext].append(f)
if not len(by_ext['.nxml']):
continue
xml_file = by_ext['.nxml'][0]
xml_file = tf.extractfile(members[xml_file])
tree = etree.parse(xml_file)
fig_tags = tree.xpath('//fig')
if len(fig_tags) == 0:
continue
nb = 0
for tag in fig_tags:
captions = tag.findall("caption")
captions = ([get_text(c) for c in captions])
imgs = tag.findall("graphic")
imgs = [get_href(i) for i in imgs]
if len(imgs) != 1 or len(captions) != 1:
continue
img_name = imgs[0]
caption = captions[0]
if img_name not in names:
continue
if (img_name, '.jpg') in by_name_and_ext:
filename = by_name_and_ext[(img_name, '.jpg')]
elif (img_name, '.png') in by_name_and_ext:
filename = by_name_and_ext[(img_name, '.png')]
else:
continue
data = tf.extractfile(members[filename])
data = data.read()
# data = io.BytesIO(data)
# img = Image.open(data)
img = data
ext = os.path.splitext(filename)[1][1:]
key = K + "-" + img_name
yield {"key":key, "__key__": key, ext: img, "txt": caption}
nb += 1
except Exception as ex:
print(ex)
continue
class ShuffledIter:
def __init__(self, data):
self.data = data
def __iter__(self):
while True:
random.shuffle(self.data)
yield from self.data
def get_text(node):
return ''.join(node.itertext())
def get_href(i):
if 'xlink:href' in i.attrib:
return i.attrib['xlink:href']
elif '{http://www.w3.org/1999/xlink}href' in i.attrib:
return i.attrib['{http://www.w3.org/1999/xlink}href']
else:
return None
def main():
random.seed(0)
nb_shards = 2500
sinks = [wds.TarWriter(f"/p/fastdata/datasets/pubmed/figure-captions/{i:05d}.tar") for i in range(nb_shards)]
sink_iter = iter(ShuffledIter(sinks))
dataset = get_wds("/p/fastdata/datasets/pubmed/raw/{00000..00520}.tar")
t0 = time.time()
nb = 0
i = 0
for keys, ims, txts in dataset:
for key, im, txt in zip(keys, ims, txts):
data = {
"__key__":key,
"jpg": im,
"txt": txt,
}
sink = next(sink_iter)
sink.write(data)
nb += len(keys)
if i % 1000 == 0:
dt = time.time() - t0
print(dt, nb/dt)
i += 1
for sink in sinks:
sink.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment