Skip to content

Instantly share code, notes, and snippets.

View N8python's full-sized avatar

n8programs N8python

View GitHub Profile
@N8python
N8python / matrix_exp.py
Created January 19, 2025 22:18
MLX matrix_exp.
import mlx.core as mx
@mx.compile
def _compute_T1(A):
"""I + A"""
return mx.eye(A.shape[-1]) + A
@mx.compile
def _compute_T2(A):
"""I + A + A^2/2"""
A2 = A @ A
return mx.eye(A.shape[-1]) + A + A2/2
@N8python
N8python / orthomatrix.py
Created January 19, 2025 05:32
OPTIMIZE with me
import torch
import torch.nn as nn
import torch.optim as optim
# Define the differentiable orthonormal linear layer
class OrthonormalLayer(nn.Module):
def __init__(self, n):
"""
Initializes a learnable layer with an orthonormal weight matrix.
:param n: Dimension of the square weight matrix.
@N8python
N8python / pretrain.py
Created January 15, 2025 23:41
Simple character-level pretraining in MLX. Gets a roughly billion tokens/day for an 18M parameter model on one M3 Max.
import json
import random
import mlx.optimizers as optim
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from tqdm import tqdm
import time
from datetime import datetime
@N8python
N8python / batch.py
Created December 31, 2024 00:39
Prototypes.
def generate_batched(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
batch_size: int,
*,
verbose: bool = False,
formatter: Optional[Callable] = None,
max_tokens: int = 256,
@N8python
N8python / MORE.py
Last active December 7, 2024 22:01
faster every day
from mlx_lm import load
import mlx.core as mx
from mlx.utils import tree_flatten, tree_map, tree_unflatten
import numpy as np
# Copyright © 2023-2024 Apple Inc.
import contextlib
import copy
import glob
import importlib
from mlx_lm import load
import mlx.core as mx
from mlx.utils import tree_flatten, tree_map, tree_unflatten
import numpy as np
# Copyright © 2023-2024 Apple Inc.
import contextlib
import copy
import glob
import importlib
def generate_speculative(
model: nn.Module,
draft_model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
@N8python
N8python / cursed-convergence.c
Created November 1, 2024 07:15
It's cursed: https://www.desmos.com/calculator/asjxggbglo. What if you had only one discrete variable that you could plug into a sum sign, and *everything else had to be perfectly continuous* - no floor, no mod, no tricks. Could you even DO a generalized 2d riemann sum that converges as k -> infinity. Find out now (spoiler: you kinda can) - modi…
#include <stdio.h>
#include <math.h>
#include <time.h>
#define PI 3.14159265358979323846
double s_inv(double x) {
return asin(2.0 * (x - 0.5)) / PI + 0.5;
}
import os
import time
from pynput import keyboard
from datetime import datetime
import subprocess
import threading
import tkinter as tk
import queue
# ML imports
@N8python
N8python / nano-mnist-replicate.py
Created October 3, 2024 05:11
Code to replicate score
import torch
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
# Import the generated predict function
from predict_function import predict
# Load MNIST test dataset
transform = transforms.Compose([transforms.ToTensor()])