Skip to content

Instantly share code, notes, and snippets.

@tanaymeh
Last active May 12, 2024 18:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tanaymeh/5285a073f4ad7d7e8aa7e952fe220aa4 to your computer and use it in GitHub Desktop.
Save tanaymeh/5285a073f4ad7d7e8aa7e952fe220aa4 to your computer and use it in GitHub Desktop.
import lance
import pyarrow as pa
from tqdm.auto import tqdm
import datasets
from transformers import AutoTokenizer
# We'll be using the GPT neo tokenizer for tokenizing the code files
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
# Only load the Python code files from codeparrot dataset
dataset = load_dataset(
"codeparrot/github-code",
streaming=True,
split="train",
languages=["Python"]
)
dataset = dataset.shuffle(seed=42)
def tokenize(sample):
return tokenizer(sample['code'])['input_ids']
total_samples = 5_000_000 # 5 Million samples
def process_samples():
current_sample = 0
for sample in tqdm(dataset, total=total_samples):
# If we have added all 5M samples, stop
if current_sample == total_samples:
break
# Tokenize the current sample
tokenized_sample = tokenize(sample)
# Increement the counter
current_sample += 1
# Yield a PyArrow RecordBatch
yield pa.RecordBatch.from_arrays(
[tokenized_sample],
names=["value"]
)
# Define the dataset schema
schema = pa.schema([
pa.field("value", pa.int64())
])
# The reader takes in a schema and the function
reader = pa.RecordBatchReader.from_batches(schema, process_samples())
# Write the dataset to disk (this will start the actual process)
lance.write_dataset(reader, "code_parrot_5M_subset.lance", schema)
# Try reading the dataset
# First make a dataset descriptor and see how many rows we have
dataset = lance.dataset("code_parrot_5M_subset.lance")
print(dataset.count_rows()) # Should be 5M total samples
def load_data(dataset, indices):
# Load the data at these indices
data = dataset.take(indices).to_pylist()
# A little short-cut to get the tokens in one single list
data = list(map(lambda x: x['value'], data))
return data
# Load first 100 tokens
indices = [x for x in range(100)]
tokens = load_data(dataset, indices)
print(tokenizer.decode(tokens)) # Decode the tokens to see that coherence is preserved.
@elcachorrohumano
Copy link

The package is lance or lancedb?

@tanaymeh
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment