-
-
Save janodev/46438dc0756caf816f426a7a80385609 to your computer and use it in GitHub Desktop.
Prepare Core ML MiniLM_L6_v2 for macOS with FP16 and static sequence length.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import shutil | |
| import pathlib | |
| import torch | |
| import coremltools as ct | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| """Prepare Core ML MiniLM_L6_v2 for macOS with FP16 and static sequence length. | |
| This script converts sentence-transformers/all-MiniLM-L6-v2 to a Core ML MLProgram | |
| with: | |
| - FP16 compute (compute_precision=FLOAT16) | |
| - minimum_deployment_target macOS 14 | |
| - Fixed sequence length = 512 (static second dimension) | |
| - Bounded batch = 1..16 (default 16) | |
| It then copies MiniLM_L6_v2.mlpackage into EmbeddingsDB/Sources/EmbeddingsDB/Embeddings/Models. | |
| """ | |
| # Load tokenizer and model | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| model_id = "sentence-transformers/all-MiniLM-L6-v2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModel.from_pretrained(model_id) | |
| model.eval() | |
| # Wrap model to return only last_hidden_state | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self, base_model): | |
| super().__init__() | |
| self.base_model = base_model | |
| def forward(self, input_ids, attention_mask): | |
| output = self.base_model(input_ids=input_ids, attention_mask=attention_mask) | |
| return output.last_hidden_state | |
| wrapped_model = Wrapper(model) | |
| wrapped_model.eval() | |
| """Disable PyTorch MHA fastpath so coremltools can see attention structure. | |
| Some torch versions may not have this flag; ignore if unavailable.""" | |
| try: | |
| torch.backends.mha.set_fastpath_enabled(False) # type: ignore[attr-defined] | |
| except Exception: | |
| pass | |
| # Create dummy input with fixed max length (seq=512) | |
| MAX_SEQ = 512 | |
| example = tokenizer( | |
| "search_document: This is an example sentence.", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=MAX_SEQ, | |
| ) | |
| # Trace the model | |
| traced_model = torch.jit.trace(wrapped_model, (example['input_ids'], example['attention_mask'])) | |
| # Convert to Core ML (MLProgram, FP16, macOS14 target, static seq=512) | |
| batch_range = ct.RangeDim(lower_bound=1, upper_bound=16, default=16) | |
| static_shape = ct.Shape(shape=(batch_range, MAX_SEQ)) | |
| mlmodel = ct.convert( | |
| traced_model, | |
| inputs=[ | |
| ct.TensorType(name="input_ids", shape=static_shape, dtype=np.int32), | |
| ct.TensorType(name="attention_mask", shape=static_shape, dtype=np.int32), | |
| ], | |
| convert_to="mlprogram", | |
| compute_units=ct.ComputeUnit.ALL, | |
| compute_precision=ct.precision.FLOAT16, | |
| minimum_deployment_target=ct.target.macOS14, | |
| ) | |
| mlmodel.save("MiniLM_L6_v2.mlpackage") | |
| # Define the target directory using a relative path | |
| # Assuming script is at MiniLM/prepare-model.py | |
| # Target is Sources/EmbeddingsDB/Embeddings/Models/MiniLM_L6_v2.mlpackage | |
| script_dir = pathlib.Path(__file__).parent | |
| target_dir = script_dir.parent / "Sources" / "EmbeddingsDB" / "Embeddings" / "Models" | |
| target_path = target_dir / "MiniLM_L6_v2.mlpackage" | |
| # Create target directory if it doesn't exist | |
| os.makedirs(target_dir, exist_ok=True) | |
| # Remove existing model if it exists | |
| if os.path.exists(target_path): | |
| print(f"Removing existing model at {target_path}") | |
| shutil.rmtree(target_path) | |
| # Copy the model to the target directory | |
| print(f"Copying model from {os.path.abspath('MiniLM_L6_v2.mlpackage')} to {target_path}") | |
| shutil.copytree("MiniLM_L6_v2.mlpackage", target_path) | |
| print(f"Model successfully saved to: {target_path}") | |
| print(f"Absolute path: {target_path.absolute()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment