Skip to content

Instantly share code, notes, and snippets.

@perryism
Last active February 20, 2024 20:19
Show Gist options
  • Save perryism/30b35767ca40049bd211bbf312bfc8ed to your computer and use it in GitHub Desktop.
Save perryism/30b35767ca40049bd211bbf312bfc8ed to your computer and use it in GitHub Desktop.
custom model garden client
from langchain.callbacks.manager import AsyncCallbackManager
import os, logging
logger = logging.getLogger(__name__)
stream_manager = AsyncCallbackManager([])
from langchain_google_vertexai import VertexAIModelGarden
from typing import Any, List, Optional
from langchain.schema import Generation, LLMResult
from langchain_google_vertexai import VertexAI
from langchain_core.outputs import GenerationChunk
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_community.utilities.vertexai import (
create_retry_decorator,
)
# the changes below are from the PR, except _decorate_response:
# https://github.com/langchain-ai/langchain/pull/14444
def is_codey_model(model_name: str) -> bool:
"""Returns True if the model name is a Codey model.
Args:
model_name: The model name to check.
Returns: True if the model name is a Codey model.
"""
return "code" in model_name
def completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
def stream_completion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(
llm, max_retries=llm.max_retries, run_manager=run_manager
)
@retry_decorator
def _completion_with_retry(*args: Any, **kwargs: Any) -> Any:
return llm.client.predict_streaming(*args, **kwargs)
return _completion_with_retry(*args, **kwargs)
async def acompletion_with_retry(
llm: VertexAI,
*args: Any,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = create_retry_decorator(llm, run_manager=run_manager)
@retry_decorator
async def _acompletion_with_retry(*args: Any, **kwargs: Any) -> Any:
return await llm.client.predict_async(*args, **kwargs)
return await _acompletion_with_retry(*args, **kwargs)
from langchain.callbacks.manager import AsyncCallbackManagerForLLMRun
class Mistral(VertexAIModelGarden):
strip_prefix: bool = False
"Whether to strip the prompt from the generated text."
max_tokens: int = 2024
temperature: float = 0.0
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.max_tokens = kwargs.get("max_tokens", self.max_tokens)
self.temperature = kwargs.get("temperature", self.temperature)
self.allowed_model_args = ["max_tokens", "temperature"]
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
kwargs["max_tokens"] = self.max_tokens
kwargs["temperature"] = self.temperature
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = self.client.predict(endpoint=self.endpoint_path, instances=instances)
# google.cloud.aiplatform_v1.types.prediction_service.PredictResponse => ChatResponse
resp = self._parse_response(
prompts,
response,
run_manager=run_manager,
)
return resp
def _parse_response(
self,
prompts: List[str],
predictions: "Prediction",
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> List[List[GenerationChunk]]:
generations: List[List[GenerationChunk]] = []
for prompt, result in zip(prompts, predictions.predictions):
chunks = [
GenerationChunk(text=self._parse_prediction(prediction))
for prediction in result
]
if self.strip_prefix:
chunks = self._strip_generation_context(prompt, chunks)
generation = self._aggregate_response(
chunks,
run_manager=run_manager,
verbose=self.verbose,
)
generations.append([generation])
return LLMResult(generations=generations)
def _aggregate_response(
self,
chunks: List[Generation],
run_manager: Optional[CallbackManagerForLLMRun] = None,
verbose: bool = False,
) -> GenerationChunk:
final_chunk: Optional[GenerationChunk] = None
for chunk in chunks:
if final_chunk is None:
final_chunk = chunk
else:
final_chunk += chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=verbose,
)
if final_chunk is None:
raise ValueError("Malformed response from VertexAIModelGarden")
return self._decorate_response(final_chunk)
def _decorate_response(self, final_chunk: GenerationChunk) -> GenerationChunk:
parts = final_chunk.text.split("Output:")
if len(parts) > 1:
final_chunk.text = parts[1].strip()
return final_chunk
def _strip_generation_context(
self,
prompt: str,
chunks: List[GenerationChunk],
) -> List[GenerationChunk]:
context = self._format_generation_context(prompt)
chunk_cursor = 0
context_cursor = 0
while chunk_cursor < len(chunks) and context_cursor < len(context):
chunk = chunks[chunk_cursor]
for c in chunk.text:
if c == context[context_cursor]:
context_cursor += 1
else:
break
chunk_cursor += 1
return chunks[chunk_cursor:] if chunk_cursor == context_cursor else chunks
def _format_generation_context(self, prompt: str) -> str:
return "\n".join(["Prompt:", prompt.strip(), "Output:", ""])
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
instances = self._prepare_request(prompts, **kwargs)
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
resp = self._parse_response(
prompts,
response,
run_manager=run_manager,
)
return resp
def create_mistral():
project_id = os.environ["GOOGLE_CLOUD_PROJECT_ID"]
endpoint_id = os.environ["LLM_ENDPOINT_ID"]
llm_gcp_location = os.environ["LLM_GCP_LOCATION"]
return Mistral(project=project_id, endpoint_id=endpoint_id, location=llm_gcp_location, temperature=0, verbose=True, max_tokens=2028)
def __name__ == "__main__":
mistral = create_mistral()
print(mistral.invoke("Summarize Oaccam's Razor"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment