Last active
May 12, 2024 18:04
-
-
Save tanaymeh/5285a073f4ad7d7e8aa7e952fe220aa4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lance | |
import pyarrow as pa | |
from tqdm.auto import tqdm | |
import datasets | |
from transformers import AutoTokenizer | |
# We'll be using the GPT neo tokenizer for tokenizing the code files | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | |
# Only load the Python code files from codeparrot dataset | |
dataset = load_dataset( | |
"codeparrot/github-code", | |
streaming=True, | |
split="train", | |
languages=["Python"] | |
) | |
dataset = dataset.shuffle(seed=42) | |
def tokenize(sample): | |
return tokenizer(sample['code'])['input_ids'] | |
total_samples = 5_000_000 # 5 Million samples | |
def process_samples(): | |
current_sample = 0 | |
for sample in tqdm(dataset, total=total_samples): | |
# If we have added all 5M samples, stop | |
if current_sample == total_samples: | |
break | |
# Tokenize the current sample | |
tokenized_sample = tokenize(sample) | |
# Increement the counter | |
current_sample += 1 | |
# Yield a PyArrow RecordBatch | |
yield pa.RecordBatch.from_arrays( | |
[tokenized_sample], | |
names=["value"] | |
) | |
# Define the dataset schema | |
schema = pa.schema([ | |
pa.field("value", pa.int64()) | |
]) | |
# The reader takes in a schema and the function | |
reader = pa.RecordBatchReader.from_batches(schema, process_samples()) | |
# Write the dataset to disk (this will start the actual process) | |
lance.write_dataset(reader, "code_parrot_5M_subset.lance", schema) | |
# Try reading the dataset | |
# First make a dataset descriptor and see how many rows we have | |
dataset = lance.dataset("code_parrot_5M_subset.lance") | |
print(dataset.count_rows()) # Should be 5M total samples | |
def load_data(dataset, indices): | |
# Load the data at these indices | |
data = dataset.take(indices).to_pylist() | |
# A little short-cut to get the tokens in one single list | |
data = list(map(lambda x: x['value'], data)) | |
return data | |
# Load first 100 tokens | |
indices = [x for x in range(100)] | |
tokens = load_data(dataset, indices) | |
print(tokenizer.decode(tokens)) # Decode the tokens to see that coherence is preserved. |
The package is pylance
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The package is lance or lancedb?