Skip to content

Instantly share code, notes, and snippets.

@zzstoatzz
Created February 25, 2024 06:13
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zzstoatzz/8b086131d1eb0dc778d28c213fd5eec6 to your computer and use it in GitHub Desktop.
Save zzstoatzz/8b086131d1eb0dc778d28c213fd5eec6 to your computer and use it in GitHub Desktop.
import inspect
from contextlib import contextmanager
import marvin
from devtools import debug # pip install devtools
from marvin.ai.text import ChatRequest, EjectRequest
from marvin.client.openai import AsyncMarvinClient
from marvin.utilities.asyncio import run_sync
from marvin.utilities.context import ctx
from marvin.utilities.strings import count_tokens
from openai.types.chat import ChatCompletion
def process_ejected_request(request: ChatRequest):
debug(request)
print(f"Message tokens: {count_tokens(''.join(m.content for m in request.messages))}")
print(f"Called tool: {(t := request.tool_choice.get('function').get('name'))}")
print("Which looks like:")
debug(next(iter(tool for tool in request.tools if tool.function.name == t)))
def process_completion(completion: ChatCompletion):
debug(completion.usage)
# optionally publish this to a queue
class MyClient(AsyncMarvinClient):
async def generate_chat(self, **kwargs):
r = await super().generate_chat(**kwargs)
maybe_coro = process_completion(r)
if inspect.iscoroutine(maybe_coro):
await maybe_coro
return r
@contextmanager
def inspect_mode(_process_fn=None):
if _process_fn is None:
_process_fn = process_ejected_request
with ctx(eject_request=True):
try:
yield
except EjectRequest as e:
maybe_coro = _process_fn(e.request)
if inspect.iscoroutine(maybe_coro):
run_sync(maybe_coro)
@marvin.fn(client=MyClient())
def list_fruit(n: int = 2) -> list[str]:
"""returns a list of `n` fruit"""
if __name__ == "__main__":
with inspect_mode():
list_fruit(n=3)
print(list_fruit(n=3))
@zzstoatzz
Copy link
Author

zzstoatzz commented Feb 25, 2024

» python cookbook/eject_payload.py
cookbook/eject_payload.py:15 process_ejected_request
    request: ChatRequest(
        tools=[
            Tool[~M](
                type='function',
                function=Function[~M](
                    name='FormatResponse',
                    description='Format the response with valid JSON.',
                    parameters={
                        'description': 'Format the response with valid JSON.',
                        'properties': {
                            'value': {
                                'description': 'The formatted response',
                                'items': {
                                    'type': 'string',
                                },
                                'title': 'Value',
                                'type': 'array',
                            },
                        },
                        'required': ['value'],
                        'type': 'object',
                    },
                ),
            ),
        ],
        tool_choice={
            'type': 'function',
            'function': {
                'name': 'FormatResponse',
            },
        },
        logit_bias=None,
        max_tokens=None,
        response_format=None,
        messages=[
            BaseMessage(
                content=(
                    'Your job is to generate likely outputs for a Python function with the\n'
                    'following definition:\n'
                    '\n'
                    'def list_fruit(n: int = 2) -> list[str]:\n'
                    '    """\n'
                    '    returns a list of `n` fruit\n'
                    '    """\n'
                    '\n'
                    'The user will provide function inputs (if any) and you must respond with\n'
                    'the most likely result.'
                ),
                role='system',
            ),
            BaseMessage(
                content=(
                    '## Function inputs\n'
                    '\n'
                    'The function was called with the following inputs:\n'
                    '- n: 3\n'
                    '\n'
                    '\n'
                    "What is the function's output?"
                ),
                role='user',
            ),
            BaseMessage(
                content='The output is',
                role='assistant',
            ),
        ],
        model='gpt-4-1106-preview',
        frequency_penalty=0,
        n=1,
        presence_penalty=0,
        seed=None,
        stop=None,
        stream=False,
        temperature=0.0,
        top_p=1,
        user=None,
    ) (ChatRequest)
Message tokens: 96
Called tool: FormatResponse
Which looks like:
cookbook/eject_payload.py:19 process_ejected_request
    next(iter(tool for tool in request.tools if tool.function.name == t)): Tool[~M](
        type='function',
        function=Function[~M](
            name='FormatResponse',
            description='Format the response with valid JSON.',
            parameters={
                'description': 'Format the response with valid JSON.',
                'properties': {
                    'value': {
                        'description': 'The formatted response',
                        'items': {
                            'type': 'string',
                        },
                        'title': 'Value',
                        'type': 'array',
                    },
                },
                'required': ['value'],
                'type': 'object',
            },
        ),
    ) (Tool[~M])
cookbook/eject_payload.py:23 process_completion
    completion.usage: CompletionUsage(
        completion_tokens=10,
        prompt_tokens=172, # 172 total tokens - 96 messages tokens = N tool/function tokens
        total_tokens=182,
    ) (CompletionUsage)
[02/25/24 00:10:54] DEBUG    marvin.Tools: FormatResponse: called with arguments: {'value': ['apple', 'banana', 'cherry']}                                           logging.py:89
                    DEBUG    marvin.Tools: FormatResponse: returned: ['apple', 'banana', 'cherry']                                                                   logging.py:89
['apple', 'banana', 'cherry']

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment