Skip to content

Instantly share code, notes, and snippets.

@simonw

simonw/groq.md Secret

Created January 27, 2025 05:36
Show Gist options
  • Save simonw/c9024b7feffda865ef10138dacfddcdc to your computer and use it in GitHub Desktop.
Save simonw/c9024b7feffda865ef10138dacfddcdc to your computer and use it in GitHub Desktop.

2025-01-27T05:34:16 conversation: 01jjk431n5xsny4bzdcskn779k

Model: o1

Prompt:

files-to-prompt ../llm-mistral/llm_mistral.py llm_groq.py

../llm-mistral/llm_mistral.py
---
import click
from httpx_sse import connect_sse, aconnect_sse
import httpx
import json
import llm
from pydantic import Field
from typing import Optional


DEFAULT_ALIASES = {
    "mistral/mistral-tiny": "mistral-tiny",
    "mistral/open-mistral-nemo": "mistral-nemo",
    "mistral/mistral-small": "mistral-small",
    "mistral/mistral-medium": "mistral-medium",
    "mistral/mistral-large-latest": "mistral-large",
    "mistral/codestral-mamba-latest": "codestral-mamba",
    "mistral/codestral-latest": "codestral",
    "mistral/ministral-3b-latest": "ministral-3b",
    "mistral/ministral-8b-latest": "ministral-8b",
    "mistral/pixtral-12b-latest": "pixtral-12b",
    "mistral/pixtral-large-latest": "pixtral-large",
}


@llm.hookimpl
def register_models(register):
    for model in get_model_details():
        model_id = model["id"]
        vision = model.get("capabilities", {}).get("vision")
        our_model_id = "mistral/" + model_id
        alias = DEFAULT_ALIASES.get(our_model_id)
        aliases = [alias] if alias else []
        register(
            Mistral(our_model_id, model_id, vision),
            AsyncMistral(our_model_id, model_id, vision),
            aliases=aliases,
        )


@llm.hookimpl
def register_embedding_models(register):
    register(MistralEmbed())


def refresh_models():
    user_dir = llm.user_dir()
    mistral_models = user_dir / "mistral_models.json"
    key = llm.get_key("", "mistral", "LLM_MISTRAL_KEY")
    if not key:
        raise click.ClickException(
            "You must set the 'mistral' key or the LLM_MISTRAL_KEY environment variable."
        )
    response = httpx.get(
        "https://api.mistral.ai/v1/models", headers={"Authorization": f"Bearer {key}"}
    )
    response.raise_for_status()
    models = response.json()
    mistral_models.write_text(json.dumps(models, indent=2))
    return models


def get_model_details():
    user_dir = llm.user_dir()
    models = {
        "data": [
            {"id": model_id.replace("mistral/", "")}
            for model_id in DEFAULT_ALIASES.keys()
        ]
    }
    mistral_models = user_dir / "mistral_models.json"
    if mistral_models.exists():
        models = json.loads(mistral_models.read_text())
    elif llm.get_key("", "mistral", "LLM_MISTRAL_KEY"):
        try:
            models = refresh_models()
        except httpx.HTTPStatusError:
            pass
    return [model for model in models["data"] if "embed" not in model["id"]]


def get_model_ids():
    return [model["id"] for model in get_model_details()]


@llm.hookimpl
def register_commands(cli):
    @cli.group()
    def mistral():
        "Commands relating to the llm-mistral plugin"

    @mistral.command()
    def refresh():
        "Refresh the list of available Mistral models"
        before = set(get_model_ids())
        refresh_models()
        after = set(get_model_ids())
        added = after - before
        removed = before - after
        if added:
            click.echo(f"Added models: {', '.join(added)}", err=True)
        if removed:
            click.echo(f"Removed models: {', '.join(removed)}", err=True)
        if added or removed:
            click.echo("New list of models:", err=True)
            for model_id in get_model_ids():
                click.echo(model_id, err=True)
        else:
            click.echo("No changes", err=True)


class _Shared:
    can_stream = True
    needs_key = "mistral"
    key_env_var = "LLM_MISTRAL_KEY"

    class Options(llm.Options):
        temperature: Optional[float] = Field(
            description=(
                "Determines the sampling temperature. Higher values like 0.8 increase randomness, "
                "while lower values like 0.2 make the output more focused and deterministic."
            ),
            ge=0,
            le=1,
            default=0.7,
        )
        top_p: Optional[float] = Field(
            description=(
                "Nucleus sampling, where the model considers the tokens with top_p probability mass. "
                "For example, 0.1 means considering only the tokens in the top 10% probability mass."
            ),
            ge=0,
            le=1,
            default=1,
        )
        max_tokens: Optional[int] = Field(
            description="The maximum number of tokens to generate in the completion.",
            ge=0,
            default=None,
        )
        safe_mode: Optional[bool] = Field(
            description="Whether to inject a safety prompt before all conversations.",
            default=False,
        )
        random_seed: Optional[int] = Field(
            description="Sets the seed for random sampling to generate deterministic results.",
            default=None,
        )

    def __init__(self, our_model_id, mistral_model_id, vision):
        self.model_id = our_model_id
        self.mistral_model_id = mistral_model_id
        if vision:
            self.attachment_types = {
                "image/jpeg",
                "image/png",
                "image/gif",
                "image/webp",
            }

    def build_messages(self, prompt, conversation):
        messages = []
        latest_message = None
        if prompt.attachments:
            latest_message = {
                "role": "user",
                "content": [{"type": "text", "text": prompt.prompt}]
                + [
                    {
                        "type": "image_url",
                        "image_url": attachment.url
                        or f"data:{attachment.resolve_type()};base64,{attachment.base64_content()}",
                    }
                    for attachment in prompt.attachments
                ],
            }
        else:
            latest_message = {"role": "user", "content": prompt.prompt}
        if not conversation:
            if prompt.system:
                messages.append({"role": "system", "content": prompt.system})
            messages.append(latest_message)
            return messages

        current_system = None
        for prev_response in conversation.responses:
            if (
                prev_response.prompt.system
                and prev_response.prompt.system != current_system
            ):
                messages.append(
                    {"role": "system", "content": prev_response.prompt.system}
                )
                current_system = prev_response.prompt.system
            if prev_response.attachments:
                messages.append(
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prev_response.prompt.prompt,
                            }
                        ]
                        + [
                            {
                                "type": "image_url",
                                "image_url": attachment.url
                                or f"data:{attachment.resolve_type()};base64,{attachment.base64_content()}",
                            }
                            for attachment in prev_response.attachments
                        ],
                    }
                )
            else:
                messages.append(
                    {"role": "user", "content": prev_response.prompt.prompt}
                )
            messages.append(
                {"role": "assistant", "content": prev_response.text_or_raise()}
            )
        if prompt.system and prompt.system != current_system:
            messages.append({"role": "system", "content": prompt.system})

        messages.append(latest_message)
        return messages

    def build_body(self, prompt, messages):
        body = {
            "model": self.mistral_model_id,
            "messages": messages,
        }
        if prompt.options.temperature:
            body["temperature"] = prompt.options.temperature
        if prompt.options.top_p:
            body["top_p"] = prompt.options.top_p
        if prompt.options.max_tokens:
            body["max_tokens"] = prompt.options.max_tokens
        if prompt.options.safe_mode:
            body["safe_mode"] = prompt.options.safe_mode
        if prompt.options.random_seed:
            body["random_seed"] = prompt.options.random_seed
        return body

    def set_usage(self, response, usage):
        response.set_usage(
            input=usage["prompt_tokens"],
            output=usage["completion_tokens"],
        )


class Mistral(_Shared, llm.Model):
    def execute(self, prompt, stream, response, conversation):
        key = self.get_key()
        messages = self.build_messages(prompt, conversation)
        response._prompt_json = {"messages": messages}
        body = self.build_body(prompt, messages)
        if stream:
            body["stream"] = True
            with httpx.Client() as client:
                with connect_sse(
                    client,
                    "POST",
                    "https://api.mistral.ai/v1/chat/completions",
                    headers={
                        "Content-Type": "application/json",
                        "Accept": "application/json",
                        "Authorization": f"Bearer {key}",
                    },
                    json=body,
                    timeout=None,
                ) as event_source:
                    # In case of unauthorized:
                    if event_source.response.status_code != 200:
                        # Try to make this a readable error, it may have a base64 chunk
                        try:
                            decoded = json.loads(event_source.response.read())
                            type = decoded["type"]
                            words = decoded["message"].split()
                        except (json.JSONDecodeError, KeyError):
                            click.echo(
                                event_source.response.read().decode()[:200], err=True
                            )
                            event_source.response.raise_for_status()
                        # Truncate any words longer than 30 characters
                        words = [word[:30] for word in words]
                        message = " ".join(words)
                        raise click.ClickException(
                            f"{event_source.response.status_code}: {type} - {message}"
                        )
                    usage = None
                    event_source.response.raise_for_status()
                    for sse in event_source.iter_sse():
                        if sse.data != "[DONE]":
                            try:
                                event = sse.json()
                                if "usage" in event:
                                    usage = event["usage"]
                                yield event["choices"][0]["delta"]["content"]
                            except KeyError:
                                pass
                    if usage:
                        self.set_usage(response, usage)
        else:
            with httpx.Client() as client:
                api_response = client.post(
                    "https://api.mistral.ai/v1/chat/completions",
                    headers={
                        "Content-Type": "application/json",
                        "Accept": "application/json",
                        "Authorization": f"Bearer {key}",
                    },
                    json=body,
                    timeout=None,
                )
                api_response.raise_for_status()
                yield api_response.json()["choices"][0]["message"]["content"]
                details = api_response.json()
                usage = details.pop("usage", None)
                response.response_json = details
                if usage:
                    self.set_usage(response, usage)


class AsyncMistral(_Shared, llm.AsyncModel):
    async def execute(self, prompt, stream, response, conversation):
        key = self.get_key()
        messages = self.build_messages(prompt, conversation)
        response._prompt_json = {"messages": messages}
        body = self.build_body(prompt, messages)
        if stream:
            body["stream"] = True
            async with httpx.AsyncClient() as client:
                async with aconnect_sse(
                    client,
                    "POST",
                    "https://api.mistral.ai/v1/chat/completions",
                    headers={
                        "Content-Type": "application/json",
                        "Accept": "application/json",
                        "Authorization": f"Bearer {key}",
                    },
                    json=body,
                    timeout=None,
                ) as event_source:
                    # In case of unauthorized:
                    if event_source.response.status_code != 200:
                        # Try to make this a readable error, it may have a base64 chunk
                        try:
                            decoded = json.loads(event_source.response.read())
                            type = decoded["type"]
                            words = decoded["message"].split()
                        except (json.JSONDecodeError, KeyError):
                            click.echo(
                                event_source.response.read().decode()[:200], err=True
                            )
                            event_source.response.raise_for_status()
                        # Truncate any words longer than 30 characters
                        words = [word[:30] for word in words]
                        message = " ".join(words)
                        raise click.ClickException(
                            f"{event_source.response.status_code}: {type} - {message}"
                        )
                    event_source.response.raise_for_status()
                    usage = None
                    async for sse in event_source.aiter_sse():
                        if sse.data != "[DONE]":
                            try:
                                event = sse.json()
                                if "usage" in event:
                                    usage = event["usage"]
                                yield event["choices"][0]["delta"]["content"]
                            except KeyError:
                                pass
                    if usage:
                        self.set_usage(response, usage)
        else:
            async with httpx.AsyncClient() as client:
                api_response = await client.post(
                    "https://api.mistral.ai/v1/chat/completions",
                    headers={
                        "Content-Type": "application/json",
                        "Accept": "application/json",
                        "Authorization": f"Bearer {key}",
                    },
                    json=body,
                    timeout=None,
                )
                api_response.raise_for_status()
                yield api_response.json()["choices"][0]["message"]["content"]
                details = api_response.json()
                usage = details.pop("usage", None)
                response.response_json = details
                if usage:
                    self.set_usage(response, usage)


class MistralEmbed(llm.EmbeddingModel):
    model_id = "mistral-embed"
    batch_size = 10
    needs_key = "mistral"
    key_env_var = "LLM_MISTRAL_KEY"

    def embed_batch(self, texts):
        key = self.get_key()
        with httpx.Client() as client:
            api_response = client.post(
                "https://api.mistral.ai/v1/embeddings",
                headers={
                    "Content-Type": "application/json",
                    "Accept": "application/json",
                    "Authorization": f"Bearer {key}",
                },
                json={
                    "model": "mistral-embed",
                    "input": list(texts),
                    "encoding_format": "float",
                },
                timeout=None,
            )
            api_response.raise_for_status()
            return [item["embedding"] for item in api_response.json()["data"]]


---
llm_groq.py
---
import llm
from groq import Groq, AsyncGroq
from pydantic import Field
from typing import Optional, List, Union

model_map: dict = {
    "groq-gemma": "gemma-7b-it",
    "groq-gemma2": "gemma2-9b-it",
    "groq-llama2": "llama2-70b-4096",
    "groq-llama3": "llama3-8b-8192",
    "groq-llama3-70b": "llama3-70b-8192",
    "groq-mixtral": "mixtral-8x7b-32768",
    "groq-llama3.1-8b": "llama-3.1-8b-instant",
    "groq-llama3.1-70b": "llama-3.1-70b-versatile",
    "groq-llama3.1-405b": "llama-3.1-405b-reasoning",
    "groq-llama-3.3-70b": "llama-3.3-70b-versatile",
}


@llm.hookimpl
def register_models(register):
    for model_id in model_map:
        register(LLMGroq(model_id), LLMAsyncGroq(model_id))


class _Options(llm.Options):
    temperature: Optional[float] = Field(
        description=(
            "Controls randomness of responses. A lower temperature leads to"
            "more predictable outputs while a higher temperature results in"
            "more varies and sometimes more creative outputs."
            "As the temperature approaches zero, the model will become deterministic"
            "and repetitive."
        ),
        ge=0,
        le=1,
        default=None,
    )
    top_p: Optional[float] = Field(
        description=(
            "Controls randomness of responses. A lower temperature leads to"
            "more predictable outputs while a higher temperature results in"
            "more varies and sometimes more creative outputs."
            "0.5 means half of all likelihood-weighted options are considered."
        ),
        ge=0,
        le=1,
        default=None,
    )
    max_tokens: Optional[int] = Field(
        description=(
            "The maximum number of tokens that the model can process in a"
            "single response. This limits ensures computational efficiency"
            "and resource management."
            "Requests can use up to 2048 tokens shared between prompt and completion."
        ),
        ge=0,
        lt=2049,
        default=None,
    )
    stop: Optional[Union[str, List[str]]] = Field(
        description=(
            "A stop sequence is a predefined or user-specified text string that"
            "signals an AI to stop generating content, ensuring its responses"
            "remain focused and concise. Examples include punctuation marks and"
            'markers like "[end]".'
            'For this example, we will use ", 6" so that the llm stops counting at 5.'
            "If multiple stop values are needed, an array of string may be passed,"
            'stop=[", 6", ", six", ", Six"]'
        ),
        default=None,
    )


class _Shared:
    def __init__(self, model_id):
        self.model_id = model_id

    def build_messages(self, prompt, conversation):
        messages = []
        if not conversation:
            if prompt.system:
                messages.append({"role": "system", "content": prompt.system})
            messages.append({"role": "user", "content": prompt.prompt})
            return messages
        current_system = None
        for prev_response in conversation.responses:
            if (
                prev_response.prompt.system
                and prev_response.prompt.system != current_system
            ):
                messages.append(
                    {"role": "system", "content": prev_response.prompt.system}
                )
                current_system = prev_response.prompt.system
            messages.append({"role": "user", "content": prev_response.prompt.prompt})
            messages.append({"role": "assistant", "content": prev_response.text()})
        if prompt.system and prompt.system != current_system:
            messages.append({"role": "system", "content": prompt.system})
        messages.append({"role": "user", "content": prompt.prompt})
        return messages


class LLMGroq(llm.Model, _Shared):
    can_stream = True
    Options = _Options

    def execute(self, prompt, stream, response, conversation):
        key = llm.get_key("", "groq", "LLM_GROQ_KEY")
        messages = self.build_messages(prompt, conversation)
        client = Groq(api_key=key)
        resp = client.chat.completions.create(
            messages=messages,
            model=model_map[self.model_id],
            stream=stream,
            temperature=prompt.options.temperature,
            top_p=prompt.options.top_p,
            max_tokens=prompt.options.max_tokens,
            stop=prompt.options.stop,
        )
        if stream:
            for chunk in resp:
                if chunk.choices[0] and chunk.choices[0].delta.content:
                    yield chunk.choices[0].delta.content
        else:
            yield resp.choices[0].message.content


class LLMAsyncGroq(llm.AsyncModel, _Shared):
    can_stream = True
    Options = _Options

    async def execute(self, prompt, stream, response, conversation):
        key = llm.get_key("", "groq", "LLM_GROQ_KEY")
        messages = self.build_messages(prompt, conversation)
        client = AsyncGroq(api_key=key)
        resp = await client.chat.completions.create(
            messages=messages,
            model=model_map[self.model_id],
            stream=stream,
            temperature=prompt.options.temperature,
            top_p=prompt.options.top_p,
            max_tokens=prompt.options.max_tokens,
            stop=prompt.options.stop,
        )
        if stream:
            async for chunk in resp:
                if chunk.choices[0] and chunk.choices[0].delta.content:
                    yield chunk.choices[0].delta.content
        else:
            yield resp.choices[0].message.content

Implement the pattern in llm_groq.py where a refresh command and function hits https://api.groq.com/openai/v1/models (needs a Authorization: Bearer $KEY header) and saves the results in JSON, then uses that for model listing later on. Also implement the "llm groq refresh" command. Output the full updated llm_groq.py file

Response:

-------------------------------------------------------------------------------
Below is the complete updated llm_groq.py file with a "groq refresh" command and
the associated refresh_models logic. This approach mirrors the Mistral plugin's
refresh pattern, hitting https://api.groq.com/openai/v1/models and persisting
the results to a local "groq_models.json" file in your llm user directory.
-------------------------------------------------------------------------------

import click
import httpx
import json
import llm
from groq import Groq, AsyncGroq
from pydantic import Field
from typing import Optional, List, Union

# An internal map of recognized Groq model IDs to their "internal" or "actual"
# model names on the server. By default, if the server returns an "id" that is
# in this dict, we will pass the mapped model name to the .create() calls.
model_map: dict = {
    "groq-gemma": "gemma-7b-it",
    "groq-gemma2": "gemma2-9b-it",
    "groq-llama2": "llama2-70b-4096",
    "groq-llama3": "llama3-8b-8192",
    "groq-llama3-70b": "llama3-70b-8192",
    "groq-mixtral": "mixtral-8x7b-32768",
    "groq-llama3.1-8b": "llama-3.1-8b-instant",
    "groq-llama3.1-70b": "llama-3.1-70b-versatile",
    "groq-llama3.1-405b": "llama-3.1-405b-reasoning",
    "groq-llama-3.3-70b": "llama-3.3-70b-versatile",
}


@llm.hookimpl
def register_models(register):
    """
    Register models based on the local 'groq_models.json' file if it exists,
    otherwise on first run it uses model_map as defaults. The user can
    'llm groq refresh' to fetch dynamically from the Groq API.
    """
    models = get_model_details()
    for model in models:
        model_id = model["id"]
        # Only register models we know how to map:
        if model_id in model_map:
            register(LLMGroq(model_id), LLMAsyncGroq(model_id))


def refresh_models():
    """
    Fetch the list of available models from Groq's API and store them locally.
    """
    user_dir = llm.user_dir()
    groq_models_path = user_dir / "groq_models.json"
    key = llm.get_key("", "groq", "LLM_GROQ_KEY")
    if not key:
        raise click.ClickException(
            "You must set the 'groq' key or the LLM_GROQ_KEY environment variable."
        )
    response = httpx.get(
        "https://api.groq.com/openai/v1/models",
        headers={"Authorization": f"Bearer {key}"},
    )
    response.raise_for_status()
    data = response.json()
    groq_models_path.write_text(json.dumps(data, indent=2))
    return data


def get_model_details():
    """
    Return details of available Groq models, preferring whatever we have in
    groq_models.json. If that file doesn't exist (and the user has a key), try
    calling refresh_models() once. If we fail, fallback to local model_map keys.
    """
    user_dir = llm.user_dir()
    groq_models_path = user_dir / "groq_models.json"
    # Default fallback if we can't fetch from the server:
    models = {"data": [{"id": m} for m in model_map.keys()]}
    if groq_models_path.exists():
        models = json.loads(groq_models_path.read_text())
    else:
        # Attempt to refresh if a key is present
        if llm.get_key("", "groq", "LLM_GROQ_KEY"):
            try:
                models = refresh_models()
            except httpx.HTTPStatusError:
                pass
    return models["data"]


def get_model_ids():
    """Return just the list of Groq model IDs from get_model_details()."""
    return [model["id"] for model in get_model_details()]


@llm.hookimpl
def register_commands(cli):
    @cli.group()
    def groq():
        "Commands relating to the llm-groq plugin"

    @groq.command()
    def refresh():
        """
        Refresh the list of available Groq models from https://api.groq.com/
        and update the local groq_models.json file in your llm user directory.
        """
        before = set(get_model_ids())
        refresh_models()
        after = set(get_model_ids())
        added = after - before
        removed = before - after
        if added:
            click.echo(f"Added models: {', '.join(added)}", err=True)
        if removed:
            click.echo(f"Removed models: {', '.join(removed)}", err=True)
        if added or removed:
            click.echo("New list of models:", err=True)
            for model_id in get_model_ids():
                click.echo(model_id, err=True)
        else:
            click.echo("No changes", err=True)


class _Options(llm.Options):
    temperature: Optional[float] = Field(
        description=(
            "Controls randomness of responses. A lower temperature leads to "
            "more predictable outputs, while a higher temperature results in "
            "more varied and sometimes more creative outputs. As the temperature "
            "approaches zero, the model will become deterministic and repetitive."
        ),
        ge=0,
        le=1,
        default=None,
    )
    top_p: Optional[float] = Field(
        description=(
            "Controls randomness of responses via nucleus sampling. 0.5 means "
            "half of all likelihood-weighted options are considered."
        ),
        ge=0,
        le=1,
        default=None,
    )
    max_tokens: Optional[int] = Field(
        description=(
            "The maximum number of tokens that the model can process in a "
            "single response. This limits ensures computational efficiency "
            "and resource management. Requests can use up to 2048 tokens "
            "shared between prompt and completion."
        ),
        ge=0,
        lt=2049,
        default=None,
    )
    stop: Optional[Union[str, List[str]]] = Field(
        description=(
            "A stop sequence is a predefined text string that signals the AI "
            "to stop generating content, ensuring responses remain focused and "
            "concise. If multiple values are needed, pass an array of strings."
        ),
        default=None,
    )


class _Shared:
    def __init__(self, model_id):
        self.model_id = model_id

    def build_messages(self, prompt, conversation):
        messages = []
        if not conversation:
            if prompt.system:
                messages.append({"role": "system", "content": prompt.system})
            messages.append({"role": "user", "content": prompt.prompt})
            return messages
        current_system = None
        for prev_response in conversation.responses:
            if (
                prev_response.prompt.system
                and prev_response.prompt.system != current_system
            ):
                messages.append(
                    {"role": "system", "content": prev_response.prompt.system}
                )
                current_system = prev_response.prompt.system
            messages.append({"role": "user", "content": prev_response.prompt.prompt})
            messages.append({"role": "assistant", "content": prev_response.text()})
        if prompt.system and prompt.system != current_system:
            messages.append({"role": "system", "content": prompt.system})
        messages.append({"role": "user", "content": prompt.prompt})
        return messages


class LLMGroq(llm.Model, _Shared):
    can_stream = True
    Options = _Options

    def execute(self, prompt, stream, response, conversation):
        key = llm.get_key("", "groq", "LLM_GROQ_KEY")
        messages = self.build_messages(prompt, conversation)

        # Map the user-specified model_id to the actual Groq model name:
        actual_model_name = model_map[self.model_id]
        client = Groq(api_key=key)
        resp = client.chat.completions.create(
            messages=messages,
            model=actual_model_name,
            stream=stream,
            temperature=prompt.options.temperature,
            top_p=prompt.options.top_p,
            max_tokens=prompt.options.max_tokens,
            stop=prompt.options.stop,
        )
        if stream:
            for chunk in resp:
                if chunk.choices[0] and chunk.choices[0].delta.content:
                    yield chunk.choices[0].delta.content
        else:
            yield resp.choices[0].message.content


class LLMAsyncGroq(llm.AsyncModel, _Shared):
    can_stream = True
    Options = _Options

    async def execute(self, prompt, stream, response, conversation):
        key = llm.get_key("", "groq", "LLM_GROQ_KEY")
        messages = self.build_messages(prompt, conversation)

        # Map the user-specified model_id to the actual Groq model name:
        actual_model_name = model_map[self.model_id]
        client = AsyncGroq(api_key=key)
        resp = await client.chat.completions.create(
            messages=messages,
            model=actual_model_name,
            stream=stream,
            temperature=prompt.options.temperature,
            top_p=prompt.options.top_p,
            max_tokens=prompt.options.max_tokens,
            stop=prompt.options.stop,
        )
        if stream:
            async for chunk in resp:
                if chunk.choices[0] and chunk.choices[0].delta.content:
                    yield chunk.choices[0].delta.content
        else:
            yield resp.choices[0].message.content

Token usage:

4,535 input, 4,762 output, {"completion_tokens_details": {"reasoning_tokens": 2624}}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment