Skip to content

Instantly share code, notes, and snippets.

@Airbus5717
Last active May 28, 2024 06:41
Show Gist options
  • Save Airbus5717/0a884eba843b0ca6b52cb77da2c5ab61 to your computer and use it in GitHub Desktop.
Save Airbus5717/0a884eba843b0ca6b52cb77da2c5ab61 to your computer and use it in GitHub Desktop.
### CONFIGURATION
TAKE_ACTIONS = False # If false, it has no access to terminal and code execution
USE_MEMORY = True
USE_PYTHON = True
USE_CALCULATOR = True
USE_SHELL = True
USE_WIKI_SEARCH = True
USE_SEARCH_ONLINE = True
ASK_HUMAN_BEFORE_EXE_CMDS = True
# it uses transformers or local transformers model
FINE_TUNED_MODEL_PATH = "../finetuning/mixtral/qlora-mixtral-out"
#DEFAULT_MODEL = "cognitivecomputations/dolphin-2.5-mixtral-8x7b"
DEFAULT_MODEL = FINE_TUNED_MODEL_PATH
#############################
##### #####
### DO NOT EDIT BELOW ###
##### #####
#############################
import platform
import warnings
import torch
import numexpr as ne
from typing import Optional, List, Mapping, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.tools import BaseTool
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.agents import Tool
from langchain_community.tools import WikipediaQueryRun, ShellTool
from langchain.agents import create_json_chat_agent, AgentExecutor
from langchain.memory import ConversationBufferMemory
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
### CODE
USE_PYTHON = USE_PYTHON and TAKE_ACTIONS
USE_SHELL = USE_SHELL and TAKE_ACTIONS
MODEL_NAME = FINE_TUNED_MODEL_PATH if USE_FINE_TUNED_MODEL else DEFAULT_MODEL
QUANTIZATION_CONFIG = BitsAndBytesConfig(load_in_4bit=True)
MODEL = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
quantization_config=QUANTIZATION_CONFIG,
device_map="auto",
)
TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
# print(type(MODEL))
warnings.filterwarnings("ignore")
# Tools classes
class CustomLLMMistral(LLM):
model: MixtralForCausalLM # MistralForCausalLM
tokenizer: LlamaTokenizerFast
@property
def _llm_type(self) -> str:
return "custom"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
messages = [
{"role": "user", "content": prompt},
]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(self.model.device)
generated_ids = self.model.generate(
model_inputs,
max_new_tokens=1024, # Adjusted max_length for generated output
do_sample=True,
pad_token_id=TOKENIZER.eos_token_id,
top_k=50, # Increased top_k for more diverse sampling
temperature=0.5, # Lowered temperature for more conservative sampling
num_return_sequences=1, # Set to generate a single sequence
)
decoded = self.tokenizer.batch_decode(generated_ids)
try:
output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip()
except Exception as e:
print(e)
exit()
if stop is not None:
for word in stop:
output = output.split(word)[0].strip()
# Mistral 7B sometimes fails to properly close the Markdown Snippets.
# If they are not correctly closed, Langchain will struggle to parse the output.
while not output.endswith("```"):
output += "`"
return output
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"model": self.model}
class Calculator(BaseTool):
name = "calculator"
description = "Use this tool for math operations. It requires numexpr syntax. Use it always you need to solve any math operation. You can give it the whole complete expression at once. Be sure syntax is correct."
def _run(self, expression: str):
try:
return ne.evaluate(expression).item()
except Exception:
return "This is not a numexpr valid syntax. Try a different syntax."
def _arun(self, radius: int):
raise NotImplementedError("This tool does not support async")
def custom_py_run(input):
from io import StringIO
import contextlib
with contextlib.redirect_stdout(StringIO()) as output:
try:
exec(input)
except Exception as e:
return e
return output.getvalue().strip()
def _get_platform() -> str:
system = platform.system()
if system == "Darwin":
return "MacOS"
return system
llm = CustomLLMMistral(model=MODEL, tokenizer=TOKENIZER)
search = DuckDuckGoSearchRun()
wikipedia = WikipediaQueryRun(
api_wrapper=WikipediaAPIWrapper(top_k_results=2, doc_content_chars_max=2500)
)
shell_tool = ShellTool(
description=f"""
Run shell commands on this {_get_platform()} machine.,
it might return 'None' due to the user not allowing to run such command if it happens then conclude with 'User has not allowed to run the command',
if some package is not available, suggest to install the packages that are required to run the program
some commands do not show output, you may use other tools to find out if they are correctly ran,
inline the terminal action 'echo $?' to check the result of the command if it is successful or not
then take actions needed.
""",
ask_human_input=ASK_HUMAN_BEFORE_EXE_CMDS,
)
wikipedia_tool = Tool(
name="wikipedia",
description="Never search for more than one concept at a single step. If you need to compare two concepts, search for each one individually and use the meta data syntax when prompting for the next query. Syntax: string with a simple concept",
func=wikipedia.run,
)
search_tool = Tool(
name="Search online",
description="this tool searches the internet for results",
func=search,
)
calculator_tool = Calculator()
py_tool = Tool(
name="Python",
description=f"""This is a python execution tool, use it only execution is required.
this tool runs python code on this {_get_platform()} machine. it uses a function exec to evaluate the code and collects the outputs and returns the outputs.
provide valid python code to be ran once
""",
func=custom_py_run,
)
tools = []
if USE_SEARCH_ONLINE:
tools.append(search_tool)
if USE_SHELL:
tools.append(shell_tool)
if USE_PYTHON:
tools.append(py_tool)
if USE_CALCULATOR:
tools.append(calculator_tool)
if USE_WIKI_SEARCH:
tools.append(wikipedia_tool)
system = """
As a Cyber Security Researcher Assistant, you're equipped to solve tasks. Try Minimize the steps to solve the task, but make sure it is correct. Each task requires a series of steps, represented by JSON blobs with specific keys called Meta Data:
thought -> your thoughts
action -> tool name or "Final Answer" to give a final answer
action_input -> tool parameters or the final solution
Available tools: {tool_names}
Tool descriptions:
{tools}
If you have enough information, use "Final Answer" with the solution.
If information is insufficient or incorrect, try once. If the issue persists, use "Final Answer".
"""
human = """
Add "STOP" after each snippet. Follow this JSON schema:
```json
{{"thought": "<your thoughts>",
"action": "<tool name or "Final Answer" to give a final answer or "None" as a to conclude>",
"action_input": "<tool parameters or final output>"}}
```\n
STOP
Query: "{input}". Provide the next necessary step only.
Base your answer on previous steps, even if you believe you know the solution.
Add "STOP" after each snippet.
Previous steps and gathered information:
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", system),
MessagesPlaceholder("chat_history", optional=True),
("human", human),
MessagesPlaceholder("agent_scratchpad"),
]
)
agent = create_json_chat_agent(
tools=tools,
llm=llm,
prompt=prompt,
stop_sequence=["STOP"],
template_tool_response="{observation}",
)
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
handle_parsing_errors=True,
memory=memory if USE_MEMORY else None,
)
# Display currently used agents
print("Agents: [", ", ".join(i.name for i in tools), "]")
# Prompt input
x = input("input: ")
run = agent_executor.invoke(
{"input": f"{x}"}, {"recursion_limit": 4, "max_concurrency": 1}
)
print(run)
out = run["output"]
print(f"\n### FINAL OUTPUT ###\n{out}\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment