Skip to content

Instantly share code, notes, and snippets.

@lampts
Forked from ramiil/prepare_thr.py
Created April 16, 2023 02:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lampts/1c6d5c3220eaed7e11980383724fae66 to your computer and use it in GitHub Desktop.
Save lampts/1c6d5c3220eaed7e11980383724fae66 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import os
import tiktoken
import time
import multiprocessing
working_dir = os.path.dirname(os.path.realpath(__file__))
dataset = 'data'
ws = 512*1024*1024 # 128k per chunk
def chunks(arr, size):
for i in range(0, len(arr), size):
yield arr[i:i + size]
def tofile(lst, name):
with open(name, 'ab') as fh:
for i in lst:
fh.write(i.to_bytes(2, 'little'))
def process_files(pid, lst):
enc = tiktoken.get_encoding("gpt2")
data = ''
for i in lst:
if not i[-4:]=='.txt':
continue
print(' [{0}] {1}'.format(pid, i))
with open(working_dir+'\\'+dataset+'\\'+i, 'r', encoding="utf8") as f:
data += f.read()
if len(data)>=ws:
# encode with tiktoken gpt2 bpe
print(' [{0}] Encoding {1} mb of data'.format(pid, len(data)//(1024*1024)))
ids = enc.encode_ordinary(data)
tofile(ids, dataset+'_'+str(pid)+'.bin')
data = ''
MAX_THREADS = multiprocessing.cpu_count()
threads = []
if __name__ == "__main__":
nowtime = time.time()
files = os.listdir(working_dir+'\\'+dataset)
for pid, ch in enumerate(chunks(files, len(files)//MAX_THREADS)):
print('Running process {0} of {1}'.format(pid, MAX_THREADS))
threads.append(multiprocessing.Process(target=process_files, args=(pid, ch)))
threads[pid].start()
for i in range(0, MAX_THREADS):
threads[i].join()
print(time.time() - nowtime)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment