Last active
May 3, 2025 07:15
-
-
Save isaias-b/5b67ef499e497f21c9a9481b6a266f8c to your computer and use it in GitHub Desktop.
MCP Commons for smooth tool declarations in python similar to the typescript SDKs variant
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| from collections.abc import Awaitable, Callable, Sequence | |
| from typing import Any, TypeVar | |
| from mcp.server import Server | |
| from mcp.types import EmbeddedResource, ImageContent, TextContent, Tool | |
| from pydantic import ConfigDict | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T") | |
| class ToolDeclaration(Tool): | |
| """Extended Tool with additional fields needed for implementation.""" | |
| handler: Callable[..., Awaitable[T]] | |
| formatter: Callable[[T], Sequence[TextContent | ImageContent | EmbeddedResource]] | |
| @property | |
| def input_schema(self) -> dict[str, Any]: | |
| return self.inputSchema | |
| def single_text_content(text: str) -> Sequence[TextContent]: | |
| return [TextContent(type="text", text=text)] | |
| class EnhancedServer(Server): | |
| """Server with built-in tool registration capabilities.""" | |
| def __init__(self, name: str, version: str | None = None, instructions: str | None = None): | |
| super().__init__(name, version, instructions) | |
| self.tools: dict[str, ToolDeclaration] = {} | |
| # Register default handlers | |
| @self.list_tools() | |
| async def list_tools() -> list[Tool]: | |
| return list(self.tools.values()) | |
| @self.call_tool() | |
| async def call_tool(name: str, arguments: dict) -> list[TextContent]: | |
| logger.info(f"Calling tool {name} with arguments {arguments}") | |
| try: | |
| declaration = self.tools.get(name) | |
| if not declaration: | |
| logger.error(f"Tool {name} not found. Available tools: {list(self.tools.keys())}") | |
| return [TextContent(type="text", text=f"Error: Tool '{name}' not found")] | |
| # Log what we're about to do | |
| logger.info(f"Executing {name} with arguments: {arguments}") | |
| # Call the tool handler | |
| result = await declaration.handler(**arguments) | |
| # Format and return the result | |
| logger.info(f"Tool {name} executed successfully") | |
| return declaration.formatter(result) | |
| except Exception as e: | |
| logger.error(f"Error calling tool {name}: {str(e)}", exc_info=True) | |
| # Return an error message as text content instead of raising | |
| error_message = f"Error calling tool {name}: {str(e)}" | |
| return [TextContent(type="text", text=error_message)] | |
| def declare_tool( | |
| self, | |
| name: str | None = None, | |
| description: str | None = None, | |
| input_schema: dict[str, Any] = {}, | |
| model_config: ConfigDict = ConfigDict(extra="allow"), | |
| formatter: Callable[[T], Sequence[TextContent | ImageContent | EmbeddedResource]] = lambda x: single_text_content(str(x)), | |
| ) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]: | |
| """Decorator to register a function as a tool.""" | |
| def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: | |
| tool_name = name or func.__name__ | |
| self.tools[tool_name] = ToolDeclaration( | |
| name=tool_name, description=description, inputSchema=input_schema, model_config=model_config, handler=func, formatter=formatter | |
| ) | |
| return func | |
| return decorator |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import logging | |
| from pathlib import Path | |
| from typing import Callable, Awaitable, Optional, List, Any | |
| import git | |
| from pydantic import BaseModel, Field | |
| from functools import wraps | |
| from mcp_commons import single_text_content, EnhancedServer | |
| from mcp.server.stdio import stdio_server | |
| logger = logging.getLogger(__name__) | |
| server = EnhancedServer("mcp-git") | |
| def with_repo(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]: | |
| """Decorator to handle repo parameter for git tools.""" | |
| @wraps(func) | |
| async def wrapper(*args, **kwargs): | |
| repo_path = kwargs.pop('repo_path', None) | |
| if repo_path is None: | |
| logger.debug("No repository path provided, skipping repo initialization") | |
| return await func(*args, **kwargs) | |
| try: | |
| repo_path = Path(repo_path) | |
| logger.debug(f"Initializing repository at {repo_path}") | |
| if not repo_path.exists(): | |
| logger.error(f"Repository path does not exist: {repo_path}") | |
| raise ValueError(f"Repository path does not exist: {repo_path}") | |
| repo = git.Repo(repo_path) | |
| logger.debug(f"Successfully initialized repository at {repo_path}") | |
| return await func(repo, *args, **kwargs) | |
| except git.InvalidGitRepositoryError: | |
| logger.error(f"Not a valid git repository: {repo_path}") | |
| raise ValueError(f"Not a valid git repository: {repo_path}") | |
| except git.NoSuchPathError: | |
| logger.error(f"Invalid repository path: {repo_path}") | |
| raise ValueError(f"Invalid repository path: {repo_path}") | |
| except Exception as e: | |
| logger.error(f"Error accessing repository: {str(e)}") | |
| raise ValueError(f"Error accessing repository: {str(e)}") | |
| return wrapper | |
| REPO_PATH_FIELD = Field(description="The path to the git repository") | |
| # Pydantic models for input validation | |
| class GitStatus(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| class GitDiffUnstaged(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| class GitDiffStaged(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| class GitDiff(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| target: str = Field(description="The target to diff against") | |
| class GitCommit(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| message: str = Field(description="The commit message") | |
| class GitAdd(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| files: List[str] = Field(description="The files to add") | |
| class GitReset(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| class GitLog(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| max_count: int = Field(default=10, description="Maximum number of commits to show") | |
| class GitCreateBranch(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| branch_name: str = Field(description="The name of the new branch") | |
| base_branch: Optional[str] = Field(default=None, description="The base branch to create from") | |
| class GitCheckout(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| branch_name: str = Field(description="The branch to checkout") | |
| class GitShow(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| revision: str = Field(description="The revision to show") | |
| class GitInit(BaseModel): | |
| repo_path: str = REPO_PATH_FIELD | |
| @server.declare_tool( | |
| description="Shows the working tree status", | |
| input_schema=GitStatus.model_json_schema(), | |
| formatter=lambda result: single_text_content(f"Repository status:\n{result}") | |
| ) | |
| @with_repo | |
| async def git_status(repo: git.Repo) -> str: | |
| return repo.git.status() | |
| @server.declare_tool( | |
| description="Shows changes in the working directory that are not yet staged", | |
| input_schema=GitDiffUnstaged.model_json_schema(), | |
| formatter=lambda result: single_text_content(f"Unstaged changes:\n{result}") | |
| ) | |
| @with_repo | |
| async def git_diff_unstaged(repo: git.Repo) -> str: | |
| return repo.git.diff() | |
| @server.declare_tool( | |
| description="Shows changes that are staged for commit", | |
| input_schema=GitDiffStaged.model_json_schema(), | |
| formatter=lambda result: single_text_content(f"Staged changes:\n{result}") | |
| ) | |
| @with_repo | |
| async def git_diff_staged(repo: git.Repo) -> str: | |
| return repo.git.diff("--cached") | |
| @server.declare_tool( | |
| description="Shows differences between branches or commits", | |
| input_schema=GitDiff.model_json_schema(), | |
| formatter=lambda result: single_text_content(f"Diff with target:\n{result}") | |
| ) | |
| @with_repo | |
| async def git_diff(repo: git.Repo, target: str) -> str: | |
| return repo.git.diff(target) | |
| @server.declare_tool(description="Records changes to the repository", input_schema=GitCommit.model_json_schema()) | |
| @with_repo | |
| async def git_commit(repo: git.Repo, message: str) -> str: | |
| commit = repo.index.commit(message) | |
| return f"Changes committed successfully with hash {commit.hexsha}" | |
| @server.declare_tool(description="Adds file contents to the staging area", input_schema=GitAdd.model_json_schema()) | |
| @with_repo | |
| async def git_add(repo: git.Repo, files: List[str]) -> str: | |
| repo.index.add(files) | |
| return "Files staged successfully" | |
| @server.declare_tool(description="Unstages all staged changes", input_schema=GitReset.model_json_schema()) | |
| @with_repo | |
| async def git_reset(repo: git.Repo) -> str: | |
| repo.index.reset() | |
| return "All staged changes reset" | |
| @server.declare_tool( | |
| description="Shows the commit logs", | |
| input_schema=GitLog.model_json_schema(), | |
| formatter=lambda result: single_text_content("Commit history:\n" + "\n".join(result)) | |
| ) | |
| @with_repo | |
| async def git_log(repo: git.Repo, max_count: int = 10) -> List[str]: | |
| commits = list(repo.iter_commits(max_count=max_count)) | |
| log = [] | |
| for commit in commits: | |
| log.append( | |
| f"Commit: {commit.hexsha}\n" | |
| f"Author: {commit.author}\n" | |
| f"Date: {commit.authored_datetime}\n" | |
| f"Message: {commit.message}\n" | |
| ) | |
| return log | |
| @server.declare_tool(description="Creates a new branch from an optional base branch", input_schema=GitCreateBranch.model_json_schema()) | |
| @with_repo | |
| async def git_create_branch(repo: git.Repo, branch_name: str, base_branch: Optional[str] = None) -> str: | |
| if base_branch: | |
| base = repo.refs[base_branch] | |
| else: | |
| base = repo.active_branch | |
| repo.create_head(branch_name, base) | |
| return f"Created branch '{branch_name}' from '{base.name}'" | |
| @server.declare_tool(description="Switches branches", input_schema=GitCheckout.model_json_schema()) | |
| @with_repo | |
| async def git_checkout(repo: git.Repo, branch_name: str) -> str: | |
| repo.git.checkout(branch_name) | |
| return f"Switched to branch '{branch_name}'" | |
| @server.declare_tool(description="Shows the contents of a commit", input_schema=GitShow.model_json_schema()) | |
| @with_repo | |
| async def git_show(repo: git.Repo, revision: str) -> str: | |
| commit = repo.commit(revision) | |
| output = [ | |
| f"Commit: {commit.hexsha}\n" | |
| f"Author: {commit.author}\n" | |
| f"Date: {commit.authored_datetime}\n" | |
| f"Message: {commit.message}\n" | |
| ] | |
| if commit.parents: | |
| parent = commit.parents[0] | |
| diff = parent.diff(commit, create_patch=True) | |
| else: | |
| diff = commit.diff(git.NULL_TREE, create_patch=True) | |
| for d in diff: | |
| output.append(f"\n--- {d.a_path}\n+++ {d.b_path}\n") | |
| output.append(d.diff.decode('utf-8')) | |
| return "".join(output) | |
| @server.declare_tool(description="Initialize a new Git repository", input_schema=GitInit.model_json_schema()) | |
| async def git_init(repo_path: str) -> str: | |
| try: | |
| repo = git.Repo.init(path=repo_path, mkdir=True) | |
| return f"Initialized empty Git repository in {repo.git_dir}" | |
| except Exception as e: | |
| return f"Error initializing repository: {str(e)}" | |
| async def serve() -> None: | |
| """Start the git server.""" | |
| options = server.create_initialization_options() | |
| async with stdio_server() as (read_stream, write_stream): | |
| await server.run(read_stream, write_stream, options, raise_exceptions=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment