Last active
May 16, 2025 22:39
-
-
Save seporaitis/a02d08a9d9ef8d1bf04e6cb1a31ca154 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env pyright | |
# | |
# Either make this file executable and run it or just `pyright pyright_test.py` it. | |
# | |
import functools | |
from collections.abc import Callable | |
from typing import TYPE_CHECKING, Optional, ParamSpec, TypeVar | |
from ddtrace import tracer | |
_PFunc = ParamSpec("_PFunc") | |
_RFunc = TypeVar("_RFunc") | |
def type_safe_tracer_wrap( | |
name: Optional[str] = None, | |
service: Optional[str] = None, | |
resource: Optional[str] = None, | |
span_type: Optional[str] = None, | |
) -> Callable[[Callable[_PFunc, _RFunc]], Callable[_PFunc, _RFunc]]: | |
def wrapper(func: Callable[_PFunc, _RFunc]) -> Callable[_PFunc, _RFunc]: | |
@functools.wraps(func) | |
def wrapped(*args: _PFunc.args, **kwargs: _PFunc.kwargs) -> _RFunc: | |
return tracer.wrap(name=name, service=service, resource=resource, span_type=span_type)(func)( | |
*args, **kwargs | |
) | |
return wrapped | |
return wrapper | |
if not TYPE_CHECKING: | |
# one way to avoid double-wrapping when the code actually runs | |
type_safe_tracer_wrap = tracer.wrap | |
@tracer.wrap(name="foo") | |
def foo(a: int, b: int) -> int: | |
return a + b | |
@type_safe_tracer_wrap(name="bar") | |
def bar(a: int, b: int) -> int: | |
return a + b | |
if __name__ == "__main__": | |
print(foo(1)) # no error(?!) | |
print(bar(1)) # error: Argument missing for parameter "b" | |
# What's happening here? | |
# | |
# tracer.wraps is typed using AnyCallable = TypeVar("AnyCallable", bound=Callable) | |
# which is effectively Callable[[Any, ...], Any] a.k.a. the wrapper does not preserve | |
# parameter type information. | |
# | |
# So when pyright sees: | |
# | |
# @tracer.wrap(name="foo") | |
# def foo(a: int, b: int) -> int: | |
# return a + b | |
# | |
# The final result is a function `foo(*args: Any, **kwargs: Any) -> Any` | |
# | |
# On the other hand `type_safe_tracer_wrap` correctly passes through the | |
# ParamSpec and return type. | |
# | |
# @type_safe_tracer_wrap(name="bar") | |
# def bar(a: int, b: int) -> int: | |
# return a + b | |
# | |
# The final result is a function with parameter type information. | |
# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment