Last active
September 4, 2024 15:46
-
-
Save christian-taillon/c620e81206025e7101c25272fbb27f96 to your computer and use it in GitHub Desktop.
Improved Version to Handle Streaming Bug
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
title: Anthropic Manifold Pipe v2 | |
description: An enhanced implementation of the Anthropic API integration for Open WebUI, | |
featuring improved streaming performance, robust error handling, and optimized | |
response processing for both streaming and non-streaming requests. | |
version: 0.1.6 | |
license: MIT | |
fork of: https://openwebui.com/f/justinrahb/anthropic/ | |
orig_author: justinh-rahb | |
orig_author_url: https://github.com/justinh-rahb | |
current_maintainer: christian-taillon | |
current_maintainer_url: https://gist.github.com/christian-taillon/c620e81206025e7101c25272fbb27f96 | |
funding_url: https://github.com/open-webui | |
usage: | |
This pipe can be used to interact with Anthropic's AI models, supporting both | |
streaming and non-streaming responses. It handles text and image inputs, with | |
built-in safeguards for API limitations such as maximum image count and size. | |
""" | |
import os | |
import requests | |
import json | |
import time | |
from typing import List, Union, Generator, Iterator | |
from pydantic import BaseModel, Field | |
from utils.misc import pop_system_message | |
# Define Pipe class | |
class Pipe: | |
class Valves(BaseModel): | |
ANTHROPIC_API_KEY: str = Field(default="") | |
def __init__(self): | |
self.type = "manifold" | |
self.id = "anthropic" | |
self.name = "anthropic/" | |
self.valves = self.Valves( | |
**{"ANTHROPIC_API_KEY": os.getenv("ANTHROPIC_API_KEY", "")} | |
) | |
def get_anthropic_models(self): | |
return [ | |
{"id": "claude-3-haiku-20240307", "name": "claude-3-haiku"}, | |
{"id": "claude-3-opus-20240229", "name": "claude-3-opus"}, | |
{"id": "claude-3-sonnet-20240229", "name": "claude-3-sonnet"}, | |
{"id": "claude-3-5-sonnet-20240620", "name": "claude-3.5-sonnet"}, | |
] | |
def pipes(self) -> List[dict]: | |
return self.get_anthropic_models() | |
def process_image(self, image_data): | |
if image_data["image_url"]["url"].startswith("data:image"): | |
mime_type, base64_data = image_data["image_url"]["url"].split(",", 1) | |
media_type = mime_type.split(":")[1].split(";")[0] | |
return { | |
"type": "image", | |
"source": { | |
"type": "base64", | |
"media_type": media_type, | |
"data": base64_data, | |
}, | |
} | |
else: | |
return { | |
"type": "image", | |
"source": {"type": "url", "url": image_data["image_url"]["url"]}, | |
} | |
def pipe(self, body: dict) -> Union[str, Generator[str, None, None], Iterator[str]]: | |
system_message, messages = pop_system_message(body["messages"]) | |
processed_messages = [] | |
image_count = 0 | |
total_image_size = 0 | |
for message in messages: | |
processed_content = [] | |
if isinstance(message.get("content"), list): | |
for item in message["content"]: | |
if item["type"] == "text": | |
processed_content.append({"type": "text", "text": item["text"]}) | |
elif item["type"] == "image_url": | |
if image_count >= 5: | |
raise ValueError( | |
"Maximum of 5 images per API call exceeded" | |
) | |
processed_image = self.process_image(item) | |
processed_content.append(processed_image) | |
if processed_image["source"]["type"] == "base64": | |
image_size = len(processed_image["source"]["data"]) * 3 / 4 | |
else: | |
image_size = 0 | |
total_image_size += image_size | |
if total_image_size > 100 * 1024 * 1024: | |
raise ValueError( | |
"Total size of images exceeds 100 MB limit" | |
) | |
image_count += 1 | |
else: | |
processed_content = [ | |
{"type": "text", "text": message.get("content", "")} | |
] | |
processed_messages.append( | |
{"role": message["role"], "content": processed_content} | |
) | |
# Ensure the system_message is coerced to a string | |
payload = { | |
"model": body["model"][body["model"].find(".") + 1 :], | |
"messages": processed_messages, | |
"max_tokens": body.get("max_tokens", 4096), | |
"temperature": body.get("temperature", 0.8), | |
"top_k": body.get("top_k", 40), | |
"top_p": body.get("top_p", 0.9), | |
"stop_sequences": body.get("stop", []), | |
**({"system": str(system_message)} if system_message else {}), | |
"stream": body.get("stream", False), | |
} | |
headers = { | |
"x-api-key": self.valves.ANTHROPIC_API_KEY, | |
"anthropic-version": "2023-06-01", | |
"content-type": "application/json", | |
} | |
url = "https://api.anthropic.com/v1/messages" | |
try: | |
if body.get("stream", False): | |
return self.stream_response(url, headers, payload) | |
else: | |
return self.non_stream_response(url, headers, payload) | |
except requests.exceptions.RequestException as e: | |
print(f"Request failed: {e}") | |
return f"Error: Request failed: {e}" | |
except Exception as e: | |
print(f"Error in pipe method: {e}") | |
return f"Error: {e}" | |
def stream_response(self, url, headers, payload) -> Generator[str, None, None]: | |
try: | |
with requests.post( | |
url, headers=headers, json=payload, stream=True, timeout=(3.05, 60) | |
) as response: | |
response.raise_for_status() | |
for line in response.iter_lines(delimiter=b"\n"): | |
if line: | |
decoded_line = line.decode("utf-8").strip() | |
if decoded_line.startswith("data: "): | |
try: | |
data = json.loads(decoded_line[6:]) | |
# Process the different types of content to extract text | |
if data.get("type") == "content_block_start": | |
yield data["content_block"].get("text", "") | |
elif data.get("type") == "content_block_delta": | |
yield data["delta"].get("text", "") | |
elif data.get("type") == "message": | |
# Handle entire message at the end | |
for content in data.get("content", []): | |
if content["type"] == "text": | |
yield content.get("text", "") | |
# Delay to avoid overwhelming the client. | |
time.sleep(0.01) | |
except json.JSONDecodeError: | |
print(f"Failed to parse JSON: {decoded_line}") | |
except KeyError as e: | |
print(f"Unexpected data structure: {e}") | |
print(f"Full data: {data}") | |
except requests.exceptions.RequestException as e: | |
print(f"Request failed: {e}") | |
yield f"Error: Request failed: {e}" | |
except Exception as e: | |
print(f"General error in stream_response method: {e}") | |
yield f"Error: {e}" | |
def non_stream_response(self, url, headers, payload) -> str: | |
try: | |
response = requests.post( | |
url, headers=headers, json=payload, timeout=(3.05, 60) | |
) | |
response.raise_for_status() | |
res = response.json() | |
return res.get("content", [{}])[0].get("text", "") | |
except requests.exceptions.RequestException as e: | |
print(f"Failed non-stream request: {e}") | |
return f"Error: {e}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment