Skip to content

Instantly share code, notes, and snippets.

@isaias-b
Last active May 3, 2025 07:15
Show Gist options
  • Select an option

  • Save isaias-b/5b67ef499e497f21c9a9481b6a266f8c to your computer and use it in GitHub Desktop.

Select an option

Save isaias-b/5b67ef499e497f21c9a9481b6a266f8c to your computer and use it in GitHub Desktop.
MCP Commons for smooth tool declarations in python similar to the typescript SDKs variant
import logging
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, TypeVar
from mcp.server import Server
from mcp.types import EmbeddedResource, ImageContent, TextContent, Tool
from pydantic import ConfigDict
logger = logging.getLogger(__name__)
T = TypeVar("T")
class ToolDeclaration(Tool):
"""Extended Tool with additional fields needed for implementation."""
handler: Callable[..., Awaitable[T]]
formatter: Callable[[T], Sequence[TextContent | ImageContent | EmbeddedResource]]
@property
def input_schema(self) -> dict[str, Any]:
return self.inputSchema
def single_text_content(text: str) -> Sequence[TextContent]:
return [TextContent(type="text", text=text)]
class EnhancedServer(Server):
"""Server with built-in tool registration capabilities."""
def __init__(self, name: str, version: str | None = None, instructions: str | None = None):
super().__init__(name, version, instructions)
self.tools: dict[str, ToolDeclaration] = {}
# Register default handlers
@self.list_tools()
async def list_tools() -> list[Tool]:
return list(self.tools.values())
@self.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
logger.info(f"Calling tool {name} with arguments {arguments}")
try:
declaration = self.tools.get(name)
if not declaration:
logger.error(f"Tool {name} not found. Available tools: {list(self.tools.keys())}")
return [TextContent(type="text", text=f"Error: Tool '{name}' not found")]
# Log what we're about to do
logger.info(f"Executing {name} with arguments: {arguments}")
# Call the tool handler
result = await declaration.handler(**arguments)
# Format and return the result
logger.info(f"Tool {name} executed successfully")
return declaration.formatter(result)
except Exception as e:
logger.error(f"Error calling tool {name}: {str(e)}", exc_info=True)
# Return an error message as text content instead of raising
error_message = f"Error calling tool {name}: {str(e)}"
return [TextContent(type="text", text=error_message)]
def declare_tool(
self,
name: str | None = None,
description: str | None = None,
input_schema: dict[str, Any] = {},
model_config: ConfigDict = ConfigDict(extra="allow"),
formatter: Callable[[T], Sequence[TextContent | ImageContent | EmbeddedResource]] = lambda x: single_text_content(str(x)),
) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
"""Decorator to register a function as a tool."""
def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
tool_name = name or func.__name__
self.tools[tool_name] = ToolDeclaration(
name=tool_name, description=description, inputSchema=input_schema, model_config=model_config, handler=func, formatter=formatter
)
return func
return decorator
import logging
from pathlib import Path
from typing import Callable, Awaitable, Optional, List, Any
import git
from pydantic import BaseModel, Field
from functools import wraps
from mcp_commons import single_text_content, EnhancedServer
from mcp.server.stdio import stdio_server
logger = logging.getLogger(__name__)
server = EnhancedServer("mcp-git")
def with_repo(func: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
"""Decorator to handle repo parameter for git tools."""
@wraps(func)
async def wrapper(*args, **kwargs):
repo_path = kwargs.pop('repo_path', None)
if repo_path is None:
logger.debug("No repository path provided, skipping repo initialization")
return await func(*args, **kwargs)
try:
repo_path = Path(repo_path)
logger.debug(f"Initializing repository at {repo_path}")
if not repo_path.exists():
logger.error(f"Repository path does not exist: {repo_path}")
raise ValueError(f"Repository path does not exist: {repo_path}")
repo = git.Repo(repo_path)
logger.debug(f"Successfully initialized repository at {repo_path}")
return await func(repo, *args, **kwargs)
except git.InvalidGitRepositoryError:
logger.error(f"Not a valid git repository: {repo_path}")
raise ValueError(f"Not a valid git repository: {repo_path}")
except git.NoSuchPathError:
logger.error(f"Invalid repository path: {repo_path}")
raise ValueError(f"Invalid repository path: {repo_path}")
except Exception as e:
logger.error(f"Error accessing repository: {str(e)}")
raise ValueError(f"Error accessing repository: {str(e)}")
return wrapper
REPO_PATH_FIELD = Field(description="The path to the git repository")
# Pydantic models for input validation
class GitStatus(BaseModel):
repo_path: str = REPO_PATH_FIELD
class GitDiffUnstaged(BaseModel):
repo_path: str = REPO_PATH_FIELD
class GitDiffStaged(BaseModel):
repo_path: str = REPO_PATH_FIELD
class GitDiff(BaseModel):
repo_path: str = REPO_PATH_FIELD
target: str = Field(description="The target to diff against")
class GitCommit(BaseModel):
repo_path: str = REPO_PATH_FIELD
message: str = Field(description="The commit message")
class GitAdd(BaseModel):
repo_path: str = REPO_PATH_FIELD
files: List[str] = Field(description="The files to add")
class GitReset(BaseModel):
repo_path: str = REPO_PATH_FIELD
class GitLog(BaseModel):
repo_path: str = REPO_PATH_FIELD
max_count: int = Field(default=10, description="Maximum number of commits to show")
class GitCreateBranch(BaseModel):
repo_path: str = REPO_PATH_FIELD
branch_name: str = Field(description="The name of the new branch")
base_branch: Optional[str] = Field(default=None, description="The base branch to create from")
class GitCheckout(BaseModel):
repo_path: str = REPO_PATH_FIELD
branch_name: str = Field(description="The branch to checkout")
class GitShow(BaseModel):
repo_path: str = REPO_PATH_FIELD
revision: str = Field(description="The revision to show")
class GitInit(BaseModel):
repo_path: str = REPO_PATH_FIELD
@server.declare_tool(
description="Shows the working tree status",
input_schema=GitStatus.model_json_schema(),
formatter=lambda result: single_text_content(f"Repository status:\n{result}")
)
@with_repo
async def git_status(repo: git.Repo) -> str:
return repo.git.status()
@server.declare_tool(
description="Shows changes in the working directory that are not yet staged",
input_schema=GitDiffUnstaged.model_json_schema(),
formatter=lambda result: single_text_content(f"Unstaged changes:\n{result}")
)
@with_repo
async def git_diff_unstaged(repo: git.Repo) -> str:
return repo.git.diff()
@server.declare_tool(
description="Shows changes that are staged for commit",
input_schema=GitDiffStaged.model_json_schema(),
formatter=lambda result: single_text_content(f"Staged changes:\n{result}")
)
@with_repo
async def git_diff_staged(repo: git.Repo) -> str:
return repo.git.diff("--cached")
@server.declare_tool(
description="Shows differences between branches or commits",
input_schema=GitDiff.model_json_schema(),
formatter=lambda result: single_text_content(f"Diff with target:\n{result}")
)
@with_repo
async def git_diff(repo: git.Repo, target: str) -> str:
return repo.git.diff(target)
@server.declare_tool(description="Records changes to the repository", input_schema=GitCommit.model_json_schema())
@with_repo
async def git_commit(repo: git.Repo, message: str) -> str:
commit = repo.index.commit(message)
return f"Changes committed successfully with hash {commit.hexsha}"
@server.declare_tool(description="Adds file contents to the staging area", input_schema=GitAdd.model_json_schema())
@with_repo
async def git_add(repo: git.Repo, files: List[str]) -> str:
repo.index.add(files)
return "Files staged successfully"
@server.declare_tool(description="Unstages all staged changes", input_schema=GitReset.model_json_schema())
@with_repo
async def git_reset(repo: git.Repo) -> str:
repo.index.reset()
return "All staged changes reset"
@server.declare_tool(
description="Shows the commit logs",
input_schema=GitLog.model_json_schema(),
formatter=lambda result: single_text_content("Commit history:\n" + "\n".join(result))
)
@with_repo
async def git_log(repo: git.Repo, max_count: int = 10) -> List[str]:
commits = list(repo.iter_commits(max_count=max_count))
log = []
for commit in commits:
log.append(
f"Commit: {commit.hexsha}\n"
f"Author: {commit.author}\n"
f"Date: {commit.authored_datetime}\n"
f"Message: {commit.message}\n"
)
return log
@server.declare_tool(description="Creates a new branch from an optional base branch", input_schema=GitCreateBranch.model_json_schema())
@with_repo
async def git_create_branch(repo: git.Repo, branch_name: str, base_branch: Optional[str] = None) -> str:
if base_branch:
base = repo.refs[base_branch]
else:
base = repo.active_branch
repo.create_head(branch_name, base)
return f"Created branch '{branch_name}' from '{base.name}'"
@server.declare_tool(description="Switches branches", input_schema=GitCheckout.model_json_schema())
@with_repo
async def git_checkout(repo: git.Repo, branch_name: str) -> str:
repo.git.checkout(branch_name)
return f"Switched to branch '{branch_name}'"
@server.declare_tool(description="Shows the contents of a commit", input_schema=GitShow.model_json_schema())
@with_repo
async def git_show(repo: git.Repo, revision: str) -> str:
commit = repo.commit(revision)
output = [
f"Commit: {commit.hexsha}\n"
f"Author: {commit.author}\n"
f"Date: {commit.authored_datetime}\n"
f"Message: {commit.message}\n"
]
if commit.parents:
parent = commit.parents[0]
diff = parent.diff(commit, create_patch=True)
else:
diff = commit.diff(git.NULL_TREE, create_patch=True)
for d in diff:
output.append(f"\n--- {d.a_path}\n+++ {d.b_path}\n")
output.append(d.diff.decode('utf-8'))
return "".join(output)
@server.declare_tool(description="Initialize a new Git repository", input_schema=GitInit.model_json_schema())
async def git_init(repo_path: str) -> str:
try:
repo = git.Repo.init(path=repo_path, mkdir=True)
return f"Initialized empty Git repository in {repo.git_dir}"
except Exception as e:
return f"Error initializing repository: {str(e)}"
async def serve() -> None:
"""Start the git server."""
options = server.create_initialization_options()
async with stdio_server() as (read_stream, write_stream):
await server.run(read_stream, write_stream, options, raise_exceptions=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment