Created
February 14, 2024 21:15
-
-
Save Cdaprod/f4651e331aa24a7209c8e0f71bf4ebb3 to your computer and use it in GitHub Desktop.
The MinIoAgent handles user queries by parsing statements and constructing the appropriate sequence of calls to the available tools. Each tool is invoked depending on the specified objective, whether it's listing, uploading, downloading, or removing objects. Intermediate steps are captured and presented in the final response. Users can modify th…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import io | |
import tarfile | |
import uuid | |
from urllib.parse import urlparse | |
import requests | |
from pydantic import BaseModel | |
from rich.console import Console | |
sys.path.insert(0, "/path/to/langchain-sdk") | |
import langchain | |
from langchain.schema import Tool, AgentAction, AgentFinish | |
from langchain.tools import StructuredTool | |
from langchain.agents import AgentExecutor, ZeroShotAgent | |
from langchain.utilities import SerpAPIWrapper | |
from langchain.vectorstores import Chroma | |
from chromadb.config import Settings | |
from chromadb.errors import CollectionDoesNotExistError | |
from typing import Union, Dict, Iterator, Callable, Any | |
class MinIoConfig(BaseModel): | |
host: str | |
port: int | |
access_key: str | |
secret_key: str | |
use_ssl: bool | |
class MinIoTool(StructuredTool): | |
"""Abstract class for MinIo Tools""" | |
def __init__(self, config: MinIoConfig): | |
self.config = config | |
super().__init__() | |
class ListObjectsTool(MinIoTool): | |
"""List objects in a given MinIO bucket.""" | |
name = "list_objects" | |
description = "List all objects in a given MinIO bucket." | |
def _run(self, bucket: str) -> Iterator[Dict]: | |
"""Execute the action.""" | |
console = Console() | |
session = requests.Session() | |
signing_region = urlparse(self.config.host).hostname.split('.')[0] | |
auth = ("{}:{}".format(self.config.access_key, self.config.secret_key), "") | |
request_dict = { | |
"verb": "GET", | |
"method": "Amazon S3", | |
"bucketName": bucket, | |
"host": self.config.host, | |
"port": self.config.port, | |
"awsRegion": signing_region, | |
"authorizationSignature": "", | |
"contentSHA256": "", | |
"contentLength": "", | |
"requestId": str(uuid.uuid4()), | |
"dateTime": "", | |
} | |
canonical_headers = "host:{}\n".format(self.config.host) | |
signed_headers = "host" | |
canonical_querystring = "" | |
payload_hash = "" | |
algorithm = "AWS4-HMAC-SHA256" | |
credential_scope = "/".join( | |
[ | |
self.config.awsRegion, | |
"2011-06-15/" + "s3" + "/" + signing_region + "/s3/aws4_request", | |
self.config.access_key, | |
] | |
).strip() | |
string_to_sign = ( | |
request_dict["verb"].encode() | |
+ b"\n" | |
+ canonical_uri.encode() | |
+ b"\n" | |
+ canonical_querystring.encode() | |
+ b"\n" | |
+ canonical_headers.encode() | |
+ b"\n" | |
+ signed_headers.encode() | |
+ b"\n" | |
+ hashlib.sha256(payload_hash.encode()).digest() | |
) | |
date_key = hmac.new( | |
key=("AWS4" + self.config.secret_key).encode(), msg=datetime.datetime(year, month, day).strftime("%Y%m%d").encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_regional_key = hmac.new( | |
key=date_key, msg=signing_region.encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_credential_key = hmac.new( | |
key=date_regional_key, msg=("s3/aws4_request").encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature_key = hmac.new( | |
key=date_credential_key, msg=algorithm.encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature = binascii.hexlify( | |
hmac.new(signature_key, msg=string_to_sign, digestmod=hashlib.sha256).digest() | |
) | |
request_dict["authorizationSignature"] = ( | |
algorithm | |
+ " " | |
+ "Credential=" | |
+ self.config.access_key | |
+ "/" | |
+ credential_scope | |
+ ", " | |
+ "SignedHeaders=" | |
+ signed_headers | |
+ ", " | |
+ "Signature=" | |
+ signature.decode("utf-8") | |
) | |
headers = { | |
"Host": self.config.host, | |
"Date": request_dict["dateTime"], | |
"Authorization": request_dict["authorizationSignature"], | |
} | |
url = ( | |
"https://" | |
+ request_dict["host"] | |
+ ":" | |
+ str(request_dict["port"]) | |
+ "/" | |
+ request_dict["bucketName"] | |
+ "?" | |
+ request_dict["canonicalQuerystring"] | |
) | |
response = session.get( | |
url, | |
headers=headers, | |
timeout=10, | |
verify=self.config.use_ssl, | |
) | |
if response.status_code == 200: | |
xml = ET.fromstring(response.content) | |
for elem in xml.iter(): | |
if elem.tag == "{http://s3.amazonaws.com/doc/2006-03-01/}Contents": | |
yield { | |
"name": elem.attrib["key"], | |
"size": int(elem.attrib["size"]) | |
} | |
else: | |
console.log(f"Failed to list objects in bucket '{bucket}', status code: {response.status_code}") | |
class PutObjectTool(MinIoTool): | |
"""Put an object in a given MinIO bucket.""" | |
name = "put_object" | |
description = "Put an object in a given MinIO bucket." | |
def _run(self, bucket: str, object_name: str, file_path: str) -> None: | |
"""Execute the action.""" | |
console = Console() | |
try: | |
with open(file_path, "rb") as file: | |
file_content = file.read() | |
mimetype = Magic(mime=True).from_file(file_path) | |
signing_region = urlparse(self.config.host).hostname.split('.')[0] | |
auth = ("{}:{}".format(self.config.access_key, self.config.secret_key), "") | |
request_dict = { | |
"verb": "PUT", | |
"method": "Amazon S3", | |
"bucketName": bucket, | |
"objectName": object_name, | |
"host": self.config.host, | |
"port": self.config.port, | |
"awsRegion": signing_region, | |
"authorizationSignature": "", | |
"contentSHA256": "", | |
"contentLength": str(len(file_content)), | |
"requestId": str(uuid.uuid4()), | |
"dateTime": "", | |
} | |
canonical_headers = "host:{}\n".format(self.config.host) | |
signed_headers = "host" | |
canonical_querystring = "" | |
payload_hash = hashlib.sha256(file_content).hexdigest() | |
algorithm = "AWS4-HMAC-SHA256" | |
credential_scope = "/".join( | |
[ | |
self.config.awsRegion, | |
"2011-06-15/" + "s3" + "/" + signing_region + "/s3/aws4_request", | |
self.config.access_key, | |
] | |
).strip() | |
string_to_sign = ( | |
request_dict["verb"].encode() | |
+ b"\n" | |
+ canonical_uri.encode() | |
+ b"\n" | |
+ canonical_querystring.encode() | |
+ b"\n" | |
+ canonical_headers.encode() | |
+ b"\n" | |
+ signed_headers.encode() | |
+ b"\n" | |
+ hashlib.sha256(payload_hash.encode()).digest() | |
) | |
date_key = hmac.new( | |
key=("AWS4" + self.config.secret_key).encode(), msg=datetime.datetime(year, month, day).strftime("%Y%m%d").encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_regional_key = hmac.new( | |
key=date_key, msg=signing_region.encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_credential_key = hmac.new( | |
key=date_regional_key, msg=("s3/aws4_request").encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature_key = hmac.new( | |
key=date_credential_key, msg=algorithm.encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature = binascii.hexlify( | |
hmac.new(signature_key, msg=string_to_sign, digestmod=hashlib.sha256).digest() | |
) | |
request_dict["authorizationSignature"] = ( | |
algorithm | |
+ " " | |
+ "Credential=" | |
+ self.config.access_key | |
+ "/" | |
+ credential_scope | |
+ ", " | |
+ "SignedHeaders=" | |
+ signed_headers | |
+ ", " | |
+ "Signature=" | |
+ signature.decode("utf-8") | |
) | |
headers = { | |
"Host": self.config.host, | |
"Date": request_dict["dateTime"], | |
"Authorization": request_dict["authorizationSignature"], | |
"Content-Disposition": f'attachment; filename="{object_name}"', | |
"Content-Type": mimetype, | |
} | |
url = ( | |
"https://" | |
+ request_dict["host"] | |
+ ":" | |
+ str(request_dict["port"]) | |
+ "/" | |
+ request_dict["bucketName"] | |
+ "/" | |
+ request_dict["objectName"] | |
) | |
response = session.put( | |
url, | |
headers=headers, | |
data=file_content, | |
timeout=10, | |
verify=self.config.use_ssl, | |
) | |
if response.status_code == 200: | |
console.log(f"Successfully uploaded '{object_name}' to '{bucket}'") | |
else: | |
console.log(f"Failed to upload '{object_name}' to '{bucket}', status code: {response.status_code}") | |
except FileNotFoundError as err: | |
console.log(f"Could not locate file '{file_path}'\n{err}") | |
class GetObjectTool(MinIoTool): | |
"""Get an object from a given MinIO bucket.""" | |
name = "get_object" | |
description = "Get an object from a given MinIO bucket." | |
def _run(self, bucket: str, object_name: str, dest_folder: str) -> None: | |
"""Execute the action.""" | |
console = Console() | |
signing_region = urlparse(self.config.host).hostname.split('.')[0] | |
auth = ("{}:{}".format(self.config.access_key, self.config.secret_key), "") | |
request_dict = { | |
"verb": "GET", | |
"method": "Amazon S3", | |
"bucketName": bucket, | |
"objectName": object_name, | |
"host": self.config.host, | |
"port": self.config.port, | |
"awsRegion": signing_region, | |
"authorizationSignature": "", | |
"contentSHA256": "", | |
"contentLength": "", | |
"requestId": str(uuid.uuid4()), | |
"dateTime": "", | |
} | |
canonical_headers = "host:{}\n".format(self.config.host) | |
signed_headers = "host" | |
canonical_querystring = "" | |
payload_hash = "" | |
algorithm = "AWS4-HMAC-SHA256" | |
credential_scope = "/".join( | |
[ | |
self.config.awsRegion, | |
"2011-06-15/" + "s3" + "/" + signing_region + "/s3/aws4_request", | |
self.config.access_key, | |
] | |
).strip() | |
string_to_sign = ( | |
request_dict["verb"].encode() | |
+ b"\n" | |
+ canonical_uri.encode() | |
+ b"\n" | |
+ canonical_querystring.encode() | |
+ b"\n" | |
+ canonical_headers.encode() | |
+ b"\n" | |
+ signed_headers.encode() | |
+ b"\n" | |
+ hashlib.sha256(payload_hash.encode()).digest() | |
) | |
date_key = hmac.new( | |
key=("AWS4" + self.config.secret_key).encode(), msg=datetime.datetime(year, month, day).strftime("%Y%m%d").encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_regional_key = hmac.new( | |
key=date_key, msg=signing_region.encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_credential_key = hmac.new( | |
key=date_regional_key, msg=("s3/aws4_request").encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature_key = hmac.new( | |
key=date_credential_key, msg=algorithm.encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature = binascii.hexlify( | |
hmac.new(signature_key, msg=string_to_sign, digestmod=hashlib.sha256).digest() | |
) | |
request_dict["authorizationSignature"] = ( | |
algorithm | |
+ " " | |
+ "Credential=" | |
+ self.config.access_key | |
+ "/" | |
+ credential_scope | |
+ ", " | |
+ "SignedHeaders=" | |
+ signed_headers | |
+ ", " | |
+ "Signature=" | |
+ signature.decode("utf-8") | |
) | |
headers = { | |
"Host": self.config.host, | |
"Date": request_dict["dateTime"], | |
"Authorization": request_dict["authorizationSignature"], | |
} | |
url = ( | |
"https://" | |
+ request_dict["host"] | |
+ ":" | |
+ str(request_dict["port"]) | |
+ "/" | |
+ request_dict["bucketName"] | |
+ "/" | |
+ request_dict["objectName"] | |
+ "?" | |
+ request_dict["canonicalQuerystring"] | |
) | |
response = session.get( | |
url, | |
headers=headers, | |
timeout=10, | |
verify=self.config.use_ssl, | |
stream=True, | |
) | |
if response.status_code == 200: | |
file_path = os.path.join(dest_folder, object_name) | |
with open(file_path, 'wb') as f: | |
for chunk in response: | |
f.write(chunk) | |
console.log(f"Downloaded '{object_name}' from '{bucket}' to '{file_path}'") | |
else: | |
console.log(f"Failed to download '{object_name}' from '{bucket}', status code: {response.status_code}") | |
class RemoveObjectTool(MinIoTool): | |
"""Remove an object from a given MinIO bucket.""" | |
name = "remove_object" | |
description = "Remove an object from a given MinIO bucket." | |
def _run(self, bucket: str, object_name: str) -> Tuple[bool, str]: | |
"""Execute the action.""" | |
console = Console() | |
signing_region = urlparse(self.config.host).hostname.split('.')[0] | |
auth = ("{}:{}".format(self.config.access_key, self.config.secret_key), "") | |
year = datetime.datetime.now().year | |
month = datetime.datetime.now().month | |
day = datetime.datetime.now().day | |
request_dict = { | |
"verb": "DELETE", | |
"method": "Amazon S3", | |
"bucketName": bucket, | |
"objectName": object_name, | |
"host": self.config.host, | |
"port": self.config.port, | |
"awsRegion": signing_region, | |
"authorizationSignature": "", | |
"contentSHA256": "", | |
"contentLength": "", | |
"requestId": str(uuid.uuid4()), | |
"dateTime": "", | |
} | |
canonical_headers = "host:{}\n".format(self.config.host) | |
signed_headers = "host" | |
canonical_querystring = "" | |
payload_hash = "" | |
algorithm = "AWS4-HMAC-SHA256" | |
credential_scope = "/".join( | |
[ | |
self.config.awsRegion, | |
"2011-06-15/" + "s3" + "/" + signing_region + "/s3/aws4_request", | |
self.config.access_key, | |
] | |
).strip() | |
string_to_sign = ( | |
request_dict["verb"].encode() | |
+ b"\n" | |
+ canonical_uri.encode() | |
+ b"\n" | |
+ canonical_querystring.encode() | |
+ b"\n" | |
+ canonical_headers.encode() | |
+ b"\n" | |
+ signed_headers.encode() | |
+ b"\n" | |
+ hashlib.sha256(payload_hash.encode()).digest() | |
) | |
date_key = hmac.new( | |
key=("AWS4" + self.config.secret_key).encode(), msg=datetime.datetime(year, month, day).strftime("%Y%m%d").encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_regional_key = hmac.new( | |
key=date_key, msg=signing_region.encode(), digestmod=hashlib.sha256 | |
).digest() | |
date_credential_key = hmac.new( | |
key=date_regional_key, msg=("s3/aws4_request").encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature_key = hmac.new( | |
key=date_credential_key, msg=algorithm.encode(), digestmod=hashlib.sha256 | |
).digest() | |
signature = binascii.hexlify( | |
hmac.new(signature_key, msg=string_to_sign, digestmod=hashlib.sha256).digest() | |
) | |
request_dict["authorizationSignature"] = ( | |
algorithm | |
+ " " | |
+ "Credential=" | |
+ self.config.access_key | |
+ "/" | |
+ credential_scope | |
+ ", " | |
+ "SignedHeaders=" | |
+ signed_headers | |
+ ", " | |
+ "Signature=" | |
+ signature.decode("utf-8") | |
) | |
headers = { | |
"Host": self.config.host, | |
"Date": request_dict["dateTime"], | |
"Authorization": request_dict["authorizationSignature"], | |
} | |
url = ( | |
"https://" | |
+ request_dict["host"] | |
+ ":" | |
+ str(request_dict["port"]) | |
+ "/" | |
+ request_dict["bucketName"] | |
+ "/" | |
+ request_dict["objectName"] | |
) | |
try: | |
response = requests.request( | |
method=request_dict["verb"], | |
url=url, | |
headers=headers, | |
timeout=10, | |
verify=self.config.use_ssl, | |
) | |
if response.status_code == 204: | |
return True, "Object removed successfully." | |
else: | |
return False, f"Error removing object: {response.reason}. Status Code: {response.status_code}" | |
except Exception as e: | |
return False, f"Error making request: {str(e)}" | |
class MinIoAgent(ZeroShotAgent): | |
"""MinIo Agent capable of performing tasks using MinIoTools.""" | |
toolkit = ToolKit( | |
tools=[ | |
ListObjectsTool(MinIoConfig(...)), | |
PutObjectTool(MinIoConfig(...)), | |
GetObjectTool(MinIoConfig(...)), | |
RemoveObjectTool(MinIoConfig(...)), | |
] | |
) | |
def format_inputs(self, tools: List[Tool], input_stmt: str) -> Tuple[List[Input], str]: | |
"""Format inputs passed to the tools.""" | |
parsed_input = parse(input_stmt) | |
bucket_name = cast(str, parsed_input.bucket_name) | |
objective = parsed_input.objective | |
input_actions = [] | |
if objective == "list": | |
input_actions.append(Input(tool_name=ListObjectsTool.name, args={"bucket": bucket_name})) | |
elif objective == "upload": | |
object_name = cast(str, parsed_input.object_name) | |
file_path = cast(str, parsed_input.file_path) | |
input_actions.append(Input(tool_name=PutObjectTool.name, args={"bucket": bucket_name, "object_name": object_name, "file_path": file_path})) | |
elif objective == "download": | |
object_name = cast(str, parsed_input.object_name) | |
dest_folder = cast(str, parsed_input.destination_folder) | |
input_actions.append(Input(tool_name=GetObjectTool.name, args={"bucket": bucket_name, "object_name": object_name, "dest_folder": dest_folder})) | |
elif objective == "remove": | |
object_name = cast(str, parsed_input.object_name) | |
input_actions.append(Input(tool_name=RemoveObjectTool.name, args={"bucket": bucket_name, "object_name": object_name})) | |
else: | |
raise ValueError(f"Unsupported objective {objective}") | |
return input_actions, objective | |
async def _arun(self, stmt: str, stop: Optional[Callable[[str], bool]] = None) -> str: | |
"""Run the agent with the statement.""" | |
input_actions, objective = self.format_inputs(self.toolkit.tools, InputStmt(stmt)) | |
intermediate_steps = [] | |
for action in input_actions: | |
response = await self.toolkit.execute(action) | |
intermediate_steps.append(Output(step_number=len(intermediate_steps)+1, tool_name=action.tool_name, tool_output=response)) | |
if stop and stop(response): | |
break | |
return Output(step_number=len(intermediate_steps)+1, tool_name="MinIoAgent", tool_output=intermediate_steps) | |
async def main() -> None: | |
"""Main function demonstrating MinIoAgent usage.""" | |
agent = MinIoAgent(...) | |
await agent.run("Upload a file named 'example.txt' located at '/path/to/example.txt' to 'my-bucket'", stop=lambda r: r.startswith("Successfully uploaded")) | |
await agent.run("Download 'example.txt' from 'my-bucket' to './downloads/'", stop=lambda r: "Downloaded" in r) | |
await agent.run("List objects in 'my-bucket'", stop=lambda r: len(r) > 0) | |
await agent.run("Remove 'example.txt' from 'my-bucket'", stop=lambda r: "Removed" in r) | |
if __name__ == "__main__": | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(main()) | |
#config = MinIoConfig( | |
# host="your_minio_domain", | |
# port=9000, | |
# access_key="your_access_key", | |
# secret_key="your_secret_key", | |
# use_ssl=False, | |
#) | |
# | |
#remove_object_tool = RemoveObjectTool(config) | |
#success, message = remove_object_tool.remove_object("your_bucket_name", "your_object_name") | |
#print(message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment