Model: groq/deepseek-r1-distill-llama-70b
files
../llm-mistral/llm_mistral.py
---
import click
from httpx_sse import connect_sse, aconnect_sse
import httpx
import json
import llm
from pydantic import Field
from typing import Optional
DEFAULT_ALIASES = {
"mistral/mistral-tiny": "mistral-tiny",
"mistral/open-mistral-nemo": "mistral-nemo",
"mistral/mistral-small": "mistral-small",
"mistral/mistral-medium": "mistral-medium",
"mistral/mistral-large-latest": "mistral-large",
"mistral/codestral-mamba-latest": "codestral-mamba",
"mistral/codestral-latest": "codestral",
"mistral/ministral-3b-latest": "ministral-3b",
"mistral/ministral-8b-latest": "ministral-8b",
"mistral/pixtral-12b-latest": "pixtral-12b",
"mistral/pixtral-large-latest": "pixtral-large",
}
@llm.hookimpl
def register_models(register):
for model in get_model_details():
model_id = model["id"]
vision = model.get("capabilities", {}).get("vision")
our_model_id = "mistral/" + model_id
alias = DEFAULT_ALIASES.get(our_model_id)
aliases = [alias] if alias else []
register(
Mistral(our_model_id, model_id, vision),
AsyncMistral(our_model_id, model_id, vision),
aliases=aliases,
)
@llm.hookimpl
def register_embedding_models(register):
register(MistralEmbed())
def refresh_models():
user_dir = llm.user_dir()
mistral_models = user_dir / "mistral_models.json"
key = llm.get_key("", "mistral", "LLM_MISTRAL_KEY")
if not key:
raise click.ClickException(
"You must set the 'mistral' key or the LLM_MISTRAL_KEY environment variable."
)
response = httpx.get(
"https://api.mistral.ai/v1/models", headers={"Authorization": f"Bearer {key}"}
)
response.raise_for_status()
models = response.json()
mistral_models.write_text(json.dumps(models, indent=2))
return models
def get_model_details():
user_dir = llm.user_dir()
models = {
"data": [
{"id": model_id.replace("mistral/", "")}
for model_id in DEFAULT_ALIASES.keys()
]
}
mistral_models = user_dir / "mistral_models.json"
if mistral_models.exists():
models = json.loads(mistral_models.read_text())
elif llm.get_key("", "mistral", "LLM_MISTRAL_KEY"):
try:
models = refresh_models()
except httpx.HTTPStatusError:
pass
return [model for model in models["data"] if "embed" not in model["id"]]
def get_model_ids():
return [model["id"] for model in get_model_details()]
@llm.hookimpl
def register_commands(cli):
@cli.group()
def mistral():
"Commands relating to the llm-mistral plugin"
@mistral.command()
def refresh():
"Refresh the list of available Mistral models"
before = set(get_model_ids())
refresh_models()
after = set(get_model_ids())
added = after - before
removed = before - after
if added:
click.echo(f"Added models: {', '.join(added)}", err=True)
if removed:
click.echo(f"Removed models: {', '.join(removed)}", err=True)
if added or removed:
click.echo("New list of models:", err=True)
for model_id in get_model_ids():
click.echo(model_id, err=True)
else:
click.echo("No changes", err=True)
class _Shared:
can_stream = True
needs_key = "mistral"
key_env_var = "LLM_MISTRAL_KEY"
class Options(llm.Options):
temperature: Optional[float] = Field(
description=(
"Determines the sampling temperature. Higher values like 0.8 increase randomness, "
"while lower values like 0.2 make the output more focused and deterministic."
),
ge=0,
le=1,
default=0.7,
)
top_p: Optional[float] = Field(
description=(
"Nucleus sampling, where the model considers the tokens with top_p probability mass. "
"For example, 0.1 means considering only the tokens in the top 10% probability mass."
),
ge=0,
le=1,
default=1,
)
max_tokens: Optional[int] = Field(
description="The maximum number of tokens to generate in the completion.",
ge=0,
default=None,
)
safe_mode: Optional[bool] = Field(
description="Whether to inject a safety prompt before all conversations.",
default=False,
)
random_seed: Optional[int] = Field(
description="Sets the seed for random sampling to generate deterministic results.",
default=None,
)
def __init__(self, our_model_id, mistral_model_id, vision):
self.model_id = our_model_id
self.mistral_model_id = mistral_model_id
if vision:
self.attachment_types = {
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
}
def build_messages(self, prompt, conversation):
messages = []
latest_message = None
if prompt.attachments:
latest_message = {
"role": "user",
"content": [{"type": "text", "text": prompt.prompt}]
+ [
{
"type": "image_url",
"image_url": attachment.url
or f"data:{attachment.resolve_type()};base64,{attachment.base64_content()}",
}
for attachment in prompt.attachments
],
}
else:
latest_message = {"role": "user", "content": prompt.prompt}
if not conversation:
if prompt.system:
messages.append({"role": "system", "content": prompt.system})
messages.append(latest_message)
return messages
current_system = None
for prev_response in conversation.responses:
if (
prev_response.prompt.system
and prev_response.prompt.system != current_system
):
messages.append(
{"role": "system", "content": prev_response.prompt.system}
)
current_system = prev_response.prompt.system
if prev_response.attachments:
messages.append(
{
"role": "user",
"content": [
{
"type": "text",
"text": prev_response.prompt.prompt,
}
]
+ [
{
"type": "image_url",
"image_url": attachment.url
or f"data:{attachment.resolve_type()};base64,{attachment.base64_content()}",
}
for attachment in prev_response.attachments
],
}
)
else:
messages.append(
{"role": "user", "content": prev_response.prompt.prompt}
)
messages.append(
{"role": "assistant", "content": prev_response.text_or_raise()}
)
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append(latest_message)
return messages
def build_body(self, prompt, messages):
body = {
"model": self.mistral_model_id,
"messages": messages,
}
if prompt.options.temperature:
body["temperature"] = prompt.options.temperature
if prompt.options.top_p:
body["top_p"] = prompt.options.top_p
if prompt.options.max_tokens:
body["max_tokens"] = prompt.options.max_tokens
if prompt.options.safe_mode:
body["safe_mode"] = prompt.options.safe_mode
if prompt.options.random_seed:
body["random_seed"] = prompt.options.random_seed
return body
def set_usage(self, response, usage):
response.set_usage(
input=usage["prompt_tokens"],
output=usage["completion_tokens"],
)
class Mistral(_Shared, llm.Model):
def execute(self, prompt, stream, response, conversation):
key = self.get_key()
messages = self.build_messages(prompt, conversation)
response._prompt_json = {"messages": messages}
body = self.build_body(prompt, messages)
if stream:
body["stream"] = True
with httpx.Client() as client:
with connect_sse(
client,
"POST",
"https://api.mistral.ai/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {key}",
},
json=body,
timeout=None,
) as event_source:
# In case of unauthorized:
if event_source.response.status_code != 200:
# Try to make this a readable error, it may have a base64 chunk
try:
decoded = json.loads(event_source.response.read())
type = decoded["type"]
words = decoded["message"].split()
except (json.JSONDecodeError, KeyError):
click.echo(
event_source.response.read().decode()[:200], err=True
)
event_source.response.raise_for_status()
# Truncate any words longer than 30 characters
words = [word[:30] for word in words]
message = " ".join(words)
raise click.ClickException(
f"{event_source.response.status_code}: {type} - {message}"
)
usage = None
event_source.response.raise_for_status()
for sse in event_source.iter_sse():
if sse.data != "[DONE]":
try:
event = sse.json()
if "usage" in event:
usage = event["usage"]
yield event["choices"][0]["delta"]["content"]
except KeyError:
pass
if usage:
self.set_usage(response, usage)
else:
with httpx.Client() as client:
api_response = client.post(
"https://api.mistral.ai/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {key}",
},
json=body,
timeout=None,
)
api_response.raise_for_status()
yield api_response.json()["choices"][0]["message"]["content"]
details = api_response.json()
usage = details.pop("usage", None)
response.response_json = details
if usage:
self.set_usage(response, usage)
class AsyncMistral(_Shared, llm.AsyncModel):
async def execute(self, prompt, stream, response, conversation):
key = self.get_key()
messages = self.build_messages(prompt, conversation)
response._prompt_json = {"messages": messages}
body = self.build_body(prompt, messages)
if stream:
body["stream"] = True
async with httpx.AsyncClient() as client:
async with aconnect_sse(
client,
"POST",
"https://api.mistral.ai/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {key}",
},
json=body,
timeout=None,
) as event_source:
# In case of unauthorized:
if event_source.response.status_code != 200:
# Try to make this a readable error, it may have a base64 chunk
try:
decoded = json.loads(event_source.response.read())
type = decoded["type"]
words = decoded["message"].split()
except (json.JSONDecodeError, KeyError):
click.echo(
event_source.response.read().decode()[:200], err=True
)
event_source.response.raise_for_status()
# Truncate any words longer than 30 characters
words = [word[:30] for word in words]
message = " ".join(words)
raise click.ClickException(
f"{event_source.response.status_code}: {type} - {message}"
)
event_source.response.raise_for_status()
usage = None
async for sse in event_source.aiter_sse():
if sse.data != "[DONE]":
try:
event = sse.json()
if "usage" in event:
usage = event["usage"]
yield event["choices"][0]["delta"]["content"]
except KeyError:
pass
if usage:
self.set_usage(response, usage)
else:
async with httpx.AsyncClient() as client:
api_response = await client.post(
"https://api.mistral.ai/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {key}",
},
json=body,
timeout=None,
)
api_response.raise_for_status()
yield api_response.json()["choices"][0]["message"]["content"]
details = api_response.json()
usage = details.pop("usage", None)
response.response_json = details
if usage:
self.set_usage(response, usage)
class MistralEmbed(llm.EmbeddingModel):
model_id = "mistral-embed"
batch_size = 10
needs_key = "mistral"
key_env_var = "LLM_MISTRAL_KEY"
def embed_batch(self, texts):
key = self.get_key()
with httpx.Client() as client:
api_response = client.post(
"https://api.mistral.ai/v1/embeddings",
headers={
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {key}",
},
json={
"model": "mistral-embed",
"input": list(texts),
"encoding_format": "float",
},
timeout=None,
)
api_response.raise_for_status()
return [item["embedding"] for item in api_response.json()["data"]]
---
llm_groq2.py
---
import llm
from groq import Groq, AsyncGroq
from pydantic import Field
from typing import Optional, List, Union
model_map: dict = {
"groq-gemma": "gemma-7b-it",
"groq-gemma2": "gemma2-9b-it",
"groq-llama2": "llama2-70b-4096",
"groq-llama3": "llama3-8b-8192",
"groq-llama3-70b": "llama3-70b-8192",
"groq-mixtral": "mixtral-8x7b-32768",
"groq-llama3.1-8b": "llama-3.1-8b-instant",
"groq-llama3.1-70b": "llama-3.1-70b-versatile",
"groq-llama3.1-405b": "llama-3.1-405b-reasoning",
"groq-llama-3.3-70b": "llama-3.3-70b-versatile",
}
@llm.hookimpl
def register_models(register):
for model_id in model_map:
register(LLMGroq(model_id), LLMAsyncGroq(model_id))
class _Options(llm.Options):
temperature: Optional[float] = Field(
description=(
"Controls randomness of responses. A lower temperature leads to"
"more predictable outputs while a higher temperature results in"
"more varies and sometimes more creative outputs."
"As the temperature approaches zero, the model will become deterministic"
"and repetitive."
),
ge=0,
le=1,
default=None,
)
top_p: Optional[float] = Field(
description=(
"Controls randomness of responses. A lower temperature leads to"
"more predictable outputs while a higher temperature results in"
"more varies and sometimes more creative outputs."
"0.5 means half of all likelihood-weighted options are considered."
),
ge=0,
le=1,
default=None,
)
max_tokens: Optional[int] = Field(
description=(
"The maximum number of tokens that the model can process in a"
"single response. This limits ensures computational efficiency"
"and resource management."
"Requests can use up to 2048 tokens shared between prompt and completion."
),
ge=0,
lt=2049,
default=None,
)
stop: Optional[Union[str, List[str]]] = Field(
description=(
"A stop sequence is a predefined or user-specified text string that"
"signals an AI to stop generating content, ensuring its responses"
"remain focused and concise. Examples include punctuation marks and"
'markers like "[end]".'
'For this example, we will use ", 6" so that the llm stops counting at 5.'
"If multiple stop values are needed, an array of string may be passed,"
'stop=[", 6", ", six", ", Six"]'
),
default=None,
)
class _Shared:
def __init__(self, model_id):
self.model_id = model_id
def build_messages(self, prompt, conversation):
messages = []
if not conversation:
if prompt.system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages
current_system = None
for prev_response in conversation.responses:
if (
prev_response.prompt.system
and prev_response.prompt.system != current_system
):
messages.append(
{"role": "system", "content": prev_response.prompt.system}
)
current_system = prev_response.prompt.system
messages.append({"role": "user", "content": prev_response.prompt.prompt})
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and prompt.system != current_system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages
class LLMGroq(llm.Model, _Shared):
can_stream = True
Options = _Options
def execute(self, prompt, stream, response, conversation):
key = llm.get_key("", "groq", "LLM_GROQ_KEY")
messages = self.build_messages(prompt, conversation)
client = Groq(api_key=key)
resp = client.chat.completions.create(
messages=messages,
model=model_map[self.model_id],
stream=stream,
temperature=prompt.options.temperature,
top_p=prompt.options.top_p,
max_tokens=prompt.options.max_tokens,
stop=prompt.options.stop,
)
if stream:
for chunk in resp:
if chunk.choices[0] and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
else:
yield resp.choices[0].message.content
class LLMAsyncGroq(llm.AsyncModel, _Shared):
can_stream = True
Options = _Options
async def execute(self, prompt, stream, response, conversation):
key = llm.get_key("", "groq", "LLM_GROQ_KEY")
messages = self.build_messages(prompt, conversation)
client = AsyncGroq(api_key=key)
resp = await client.chat.completions.create(
messages=messages,
model=model_map[self.model_id],
stream=stream,
temperature=prompt.options.temperature,
top_p=prompt.options.top_p,
max_tokens=prompt.options.max_tokens,
stop=prompt.options.stop,
)
if stream:
async for chunk in resp:
if chunk.choices[0] and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
else:
yield resp.choices[0].message.content
---
Implement the pattern in llm_groq.py where a refresh command and function hits https://api.groq.com/openai/v1/models (needs a Authorization: Bearer $KEY header) and saves the results in JSON, then uses that for model listing later on. Also implement the "llm groq refresh" command. Output the full updated llm_groq.py file
<think>
Okay, so I need to implement the same pattern as in the Mistral example for the Groq models. Let me see what that involves.
First, looking at the Mistral code, there's a function called refresh_models() that fetches model details from an API and saves them to a JSON file. Then, get_model_details() reads from this JSON to provide the list of models. The register_models hook uses this to register the models.
So, for Groq, I'll need to do something similar. I'll need a function to call the Groq API, fetch the models, and save them. Then, a function to get these models, either from the saved JSON or by fetching them if necessary.
First step: Create a refresh_models() function. It should make a GET request to the Groq API, which is at https://api.groq.com/openai/v1/models. The request needs an Authorization header with the Bearer token, which is stored in the LLM_GROQ_KEY environment variable.
I'll write this function, handling exceptions and saving the models to a user directory, similar to Mistral. The file path will be something like user_dir / "groq_models.json".
Next, the get_model_details() function. It should check if the JSON file exists. If it does, read from it. If not, and if the API key is set, call refresh_models() to fetch and save the models.
Then, in the register_models hook, instead of hardcoding the model_map, I should use the models fetched from get_model_details(). Each model from the API will have an "id" field. I'll map each "id" to the corresponding model_id in the register call.
Wait, but in the existing model_map, the keys are the model IDs we use (like "groq-gemma"), and the values are the actual model names used in the API. So, perhaps I need to invert this mapping. When I get the list of models from the API, each with an "id", I can create model IDs in the format "groq-{model_name}" and use the "id" as the actual model name.
Hmm, maybe I should adjust the model_map to be a dictionary that maps the API model IDs to our local model IDs. Or perhaps it's better to generate the model_map dynamically based on the fetched models.
Wait, in the Mistral code, they have DEFAULT_ALIASES that map their model IDs to the API model IDs. So, for Groq, I might not need to hardcode the model_map. Instead, when models are fetched, each model's "id" can be used to create the model ID as "groq-{id}". Then, the register_models function can register each model.
So, the register_models function will loop through each model in get_model_details() and register the Groq models accordingly.
Next, the command. The Mistral code has a refresh command under the mistral group. Similarly, I'll add a refresh command under a groq group. This command will call the refresh_models() function and display the changes.
I also need to handle exceptions, like when the API key isn't set. So, in refresh_models(), if the key is missing, raise an error.
Putting it all together, I'll modify the llm_groq.py file. I'll add the refresh_models() and get_model_details() functions, update the register_models hook to use the fetched models, and add the refresh command under a new groq group.
I should also ensure that the user directory is handled correctly, using llm.user_dir() to find the path to save the models JSON.
Testing the code, I need to make sure that when "llm groq refresh" is run, it fetches the models and updates the available models. The command should output whether any models were added or removed.
Finally, I'll make sure that all necessary imports are included, like click for the CLI commands, and handle any potential HTTP errors when fetching the models.
Wait, in the existing code, the model_map is hardcoded. So, perhaps the approach is to fetch the models and then generate the model_map dynamically. But the current code in register_models uses model_map. So, maybe I need to change that.
Alternatively, the model_map can be built based on the models fetched. Or, perhaps the Groq API returns all the models, and each can be registered on the fly without a fixed model_map.
In the Mistral example, they have a DEFAULT_ALIASES that maps their model IDs to the API model IDs. So, perhaps for Groq, we can have a similar setup, but since the model IDs are dynamic, we can just register each model as "groq-{model_id}".
Wait, perhaps it's better to let the model_id be whatever the API returns, and then the user can use them directly. Or, maybe the model_ids in the LLM are prefixed with "groq-", and the actual model name is taken from the API.
I think the approach is to fetch the models, then for each model in the API response, create a model ID like "groq-" + model_id, and register the Groq model with that ID, mapping to the actual model_id from the API.
So, in the register_models function, after getting the models, for each model in models, create a model_id like "groq-" + model.id, and then call register with LLMGroq(model_id) and LLMAsyncGroq(model_id), using the model.id as the actual model name in the API.
Wait, but in the existing code, the model_map is a hardcoded dictionary. So, perhaps I should change it to dynamically generate this mapping.
Alternatively, perhaps the model_map is no longer needed. Instead, when the models are fetched, each model's "id" is used to create the model_id in the format "groq-{id}", and the actual model name is the "id" itself. So, in the Groq client, when calling the API, it uses the model's "id" as the model parameter.
So, the build_messages function can remain as is, and the execute function will use the model's "id" from the model_id.
Wait, no, in the existing code, the model is passed as model_map[self.model_id], which is hardcoded. So, perhaps I should change that to use the model's "id" directly.
Alternatively, perhaps the model_map can be dynamic. So, after fetching the models, create a mapping from "groq-{id}" to "id" for each model.
But this might complicate things. Alternatively, perhaps the model can be passed directly without the model_map, as the model_id is the API model name.
Wait, in the Mistral example, the model_id is "mistral/" + model_id, which is then used in the API. So, perhaps for Groq, the model_id is just the model's "id" from the API, without any prefix.
Wait, no, in the Mistral code, the register function is called with Mistral(our_model_id, mistral_model_id, vision), where our_model_id is "mistral/" + model_id, and mistral_model_id is the model_id from the API.
So, for Groq, perhaps the model_id in the register function is "groq-" + model.id, and the actual model name is model.id.
So, in the Groq client, when calling the API, it uses the model.id as the model parameter.
Thus, in the execute function, the client.chat.completions.create uses model=model_map[self.model_id], but if model_map is now dynamic, perhaps it's model=self.mistral_model_id or similar.
Wait, perhaps I'm overcomplicating. Let me look again.
In the Mistral code, the model has a mistral_model_id, which is the model_id from the API. So, the execute function uses that to make the API call.
So, for Groq, I can have a similar setup: each model registered has a model_id (like "groq-gemma") and the actual model name is stored as groq_model_id, which is used in the API call.
So, in the _Shared class, I'll have:
def __init__(self, model_id):
self.model_id = model_id
self.groq_model_id = model_id # or extract the part after 'groq-'
Wait, perhaps self.groq_model_id is the model ID from the API, so when the model is "groq-gemma", the groq_model_id is "gemma".
Wait, but in the model_map, "groq-gemma" maps to "gemma-7b-it". So, perhaps the model_map is needed to map the local model IDs to the actual Groq model names.
But since the model_map is hardcoded, and the API returns available models, perhaps the model_map should be built dynamically from the API response.
Alternatively, perhaps the model_map can be eliminated, and the models are registered based on the fetched models.
Wait, perhaps the model_map is no longer necessary. Instead, when the models are fetched from the API, each model's "id" is the actual model name used in the Groq API. So, when registering the models, the local model ID is "groq-{id}", and the groq_model_id is "id".
So, in the register_models function, for each model in models, we create a model_id like "groq-" + model.id, and then pass that to the Groq model, which stores it as self.model_id, and the actual model name is model.id.
Thus, in the execute function, the model parameter will be model.id, which is passed as model=self.groq_model_id.
Wait, I'm getting a bit confused. Let me think again.
In the Mistral code, the model has an our_model_id (like "mistral/tiny") and a mistral_model_id (like "tiny"). So, in the execute function, it uses mistral_model_id to make the API call.
So, for Groq, perhaps the model has our_model_id (like "groq/gemma") and a groq_model_id (like "gemma-7b-it").
But in the model_map, "groq-gemma" maps to "gemma-7b-it". So, perhaps the model_map is still needed to map the local model IDs to the actual Groq model names.
Therefore, when the models are fetched, each model's "id" is the local model ID (like "groq-gemma"), and the actual model name is obtained from the model_map.
But that would require the model_map to be dynamic, which it's not. So, perhaps the initial approach is not suitable.
Alternatively, perhaps the model_map is no longer necessary, and the models are registered based on the API's response.
Wait, perhaps I'm overcomplicating. Let me proceed step by step.
First, I'll implement the refresh_models() function to fetch the models from the Groq API and save them to a JSON file. Then, the get_model_details() function will read this JSON.
Next, in register_models(), instead of hardcoding model_map, I'll loop through each model from get_model_details() and register the Groq models with their model IDs.
So, the model_map may not be needed anymore. Instead, each model's "id" from the API becomes the model ID in the LLM.
But wait, in the existing code, the model_map is used to map the registered model IDs to the actual Groq model names. So, if the model IDs from the API are different, perhaps the model_map should be built dynamically.
Alternatively, perhaps the model IDs in the LLM are the same as the ones returned by the API, and the model_map is no longer needed.
Wait, perhaps the model_map is not needed. Let me look at the execute function. It uses model_map[self.model_id] to get the actual model name. So, if the model_id is "groq-gemma", then model_map["groq-gemma"] is "gemma-7b-it".
But if I'm fetching the models from the API, the model_id in the LLM should be the same as the "id" from the API, and the model_map would map them to the corresponding Groq model names.
But that would require the model_map to be built dynamically based on the API response.
Alternatively, perhaps the model_map is no longer necessary, and the execute function uses the model_id directly. But that might not be the case because the model names in the API might have a different format.
Alternatively, perhaps the model_map is just a way to alias the model IDs, but it's not necessary. So, perhaps the model_map can be eliminated, and the model_id is passed directly to the Groq client.
Wait, looking at the execute function, it uses model_map[self.model_id] as the model parameter in the Groq client. So, if I can get the list of models from the API, I can build a model_map dynamically where each model's "id" is the key, and the value is the model name from the API.
So, in the get_model_details function, for each model, the "id" is the model ID used in the LLM, and the actual model name is stored elsewhere, perhaps in a dictionary.
But I'm getting stuck here. Maybe I should proceed with the initial plan: implement the refresh function, get_model_details, and then adjust the register_models to loop through the models and register each one.
Let me outline the steps:
- Add the refresh_models() function to fetch models from the Groq API and save to a JSON file.
- Add the get_model_details() function to read from the JSON or fetch if needed.
- Modify the register_models() hook to loop through the models from get_model_details() and register each as a Groq model.
- Add the refresh command under a groq group.
I'll proceed with these steps, adjusting the code accordingly.
</think>
To implement the pattern from llm-mistral.py
into llm-groq.py
, we need to add functionality to refresh the models from the Groq API, save them, and use them for model registration. Here's how to update the llm_groq.py
file:
import click
import httpx
import json
import llm
from groq import Groq, AsyncGroq
from pydantic import Field
from typing import Optional, List, Union
MODEL_ENDPOINT = "https://api.groq.com/openai/v1/models"
@llm.hookimpl
def register_models(register):
models = get_model_details()
for model in models:
model_id = model["id"]
register(LLMGroq(model_id), LLMAsyncGroq(model_id))
def refresh_models():
user_dir = llm.user_dir()
groq_models = user_dir / "groq_models.json"
key = llm.get_key("", "groq", "LLM_GROQ_KEY")
if not key:
raise click.ClickException(
"You must set the 'groq' key or the LLM_GROQ_KEY environment variable."
)
response = httpx.get(
MODEL_ENDPOINT,
headers={"Authorization": f"Bearer {key}"},
)
response.raise_for_status()
models = response.json()
groq_models.write_text(json.dumps({"data": models["models"]}, indent=2))
return models["models"]
def get_model_details():
user_dir = llm.user_dir()
default_models = {"data": []}
groq_models = user_dir / "groq_models.json"
if groq_models.exists():
models = json.loads(groq_models.read_text())
return models.get("data", [])
elif llm.get_key("", "groq", "LLM_GROQ_KEY"):
try:
return refresh_models()
except httpx.HTTPStatusError:
return default_models["data"]
else:
return default_models["data"]
@llm.hookimpl
def register_commands(cli):
@cli.group()
def groq():
"Commands relating to the llm-groq plugin"
@groq.command()
def refresh():
"Refresh the list of available Groq models"
user_dir = llm.user_dir()
groq_models = user_dir / "groq_models.json"
try:
models = refresh_models()
except httpx.HTTPStatusError:
click.echo("Failed to refresh models", err=True)
return
previous = set(json.loads(groq_models.read_text())["data"]) if groq_models.exists() else set()
current = {model["id"] for model in models}
added = current - previous
removed = previous - current
if added:
click.echo(f"Added models: {', '.join(added)}", err=True)
if removed:
click.echo(f"Removed models: {', '.join(removed)}", err=True)
if added or removed:
click.echo("New list of models:", err=True)
for model in models:
click.echo(model["id"], err=True)
else:
click.echo("No changes", err=True)
class _Options(llm.Options):
temperature: Optional[float] = Field(
description=(
"Controls randomness of responses. Lower values result in more"
"predictable outputs, while higher values lead to more varied and"
"creative outputs."
),
ge=0,
le=1,
default=None,
)
top_p: Optional[float] = Field(
description=(
"Controls randomness by considering the top P probability mass."
"A lower value makes the output more focused but less creative."
),
ge=0,
le=1,
default=None,
)
max_tokens: Optional[int] = Field(
description=(
"The maximum number of tokens to generate in the completion. The"
"combined prompt and completion tokens must be under 2048."
),
ge=0,
lt=2049,
default=None,
)
stop: Optional[Union[str, List[str]]] = Field(
description=(
"Sequence that stops the generation. Multiple stop sequences can be"
"specified as a list."
),
default=None,
)
class _Shared:
def __init__(self, model_id):
self.model_id = model_id
def build_messages(self, prompt, conversation):
messages = []
if not conversation:
if prompt.system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages
current_system = None
for prev_response in conversation.responses:
if prev_response.prompt.system and prev_response.prompt.system != current_system:
messages.append({"role": "system", "content": prev_response.prompt.system})
current_system = prev_response.prompt.system
messages.append({"role": "user", "content": prev_response.prompt.prompt})
messages.append({"role": "assistant", "content": prev_response.text()})
if prompt.system and current_system != prompt.system:
messages.append({"role": "system", "content": prompt.system})
messages.append({"role": "user", "content": prompt.prompt})
return messages
def build_body(self, messages, options):
return {
"model": self.model_id,
"messages": messages,
"temperature": options.temperature,
"top_p": options.top_p,
"max_tokens": options.max_tokens,
"stop": options.stop,
}
def set_usage(self, response, usage):
response.set_usage(
input=usage.get("prompt_tokens"),
output=usage.get("completion_tokens"),
)
class LLMGroq(llm.Model, _Shared):
can_stream = True
Options = _Options
def execute(self, prompt, stream, response, conversation):
key = self.get_key()
messages = self.build_messages(prompt, conversation)
body = self.build_body(messages, prompt.options)
client = Groq(api_key=key)
resp = client.chat.completions.create(
model=self.model_id, messages=messages, stream=stream, **prompt.options.dict()
)
try:
if stream:
for chunk in resp:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
else:
yield resp.choices[0].message.content
except AttributeError as e:
if "NoneType" in str(e):
response = resp.choices[0].message
yield response.content
else:
raise e
finally:
if getattr(resp, "usage", None):
self.set_usage(response, resp.usage)
class LLMAsyncGroq(llm.AsyncModel, _Shared):
can_stream = True
Options = _Options
async def execute(self, prompt, stream, response, conversation):
key = self.get_key()
messages = self.build_messages(prompt, conversation)
options = prompt.options.dict()
client = AsyncGroq(api_key=key)
resp = await client.chat.completions.create(
model=self.model_id, messages=messages, stream=stream, **options
)
try:
if stream:
async for chunk in resp:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
else:
yield resp.choices[0].message.content
except AttributeError as e:
if "NoneType" in str(e):
response = resp.choices[0].message
yield response.content
else:
raise e
finally:
if getattr(resp, "usage", None):
self.set_usage(response, resp.usage)
-
refresh_models()
Function:- Fetches available models from the Groq API.
- Saves the models to
groq_models.json
in the user directory.
-
get_model_details()
Function:- Reads saved models from
groq_models.json
if available. - Fetches models from the API if the file doesn't exist and the API key is set.
- Reads saved models from
-
register_models
Hook:- Registers each model obtained from
get_model_details()
. - This allows dynamic registration based on the API response.
- Registers each model obtained from
-
refresh
Command:- Added under the
groq
group. - Calls
refresh_models()
and displays added or removed models.
- Added under the
This implementation mirrors the Mistral plugin's approach, providing a dynamic and refreshable list of models while maintaining compatibility with the existing LLM framework. Users can now update their model list using the llm groq refresh
command.
Key changes:
retry_initial_delay
,retry_multiplier
,retry_max_delay
, andmax_retries
to the_Options
class using PydanticField
for configuration and validation. These control the retry behavior._handle_rate_limit
Method: Created a new method in the_Shared
class to encapsulate the rate limit handling logic. This promotes code reuse and makes theexecute
methods cleaner.client.chat.completions.create
), its arguments, and keyword arguments.while
loop for retries (up tomax_retries
).try...except
block to catchhttpx.HTTPStatusError
.Retry-After
header if present and uses that for the delay. Attempts to parse it as both seconds and an HTTP-date.Retry-After
is not present or cannot be parsed, uses exponential backoff (delay * multiplier, capped atmax_delay
).asyncio.sleep()
for asynchronous delays andtime.sleep()
for synchronous.execute
Methods:LLMGroq.execute
andLLMAsyncGroq.execute
now callself._handle_rate_limit
, passing the Groq API call and its parameters.options
and unpacked as**options
to ensure Groq's client takes precedence._handle_rate_limit
also catches generic exception to avoid unhandled ones._handle_rate_limit
method usesasyncio.sleep()
when called from the asynchronousexecute
method, ensuring non-blocking behavior.This revised code provides explicit rate limit handling using a retry-with-backoff strategy, making the plugin more robust when interacting with the Groq API. It also allows users to configure the retry parameters. This version is better because it encapsulates retry, handles both sync/async properly, reads retry-after, and uses exponential backoff.