Last active
February 14, 2023 15:12
-
-
Save flolas/c5306392f5368b7601ad7b68f246b365 to your computer and use it in GitHub Desktop.
Working example of Langchain custom LLM for revChatGPT (ChatGPT API)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.llms.base import BaseLLM | |
from langchain.schema import Generation, LLMResult | |
from langchain.utils import get_from_dict_or_env | |
import asyncio | |
from revChatGPT.V2 import Chatbot | |
import nest_asyncio | |
from retrying_async import retry | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Generator, | |
List, | |
Mapping, | |
Optional, | |
Set, | |
Tuple, | |
Union, | |
) | |
import logging | |
import sys | |
from pydantic import BaseModel, Extra, Field, root_validator | |
class ChatGPTLLM(BaseLLM, BaseModel): | |
chatbot_instance: Chatbot | |
class Config: | |
"""Configuration for this pydantic object.""" | |
extra = Extra.ignore | |
@property | |
def _llm_type(self) -> str: | |
"""Return type of llm.""" | |
return "chatgpt" | |
@property | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return {**{"model_name": self.model_name}, **self._default_params} | |
@property | |
def _default_params(self) -> Dict[str, Any]: | |
"""Get the default parameters for calling OpenAI API.""" | |
normal_params = { | |
"chatbot_instance": self.chatbot_instance | |
} | |
return {**normal_params, **self.model_kwargs} | |
def _generate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
loop = asyncio.new_event_loop() | |
nest_asyncio.apply(loop) | |
try: | |
logging.info("gen") | |
r = loop.run_until_complete(self._agenerate(prompts, stop)) | |
except Exception as e: | |
logging.info(e) | |
raise e | |
loop.close() | |
return r | |
@retry(attempts=20, delay=5) | |
async def _agenerate( | |
self, prompts: List[str], stop: Optional[List[str]] = None | |
) -> LLMResult: | |
choices = [] | |
params = dict() | |
for _prompts in prompts: | |
_prompts_gen = [] | |
async for _p in self.chatbot_instance.ask(_prompts): | |
_prompts_gen.extend(_p['choices']) | |
if len(_prompts_gen)==0: | |
logging.info(prompts) | |
raise Exception("No results") | |
result = _prompts_gen[-1] | |
result['text'] = "".join([g['text'] for g in _prompts_gen]) | |
if result['text'].contains("Hi, How can I help you today?"): | |
logging.info(prompts) | |
logging.info("bad text recieved") | |
raise Exception("bad text recieved") | |
if result['text'].strip() == '': | |
logging.info(prompts) | |
logging.info("empty text recieved") | |
raise Exception("empty text recieved") | |
logging.info(f"Completition: {result['text']}") | |
choices.append(result) | |
return self.create_llm_result(choices, prompts, 0) | |
def create_llm_result( | |
self, choices: Any, prompts: List[str], token_usage: Dict[str, int] | |
) -> LLMResult: | |
"""Create the LLMResult from the choices and prompts.""" | |
generations = [] | |
for i, prompt in enumerate(prompts): | |
sub_choices = choices | |
generations.append( | |
[ | |
Generation( | |
text='\n\n' + choice["text"].strip().replace("<|im_end|>", ''), | |
generation_info=dict( | |
finish_reason=choice.get("finish_reason"), | |
logprobs=choice.get("logprobs"), | |
), | |
) | |
for choice in sub_choices | |
] | |
) | |
return LLMResult( | |
generations=generations, llm_output={"token_usage": token_usage} | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment