This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlx.core as mx | |
@mx.compile | |
def _compute_T1(A): | |
"""I + A""" | |
return mx.eye(A.shape[-1]) + A | |
@mx.compile | |
def _compute_T2(A): | |
"""I + A + A^2/2""" | |
A2 = A @ A | |
return mx.eye(A.shape[-1]) + A + A2/2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
# Define the differentiable orthonormal linear layer | |
class OrthonormalLayer(nn.Module): | |
def __init__(self, n): | |
""" | |
Initializes a learnable layer with an orthonormal weight matrix. | |
:param n: Dimension of the square weight matrix. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import random | |
import mlx.optimizers as optim | |
import mlx.core as mx | |
import mlx.nn as nn | |
import numpy as np | |
from tqdm import tqdm | |
import time | |
from datetime import datetime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_batched( | |
model: nn.Module, | |
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
prompt: str, | |
batch_size: int, | |
*, | |
verbose: bool = False, | |
formatter: Optional[Callable] = None, | |
max_tokens: int = 256, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mlx_lm import load | |
import mlx.core as mx | |
from mlx.utils import tree_flatten, tree_map, tree_unflatten | |
import numpy as np | |
# Copyright © 2023-2024 Apple Inc. | |
import contextlib | |
import copy | |
import glob | |
import importlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mlx_lm import load | |
import mlx.core as mx | |
from mlx.utils import tree_flatten, tree_map, tree_unflatten | |
import numpy as np | |
# Copyright © 2023-2024 Apple Inc. | |
import contextlib | |
import copy | |
import glob | |
import importlib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_speculative( | |
model: nn.Module, | |
draft_model: nn.Module, | |
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], | |
prompt: str, | |
max_tokens: int = 100, | |
verbose: bool = False, | |
formatter: Optional[Callable] = None, | |
**kwargs, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <math.h> | |
#include <time.h> | |
#define PI 3.14159265358979323846 | |
double s_inv(double x) { | |
return asin(2.0 * (x - 0.5)) / PI + 0.5; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import time | |
from pynput import keyboard | |
from datetime import datetime | |
import subprocess | |
import threading | |
import tkinter as tk | |
import queue | |
# ML imports |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchvision import datasets, transforms | |
from tqdm import tqdm | |
import numpy as np | |
# Import the generated predict function | |
from predict_function import predict | |
# Load MNIST test dataset | |
transform = transforms.Compose([transforms.ToTensor()]) |
NewerOlder