Skip to content

Instantly share code, notes, and snippets.

@omkaark
Created January 13, 2026 20:30
Show Gist options
  • Select an option

  • Save omkaark/e8df777234dbd5011d9420e906ef04e2 to your computer and use it in GitHub Desktop.

Select an option

Save omkaark/e8df777234dbd5011d9420e906ef04e2 to your computer and use it in GitHub Desktop.
import tinker
import math
from datasets import load_dataset
from tinker import types
import numpy as np
service_client = tinker.ServiceClient()
BASE_MODEL = "Qwen/Qwen3-8B-Base"
LEARNING_RATE = 5e-5
TARGET_MAX_TOKENS = 10_000_000
BATCH_SIZE_TOKENS = 128_000
MAX_STEPS = math.ceil(TARGET_MAX_TOKENS / BATCH_SIZE_TOKENS)
FIM_PREFIX = "<|fim_prefix|>"
FIM_SUFFIX = "<|fim_suffix|>"
FIM_MIDDLE = "<|fim_middle|>"
base_model = "Qwen/Qwen3-8B-Base"
training_client = service_client.create_lora_training_client(
base_model=base_model,
train_unembed=False
)
tokenizer = training_client.get_tokenizer()
ds = load_dataset("Etherll/code-fim-v2", split="train", streaming=True)
ds = ds.shuffle(seed=42, buffer_size=10000)
dataset_iter = iter(ds)
def process_example(example: dict, tokenizer) -> tuple[types.Datum | None, int]:
prompt = f"{FIM_PREFIX}{example['prefix']}{FIM_SUFFIX}{example['suffix']}{FIM_MIDDLE}"
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
prompt_mask = [0] * len(prompt_tokens)
completion_tokens = tokenizer.encode(f"{example['middle']}{tokenizer.eos_token}", add_special_tokens=False)
completion_mask = [1] * len(completion_tokens)
if len(prompt_tokens) + len(completion_tokens) > 20000:
return None, 0
tokens = prompt_tokens + completion_tokens
masks = prompt_mask + completion_mask
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
masks = masks[1:]
return types.Datum(
model_input=types.ModelInput.from_ints(tokens=input_tokens),
loss_fn_inputs=dict(weights=masks, target_tokens=target_tokens)
), len(tokens)
def get_batch():
processed_examples = []
token_count = 0
while token_count < BATCH_SIZE_TOKENS:
try:
example = next(dataset_iter)
(processed, sample_token_count) = process_example(example, tokenizer)
except StopIteration:
break
if processed is not None:
processed_examples.append(processed)
token_count += sample_token_count
return processed_examples
for _ in range(MAX_STEPS):
batch = get_batch()
# handle dataset exhaustion
if len(batch) == 0:
print("Dataset exhausted, restarting...")
dataset_iter = iter(ds.shuffle(seed=42 + step, buffer_size=10000))
batch = get_batch()
fwdbwd_future = await training_client.forward_backward_async(batch, "cross_entropy")
optim_future = await training_client.optim_step_async(types.AdamParams(learning_rate=LEARNING_RATE))
fwdbwd_result = await fwdbwd_future
optim_result = await optim_future
logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs])
weights = np.concatenate([sample.loss_fn_inputs['weights'].tolist() for sample in batch])
print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}")
final_path_future = training_client.save_weights_for_sampler(name="x10c-tachi")
final_path = final_path_future.result().path
print(f"Checkpoint saved at: {final_path}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment