Created
January 13, 2026 20:30
-
-
Save omkaark/e8df777234dbd5011d9420e906ef04e2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tinker | |
| import math | |
| from datasets import load_dataset | |
| from tinker import types | |
| import numpy as np | |
| service_client = tinker.ServiceClient() | |
| BASE_MODEL = "Qwen/Qwen3-8B-Base" | |
| LEARNING_RATE = 5e-5 | |
| TARGET_MAX_TOKENS = 10_000_000 | |
| BATCH_SIZE_TOKENS = 128_000 | |
| MAX_STEPS = math.ceil(TARGET_MAX_TOKENS / BATCH_SIZE_TOKENS) | |
| FIM_PREFIX = "<|fim_prefix|>" | |
| FIM_SUFFIX = "<|fim_suffix|>" | |
| FIM_MIDDLE = "<|fim_middle|>" | |
| base_model = "Qwen/Qwen3-8B-Base" | |
| training_client = service_client.create_lora_training_client( | |
| base_model=base_model, | |
| train_unembed=False | |
| ) | |
| tokenizer = training_client.get_tokenizer() | |
| ds = load_dataset("Etherll/code-fim-v2", split="train", streaming=True) | |
| ds = ds.shuffle(seed=42, buffer_size=10000) | |
| dataset_iter = iter(ds) | |
| def process_example(example: dict, tokenizer) -> tuple[types.Datum | None, int]: | |
| prompt = f"{FIM_PREFIX}{example['prefix']}{FIM_SUFFIX}{example['suffix']}{FIM_MIDDLE}" | |
| prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False) | |
| prompt_mask = [0] * len(prompt_tokens) | |
| completion_tokens = tokenizer.encode(f"{example['middle']}{tokenizer.eos_token}", add_special_tokens=False) | |
| completion_mask = [1] * len(completion_tokens) | |
| if len(prompt_tokens) + len(completion_tokens) > 20000: | |
| return None, 0 | |
| tokens = prompt_tokens + completion_tokens | |
| masks = prompt_mask + completion_mask | |
| input_tokens = tokens[:-1] | |
| target_tokens = tokens[1:] | |
| masks = masks[1:] | |
| return types.Datum( | |
| model_input=types.ModelInput.from_ints(tokens=input_tokens), | |
| loss_fn_inputs=dict(weights=masks, target_tokens=target_tokens) | |
| ), len(tokens) | |
| def get_batch(): | |
| processed_examples = [] | |
| token_count = 0 | |
| while token_count < BATCH_SIZE_TOKENS: | |
| try: | |
| example = next(dataset_iter) | |
| (processed, sample_token_count) = process_example(example, tokenizer) | |
| except StopIteration: | |
| break | |
| if processed is not None: | |
| processed_examples.append(processed) | |
| token_count += sample_token_count | |
| return processed_examples | |
| for _ in range(MAX_STEPS): | |
| batch = get_batch() | |
| # handle dataset exhaustion | |
| if len(batch) == 0: | |
| print("Dataset exhausted, restarting...") | |
| dataset_iter = iter(ds.shuffle(seed=42 + step, buffer_size=10000)) | |
| batch = get_batch() | |
| fwdbwd_future = await training_client.forward_backward_async(batch, "cross_entropy") | |
| optim_future = await training_client.optim_step_async(types.AdamParams(learning_rate=LEARNING_RATE)) | |
| fwdbwd_result = await fwdbwd_future | |
| optim_result = await optim_future | |
| logprobs = np.concatenate([output['logprobs'].tolist() for output in fwdbwd_result.loss_fn_outputs]) | |
| weights = np.concatenate([sample.loss_fn_inputs['weights'].tolist() for sample in batch]) | |
| print(f"Loss per token: {-np.dot(logprobs, weights) / weights.sum():.4f}") | |
| final_path_future = training_client.save_weights_for_sampler(name="x10c-tachi") | |
| final_path = final_path_future.result().path | |
| print(f"Checkpoint saved at: {final_path}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment