Skip to content

Instantly share code, notes, and snippets.

@seporaitis
Last active May 16, 2025 22:39
Show Gist options
  • Save seporaitis/a02d08a9d9ef8d1bf04e6cb1a31ca154 to your computer and use it in GitHub Desktop.
Save seporaitis/a02d08a9d9ef8d1bf04e6cb1a31ca154 to your computer and use it in GitHub Desktop.
#!/usr/bin/env pyright
#
# Either make this file executable and run it or just `pyright pyright_test.py` it.
#
import functools
from collections.abc import Callable
from typing import TYPE_CHECKING, Optional, ParamSpec, TypeVar
from ddtrace import tracer
_PFunc = ParamSpec("_PFunc")
_RFunc = TypeVar("_RFunc")
def type_safe_tracer_wrap(
name: Optional[str] = None,
service: Optional[str] = None,
resource: Optional[str] = None,
span_type: Optional[str] = None,
) -> Callable[[Callable[_PFunc, _RFunc]], Callable[_PFunc, _RFunc]]:
def wrapper(func: Callable[_PFunc, _RFunc]) -> Callable[_PFunc, _RFunc]:
@functools.wraps(func)
def wrapped(*args: _PFunc.args, **kwargs: _PFunc.kwargs) -> _RFunc:
return tracer.wrap(name=name, service=service, resource=resource, span_type=span_type)(func)(
*args, **kwargs
)
return wrapped
return wrapper
if not TYPE_CHECKING:
# one way to avoid double-wrapping when the code actually runs
type_safe_tracer_wrap = tracer.wrap
@tracer.wrap(name="foo")
def foo(a: int, b: int) -> int:
return a + b
@type_safe_tracer_wrap(name="bar")
def bar(a: int, b: int) -> int:
return a + b
if __name__ == "__main__":
print(foo(1)) # no error(?!)
print(bar(1)) # error: Argument missing for parameter "b"
# What's happening here?
#
# tracer.wraps is typed using AnyCallable = TypeVar("AnyCallable", bound=Callable)
# which is effectively Callable[[Any, ...], Any] a.k.a. the wrapper does not preserve
# parameter type information.
#
# So when pyright sees:
#
# @tracer.wrap(name="foo")
# def foo(a: int, b: int) -> int:
# return a + b
#
# The final result is a function `foo(*args: Any, **kwargs: Any) -> Any`
#
# On the other hand `type_safe_tracer_wrap` correctly passes through the
# ParamSpec and return type.
#
# @type_safe_tracer_wrap(name="bar")
# def bar(a: int, b: int) -> int:
# return a + b
#
# The final result is a function with parameter type information.
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment