Skip to content

Instantly share code, notes, and snippets.

@aminnj
Last active January 13, 2024 23:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aminnj/c1d66cc7d5be4f14a9f1e093731d7f75 to your computer and use it in GitHub Desktop.
Save aminnj/c1d66cc7d5be4f14a9f1e093731d7f75 to your computer and use it in GitHub Desktop.
Evaluate layers based on a list of indices with MLX
diff --git a/llms/mistral/mistral.py b/llms/mistral/mistral.py
index 9b9a602..5fd5146 100644
--- a/llms/mistral/mistral.py
+++ b/llms/mistral/mistral.py
@@ -144,6 +144,7 @@ class Mistral(nn.Module):
self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
+ self.ilayers = list(range(len(self.layers)))
def __call__(
self,
@@ -158,9 +159,10 @@ class Mistral(nn.Module):
mask = mask.astype(h.dtype)
if cache is None:
- cache = [None] * len(self.layers)
+ cache = [None] * len(self.ilayers)
- for e, layer in enumerate(self.layers):
+ for e, ilayer in enumerate(self.ilayers):
+ layer = self.layers[ilayer]
h, cache[e] = layer(h, mask, cache[e])
return self.output(self.norm(h)), cache
@@ -267,6 +269,21 @@ if __name__ == "__main__":
print("[INFO] Loading model from disk.")
model, tokenizer = load_model(args.model_path)
+ # default = list(range(model.n_layers))
+ # model.ilayers = default
+
+ overlap_8_by_4 = (
+ []
+ + list(range(0,8))
+ + list(range(4,12))
+ + list(range(8,16))
+ + list(range(12,20))
+ + list(range(16,24))
+ + list(range(20,28))
+ + list(range(24,32))
+ )
+ model.ilayers = overlap_8_by_4
+
print("[INFO] Starting generation...")
tic = time.time()
print(args.prompt, end="", flush=True)
@aminnj
Copy link
Author

aminnj commented Jan 13, 2024

  • pip install mlx on an Apple Silicon mac.
  • Clone https://github.com/ml-explore/mlx-examples
  • Follow instructions here to download Mistral-7B
  • Apply patch within mlx-examples
  • Modify ilayers. It is a list of transformer layer indices. For example, ilayers = [1,2,3,1,2,3] will stack the first 3 layers twice. Doubling each layer via ilayers will not double memory usage.
  • Run inference
    python mistral.py --model-path ../../../../../../lora/mistral-mlx/ --max-tokens 100 --prompt "A sci-fi story about aliens. Title: Alien virus. Story:"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment