|
import functools |
|
import os |
|
from typing import Any, Awaitable, Callable, TypeVar, cast |
|
from dotenv import load_dotenv |
|
import httpx |
|
import logging |
|
logger = logging.getLogger("lakera_guarded") |
|
|
|
load_dotenv() |
|
LAKERA_API_KEY = os.environ.get("LAKERA_API_KEY") |
|
|
|
AsyncFunc = TypeVar('AsyncFunc', bound=Callable[..., Awaitable[Any]]) |
|
|
|
def guard_content(input_param: str = "", output_screen: bool = True) -> Callable[[AsyncFunc], AsyncFunc]: |
|
"""Decorator that screens input parameter and optionally output""" |
|
|
|
def decorator(func: AsyncFunc) -> AsyncFunc: |
|
@functools.wraps(func) |
|
async def wrapper(*args: Any, **kwargs: Any): |
|
|
|
# Screen input parameter if specified |
|
logger.debug(f'{input_param=}/{output_screen=}/{args=}/{kwargs=}') |
|
if input_param and input_param in kwargs: |
|
content = kwargs[input_param] |
|
logger.debug(f"screening input") |
|
logger.debug(f"{input_param}={content}") |
|
check = await screen_content(content) |
|
if not check["is_safe"]: |
|
logger.warning(f"input flagged {content=}") |
|
# raising an error here will let the LLM |
|
# know that the input is not safe and it should not be used. |
|
raise ValueError(f"Input rejected: {check['summary']}") |
|
|
|
# Call the original function |
|
logger.debug(f"calling original function..") |
|
result = await func(*args, **kwargs) |
|
|
|
# Screen output if enabled |
|
if output_screen and isinstance(result, str): |
|
logger.debug("screening output {result=}") |
|
check = await screen_content(result) |
|
if not check["is_safe"]: |
|
logger.warning(f"output flagged {result=}") |
|
# raising an error here will let the LLM |
|
# know that the input is not safe and it should not be used. |
|
raise ValueError(f"Output rejected: {check['summary']}") |
|
|
|
return result |
|
return cast(AsyncFunc, wrapper) |
|
return decorator |
|
|
|
async def screen_content(text: str) -> dict: |
|
"""Screen content using Lakera Guard API""" |
|
async with httpx.AsyncClient() as client: |
|
resp = await client.post( |
|
"https://api.lakera.ai/v2/guard", |
|
json={"messages": [{"role": "user", "content": text}], "breakdown": True}, |
|
headers={"Authorization": f"Bearer {LAKERA_API_KEY}"}, |
|
) |
|
if resp.status_code != 200: |
|
logger.error(f"error: {resp.status_code} - {resp.text}") |
|
return {"is_safe": False, "summary": "Error screening content"} |
|
result = resp.json() |
|
logger.info(f">>{text=}") |
|
logger.info(f">>{result=}") |
|
return { |
|
"is_safe": not result.get("flagged", False), |
|
"summary": "Lakera Guard screen the message and content is safe to use" |
|
if not result.get("flagged", False) |
|
else "Content has been flagged by Lakera Guard as potentially harmful.", |
|
} |