Skip to content

Instantly share code, notes, and snippets.

@itsdotscience
Last active May 20, 2024 07:17
Show Gist options
  • Save itsdotscience/e888ad00b1747c1ddc7afc36fd705c22 to your computer and use it in GitHub Desktop.
Save itsdotscience/e888ad00b1747c1ddc7afc36fd705c22 to your computer and use it in GitHub Desktop.
WIP, may likely bork your manifest or blobs. backup before using.
import os
import logging
import sys
import hashlib
import json
import subprocess
import shutil
import colorama
from colorama import Fore, Style
# Initialize colorama
colorama.init(autoreset=True)
# Configure logging
logging.basicConfig(level=logging.INFO)
def calculate_digest(data):
return hashlib.sha256(data.encode('utf-8')).hexdigest()
def get_ollama_home():
return os.getenv("OLLAMA_HOME", "/usr/share/ollama")
def get_ollama_models():
try:
result = subprocess.run(['ollama', 'list'], capture_output=True, text=True)
return result.stdout.strip().split("\n")
except subprocess.CalledProcessError as e:
logging.error(f"Error occurred while fetching model list from OLLAMA: {e}")
return []
def update_parameter_layer(model_name, parameters):
parameters_json = json.dumps(parameters, separators=(',', ':'))
parameters_digest = calculate_digest(parameters_json)
parameter_layer_path = os.path.join(get_ollama_home(), ".ollama", "models", "blobs", f"sha256-{parameters_digest}")
with open(parameter_layer_path, 'w') as parameter_layer_file:
json.dump(parameters, parameter_layer_file, separators=(',', ':'))
logging.info(f"Parameter layer updated for model '{model_name}'.")
return parameter_layer_path # Return the path to the updated parameter layer
def update_manifest(model_name, version, parameters_digest):
manifest_path = os.path.join(get_ollama_home(), ".ollama", "models", "manifests", "registry.ollama.ai", "library", model_name, version)
try:
with open(manifest_path, 'r+') as manifest_file:
manifest_data = json.load(manifest_file)
layers = manifest_data.get('layers', [])
ollama_parameters = {
"mediaType": "application/vnd.ollama.image.params",
"digest": f"sha256:{parameters_digest}",
"size": os.path.getsize(os.path.join(get_ollama_home(), ".ollama", "models", "blobs", f"sha256-{parameters_digest}"))
}
existing_ollama_parameters = next((layer for layer in layers if layer.get("mediaType") == "application/vnd.ollama.image.params"), None)
if existing_ollama_parameters:
existing_ollama_parameters.update(ollama_parameters)
else:
layers.append(ollama_parameters)
manifest_data['layers'] = layers
manifest_file.seek(0)
json.dump(manifest_data, manifest_file, separators=(',', ':'))
manifest_file.truncate()
logging.info(f"Manifest updated for model '{model_name}' with version '{version}' and parameters digest '{parameters_digest}'.")
except FileNotFoundError:
logging.error(f"Manifest file not found at {manifest_path}.")
except json.JSONDecodeError:
logging.error(f"Failed to decode JSON from manifest file at {manifest_path}.")
def read_modelfile(modelfile_path):
try:
with open(modelfile_path, 'r') as f:
return f.readlines()
except FileNotFoundError:
logging.error(f"Modelfile not found at {modelfile_path}.")
return []
except IOError as e:
logging.error(f"Error reading modelfile at {modelfile_path}: {e}")
return []
def write_modelfile(modelfile_path, modelfile_data):
try:
with open(modelfile_path, 'w') as f:
f.writelines(modelfile_data)
logging.info("Modelfile updated successfully.")
except IOError as e:
logging.error(f"Error writing modelfile at {modelfile_path}: {e}")
def parse_parameters(modelfile_data):
parameters = {}
for line in modelfile_data:
if line.startswith("PARAMETER"):
_, key, value = line.strip().split(" ", 2)
parameters[key] = value
return parameters
# Function to update modelfile parameters
def update_modelfile_parameters(model_name, version, modelfile_path, interactive_mode=False):
if interactive_mode:
start_interactive_mode(model_name, version)
return
if not modelfile_path:
logging.error("Modelfile path is required when not in interactive mode.")
return
modelfile_data = read_modelfile(modelfile_path)
if not modelfile_data:
return
parameters = parse_parameters(modelfile_data)
show_command = ['ollama', 'show', '--modelfile', f"{model_name}:{version}"]
result = subprocess.run(show_command, capture_output=True, text=True)
current_modelfile_data = result.stdout.split('\n')
current_parameters = parse_parameters(current_modelfile_data)
new_params = parse_parameters(modelfile_data)
diff_params = {key: new_params[key] for key in new_params if key not in current_parameters or current_parameters[key] != new_params[key]}
logging.info("Changes made:")
for key, value in diff_params.items():
logging.info(f"PARAMETER {key} {value}")
# Flatten the parameters into a dictionary
flattened_params = {key: value for key, value in new_params.items()} # Convert set to dictionary
# Update the manifest with the flattened parameters
update_manifest(model_name, version, calculate_digest(json.dumps(parameters, separators=(',', ':'))))
# Update the parameter layer
update_parameter_layer(model_name, new_params)
def start_interactive_mode():
while True:
models = get_ollama_models()
if not models:
logging.error("No models found in OLLAMA.")
return
terminal_width = shutil.get_terminal_size().columns
logging.info("Available Models:")
max_model_length = max(len(model.split()[0]) for model in models[1:])
for i, model in enumerate(models[1:], start=1):
model_name, metadata = model.split(None, 1)
padding = terminal_width - max_model_length - 6
print(f"#{Fore.BLUE}{str(i).rjust(2)}. {Style.RESET_ALL}{model_name.ljust(max_model_length)} - {metadata.strip()}")
model_choice = input("Select a model by entering its name or number, or type 'exit' to quit: ")
if model_choice.lower() == 'exit':
break
try:
model_index = int(model_choice) - 1
selected_model = models[model_index + 1].split(None, 1)[0]
except (ValueError, IndexError):
selected_model = model_choice
logging.info(f"Selected model: {selected_model}")
# Extract model name and version from selected model
model_name, version = selected_model.split(':')
show_command = ['ollama', 'show', '--modelfile', selected_model]
result = subprocess.run(show_command, capture_output=True, text=True)
modelfile_content = result.stdout.split('\n')
parameters_found = [line.strip() for line in modelfile_content if line.startswith('PARAMETER')]
if parameters_found:
logging.info("PARAMETERs found in the Modelfile:")
for param in parameters_found:
print(param)
logging.info("Enter the parameters you want to update/delete (or 'q' to finish):")
updated_params = []
while True:
user_input = input("Enter PARAMETER instruction (or 'q' to finish): ")
if user_input.lower() == 'q':
break
updated_params.append(user_input)
parameters = parse_parameters(modelfile_content)
for param in updated_params:
if param.startswith("PARAMETER"):
_, key, value = param.split(" ", 2)
parameters[key] = value
elif param.startswith("DELETE"):
_, key = param.split(" ", 1)
if key in parameters:
del parameters[key]
modelfile_content = [f"PARAMETER {key} {value}\n" for key, value in parameters.items()]
logging.info("Updated Modelfile content:")
for item in modelfile_content:
print(item)
parameter_layer_path = update_parameter_layer(selected_model, parameters)
# Update the manifest with the new parameters
logging.info("Updating manifest...")
update_manifest(model_name, version, calculate_digest(json.dumps(parameters, separators=(',', ':'))))
# Fetch and display updated parameters
logging.info("Fetching updated parameters from the manifest...")
show_command = ['ollama', 'show', '--modelfile', selected_model]
result = subprocess.run(show_command, capture_output=True, text=True)
updated_modelfile_content = result.stdout.split('\n')
updated_parameters_found = [line.strip() for line in updated_modelfile_content if line.startswith('PARAMETER')]
if updated_parameters_found:
logging.info("Updated PARAMETERs in the Modelfile:")
for param in updated_parameters_found:
print(param)
# Add logging to verify if the parameter layer is updated correctly
logging.info("Fetching parameter layer after update:")
with open(parameter_layer_path, 'r') as parameter_layer_file:
parameter_layer_content = parameter_layer_file.read()
logging.info(f"Parameter layer content for {selected_model}: {parameter_layer_content}")
logging.info("Interactive mode finished.")
def main(args):
if len(args) == 1:
print_help()
elif "-i" in args:
if len(args) == 2:
start_interactive_mode()
elif len(args) == 4:
model_identifier, modelfile_path, interactive_flag = args[1:]
model_name, version = model_identifier.split(":")
start_interactive_mode(model_name, version)
else:
logging.error("Invalid option. Please provide model identifier and Modelfile path with -i flag.")
print_help()
elif len(args) == 3:
model_identifier, modelfile_path = args[1:]
model_name, version = model_identifier.split(":")
update_modelfile_parameters(model_name, version, modelfile_path)
else:
logging.error("Invalid option. Please provide -i for interactive mode or specify model identifier and Modelfile path.")
def print_help():
help_message = """
Usage:
python3 script.py [model_name:version] [modelfile_path] [-i]
If only model_name:version is provided:
- Prints current parameters and SHA256 hash of the Modelfile.
If modelfile_path is provided:
- Prompts for parameter updates or deletions if -i is specified and updates the Modelfile accordingly.
"""
print(help_message)
if __name__ == "__main__":
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment