Last active
February 20, 2024 20:19
-
-
Save perryism/30b35767ca40049bd211bbf312bfc8ed to your computer and use it in GitHub Desktop.
custom model garden client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.callbacks.manager import AsyncCallbackManager | |
import os, logging | |
logger = logging.getLogger(__name__) | |
stream_manager = AsyncCallbackManager([]) | |
from langchain_google_vertexai import VertexAIModelGarden | |
from typing import Any, List, Optional | |
from langchain.schema import Generation, LLMResult | |
from langchain_google_vertexai import VertexAI | |
from langchain_core.outputs import GenerationChunk | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForLLMRun, | |
CallbackManagerForLLMRun, | |
) | |
from langchain_community.utilities.vertexai import ( | |
create_retry_decorator, | |
) | |
# the changes below are from the PR, except _decorate_response: | |
# https://github.com/langchain-ai/langchain/pull/14444 | |
def is_codey_model(model_name: str) -> bool: | |
"""Returns True if the model name is a Codey model. | |
Args: | |
model_name: The model name to check. | |
Returns: True if the model name is a Codey model. | |
""" | |
return "code" in model_name | |
def completion_with_retry( | |
llm: VertexAI, | |
*args: Any, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Any: | |
"""Use tenacity to retry the completion call.""" | |
retry_decorator = create_retry_decorator(llm, run_manager=run_manager) | |
@retry_decorator | |
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: | |
return llm.client.predict(*args, **kwargs) | |
return _completion_with_retry(*args, **kwargs) | |
def stream_completion_with_retry( | |
llm: VertexAI, | |
*args: Any, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Any: | |
"""Use tenacity to retry the completion call.""" | |
retry_decorator = create_retry_decorator( | |
llm, max_retries=llm.max_retries, run_manager=run_manager | |
) | |
@retry_decorator | |
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any: | |
return llm.client.predict_streaming(*args, **kwargs) | |
return _completion_with_retry(*args, **kwargs) | |
async def acompletion_with_retry( | |
llm: VertexAI, | |
*args: Any, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> Any: | |
"""Use tenacity to retry the completion call.""" | |
retry_decorator = create_retry_decorator(llm, run_manager=run_manager) | |
@retry_decorator | |
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any: | |
return await llm.client.predict_async(*args, **kwargs) | |
return await _acompletion_with_retry(*args, **kwargs) | |
from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun | |
class Mistral(VertexAIModelGarden): | |
strip_prefix: bool = False | |
"Whether to strip the prompt from the generated text." | |
max_tokens: int = 2024 | |
temperature: float = 0.0 | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
super().__init__(*args, **kwargs) | |
self.max_tokens = kwargs.get("max_tokens", self.max_tokens) | |
self.temperature = kwargs.get("temperature", self.temperature) | |
self.allowed_model_args = ["max_tokens", "temperature"] | |
def _generate( | |
self, | |
prompts: List[str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> LLMResult: | |
kwargs["max_tokens"] = self.max_tokens | |
kwargs["temperature"] = self.temperature | |
"""Run the LLM on the given prompt and input.""" | |
instances = self._prepare_request(prompts, **kwargs) | |
response = self.client.predict(endpoint=self.endpoint_path, instances=instances) | |
# google.cloud.aiplatform_v1.types.prediction_service.PredictResponse => ChatResponse | |
resp = self._parse_response( | |
prompts, | |
response, | |
run_manager=run_manager, | |
) | |
return resp | |
def _parse_response( | |
self, | |
prompts: List[str], | |
predictions: "Prediction", | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
) -> List[List[GenerationChunk]]: | |
generations: List[List[GenerationChunk]] = [] | |
for prompt, result in zip(prompts, predictions.predictions): | |
chunks = [ | |
GenerationChunk(text=self._parse_prediction(prediction)) | |
for prediction in result | |
] | |
if self.strip_prefix: | |
chunks = self._strip_generation_context(prompt, chunks) | |
generation = self._aggregate_response( | |
chunks, | |
run_manager=run_manager, | |
verbose=self.verbose, | |
) | |
generations.append([generation]) | |
return LLMResult(generations=generations) | |
def _aggregate_response( | |
self, | |
chunks: List[Generation], | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
verbose: bool = False, | |
) -> GenerationChunk: | |
final_chunk: Optional[GenerationChunk] = None | |
for chunk in chunks: | |
if final_chunk is None: | |
final_chunk = chunk | |
else: | |
final_chunk += chunk | |
if run_manager: | |
run_manager.on_llm_new_token( | |
chunk.text, | |
verbose=verbose, | |
) | |
if final_chunk is None: | |
raise ValueError("Malformed response from VertexAIModelGarden") | |
return self._decorate_response(final_chunk) | |
def _decorate_response(self, final_chunk: GenerationChunk) -> GenerationChunk: | |
parts = final_chunk.text.split("Output:") | |
if len(parts) > 1: | |
final_chunk.text = parts[1].strip() | |
return final_chunk | |
def _strip_generation_context( | |
self, | |
prompt: str, | |
chunks: List[GenerationChunk], | |
) -> List[GenerationChunk]: | |
context = self._format_generation_context(prompt) | |
chunk_cursor = 0 | |
context_cursor = 0 | |
while chunk_cursor < len(chunks) and context_cursor < len(context): | |
chunk = chunks[chunk_cursor] | |
for c in chunk.text: | |
if c == context[context_cursor]: | |
context_cursor += 1 | |
else: | |
break | |
chunk_cursor += 1 | |
return chunks[chunk_cursor:] if chunk_cursor == context_cursor else chunks | |
def _format_generation_context(self, prompt: str) -> str: | |
return "\n".join(["Prompt:", prompt.strip(), "Output:", ""]) | |
async def _agenerate( | |
self, | |
prompts: List[str], | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> LLMResult: | |
"""Run the LLM on the given prompt and input.""" | |
instances = self._prepare_request(prompts, **kwargs) | |
response = await self.async_client.predict( | |
endpoint=self.endpoint_path, instances=instances | |
) | |
resp = self._parse_response( | |
prompts, | |
response, | |
run_manager=run_manager, | |
) | |
return resp | |
def create_mistral(): | |
project_id = os.environ["GOOGLE_CLOUD_PROJECT_ID"] | |
endpoint_id = os.environ["LLM_ENDPOINT_ID"] | |
llm_gcp_location = os.environ["LLM_GCP_LOCATION"] | |
return Mistral(project=project_id, endpoint_id=endpoint_id, location=llm_gcp_location, temperature=0, verbose=True, max_tokens=2028) | |
def __name__ == "__main__": | |
mistral = create_mistral() | |
print(mistral.invoke("Summarize Oaccam's Razor")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment