Skip to content

Instantly share code, notes, and snippets.

@samuela
Created March 29, 2022 18:49
Show Gist options
  • Save samuela/2cbb09f8f635bde4dab63cfc6ee105ed to your computer and use it in GitHub Desktop.
Save samuela/2cbb09f8f635bde4dab63cfc6ee105ed to your computer and use it in GitHub Desktop.
============================= test session starts ==============================
platform linux -- Python 3.9.10, pytest-6.2.5, py-1.11.0, pluggy-1.0.0
rootdir: /build/source
plugins: hypothesis-6.35.0
collected 140 items
tests/test_optimizer.py .. [ 1%]
tests/test_rng_seq.py .. [ 2%]
tests/test_treex.py ........................................ [ 31%]
tests/losses/cosine_similarity_test.py ... [ 33%]
tests/losses/crossentropy_test.py . [ 34%]
tests/losses/huber_test.py ... [ 36%]
tests/losses/loss_test.py .. [ 37%]
tests/losses/mean_absolute_error_test.py .. [ 39%]
tests/losses/mean_absolute_percentage_error_test.py .. [ 40%]
tests/losses/mean_squared_error_test.py .. [ 42%]
tests/losses/mean_squared_logarithmic_error_test.py ... [ 44%]
tests/metrics/test_accuracy.py FF [ 45%]
tests/metrics/test_loss_and_logs.py FF [ 47%]
tests/metrics/test_losses.py FFFF [ 50%]
tests/metrics/test_metric.py .F.F [ 52%]
tests/metrics/test_metrics.py FFFF [ 55%]
tests/nn/test_conv.py .......... [ 62%]
tests/nn/test_dropout.py ...... [ 67%]
tests/nn/test_embed.py ..... [ 70%]
tests/nn/test_flax_module.py ... [ 72%]
tests/nn/test_linear.py ....... [ 77%]
tests/nn/test_mlp.py .. [ 79%]
tests/nn/test_norm.py .................... [ 93%]
tests/nn/test_recurrent.py ....... [ 98%]
tests/nn/test_sequential.py .. [100%]
=================================== FAILURES ===================================
____________________________ TestAccuracy.test_jit _____________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=0 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_jit>, 'nextitem': <Function test_logits_preds>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_jit>, 'nextitem': <Function test_logits_preds>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_jit>, 'nextitem': <Function test_logits_preds>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, nextitem = <Function test_logits_preds>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, log = True, nextitem = <Function test_logits_preds>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=6>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff867d4b80>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_jit>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_jit>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_jit>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_jit>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_jit>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_jit>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_jit>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_jit>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_accuracy.TestAccuracy object at 0x7fff845bb190>
def test_jit(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metric = Accuracy(num_classes=10)
target = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
preds = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metric = f(metric, target, preds)
tests/metrics/test_accuracy.py:27:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy....ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy....ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'treex.metrics.accuracy.Accuracy'>[{'_mutable': False, '_field_metadata': {'_initial_sta...ASS: 3>}], [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {}))
args_flat = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.accuracy.Accuracy'>[{'_mutable': F...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74762f70>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.accuracy.Accuracy'>[{'_mutable': F...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74762e50>
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ... False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ... {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ..., [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_accuracy.py:16', traced_for='jit', arg_info=functools.p...], [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy....[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
m = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
target = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
> m(target=target, preds=preds)
tests/metrics/test_accuracy.py:20:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
args = ()
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
module = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74762c10>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_accuracy.TestAccuracy object at 0x7fff845bb190>
def test_jit(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metric = Accuracy(num_classes=10)
target = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
preds = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metric = f(metric, target, preds)
tests/metrics/test_accuracy.py:27:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_accuracy.py:20: in f
m(target=target, preds=preds)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74762c10>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
________________________ TestAccuracy.test_logits_preds ________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=1 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=1 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=1 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=1 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=1 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_logits_preds>, 'nextitem': <Function test_basic>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_logits_preds>, 'nextitem': <Function test_basic>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_logits_preds>, 'nextitem': <Function test_basic>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_logits_preds>, nextitem = <Function test_basic>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_logits_preds>, log = True
nextitem = <Function test_basic>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_logits_preds>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=6>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_logits_preds>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff745365e0>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_logits_preds>}, argname = 'item'
firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_logits_preds>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_logits_preds>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_logits_preds>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_logits_preds>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_logits_preds>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_logits_preds>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_logits_preds>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_logits_preds>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_accuracy.TestAccuracy object at 0x7fff5c0e97f0>
def test_logits_preds(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metric = Accuracy()
target = jnp.array([0, 0, 1, 1, 1])
preds = jnp.array(
[
[10.0, 0.0, 0.0],
[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[40.0, 10.0, 0.0],
[0.0, 10.0, 0.0],
]
)
> metric = f(metric, target, preds)
tests/metrics/test_accuracy.py:57:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy...., 0., 0.],
[ 0., 10., 0.],
[40., 10., 0.],
[ 0., 10., 0.]], dtype=float32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy...., 0., 0.],
[ 0., 10., 0.],
[40., 10., 0.],
[ 0., 10., 0.]], dtype=float32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'treex.metrics.accuracy.Accuracy'>[{'_mutable': False, '_field_metadata': {'_initial_sta...ASS: 3>}], [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {}))
args_flat = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[10., 0., 0.],
[10., 0., 0.],
[ 0., 10., 0.],
[40., 10., 0.],
[ 0., 10., 0.]], dtype=float32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.accuracy.Accuracy'>[{'_mutable': F...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff5c646310>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.accuracy.Accuracy'>[{'_mutable': F...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff5c646550>
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ... False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ... {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ..., [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_accuracy.py:38', traced_for='jit', arg_info=functools.p...], [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...[{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}]), *, *), {})),)
Core: f
args = (Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy....rray(int32[5])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[5,3])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
m = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
target = Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(float32[5,3])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
> m(target=target, preds=preds)
tests/metrics/test_accuracy.py:42:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
args = ()
kwargs = {'preds': Traced<ShapedArray(float32[5,3])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=0/1)>}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
kwargs = {'preds': Traced<ShapedArray(float32[5,3])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace(level=0/1)>}
module = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff5c6468b0>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_accuracy.TestAccuracy object at 0x7fff5c0e97f0>
def test_logits_preds(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metric = Accuracy()
target = jnp.array([0, 0, 1, 1, 1])
preds = jnp.array(
[
[10.0, 0.0, 0.0],
[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[40.0, 10.0, 0.0],
[0.0, 10.0, 0.0],
]
)
> metric = f(metric, target, preds)
tests/metrics/test_accuracy.py:57:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_accuracy.py:42: in f
m(target=target, preds=preds)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Accuracy {
average: AverageMethod.MICRO, AverageMethod
dtype: <class 'jax._src.numpy.l...32), MetricState
top_k: None,
tp: jax.DynamicJaxprTracer((), uint32), MetricState
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff5c6468b0>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
__________________________ TestLossAndLogs.test_basic __________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=2 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=2 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=2 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=2 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=2 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_batch_loss>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_batch_loss>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_batch_loss>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, nextitem = <Function test_batch_loss>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, log = True, nextitem = <Function test_batch_loss>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff845eeb80>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_basic>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_basic>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_basic>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_basic>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_basic>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_basic>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_basic>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_basic>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_loss_and_logs.TestLossAndLogs object at 0x7fff5c0d92e0>
def test_basic(self):
class MyModule(tx.Module):
aux_loss: jnp.ndarray = tx.LossLog.node()
aux_metric: jnp.ndarray = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux_loss = jnp.array(1.0, jnp.float32)
self.aux_metric = jnp.array(2.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(
module: MyModule,
metrics: tx.metrics.LossAndLogs,
target,
preds,
y_true_metrics,
y_pred_metrics,
):
nonlocal N
N += 1
metrics(
target=target,
preds=preds,
metrics_kwargs=dict(
target=y_true_metrics,
preds=y_pred_metrics,
),
aux_losses=module.loss_logs(),
aux_metrics=module.metric_logs(),
)
return metrics
module = MyModule()
metrics = tx.metrics.LossAndLogs(
losses=[
tx.losses.MeanSquaredError(),
tx.losses.MeanSquaredError(),
],
metrics=dict(
a=tx.metrics.Accuracy(num_classes=10),
b=tx.metrics.Accuracy(num_classes=10),
),
aux_losses=tx.metrics.AuxLosses(module.loss_logs()),
aux_metrics=tx.metrics.AuxMetrics(module.metric_logs()),
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
y_true_metrics = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
y_pred_metrics = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metrics = f(module, metrics, target, preds, y_true_metrics, y_pred_metrics)
tests/metrics/test_loss_and_logs.py:65:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux_loss: jaxlib.DeviceArray((), float32), LossLog
aux_metric: jaxlib.DeviceArray((), float32...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux_loss: jaxlib.DeviceArray((), float32), LossLog
aux_metric: jaxlib.DeviceArray((), float32...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'test_loss_and_logs.TestLossAndLogs.test_basic.<locals>.MyModule'>[{'_mutable': False, '...'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {}))
args_flat = [DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_loss_and_logs.TestLossAndLogs.test_basic.<l...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff742b0670>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_loss_and_logs.TestLossAndLogs.test_basic.<l...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff742b0820>
tracers = [DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
tracers = [DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ...alse, False, ...), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...{})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(float32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ..._state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_loss_and_logs.py:23', traced_for='jit', arg_info=functo...l_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: jax.DynamicJaxprTracer((),...[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module = MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: jax.DynamicJaxprTracer((), ...cLog
name: "my_module", str
some_value: jax.DynamicJaxprTracer((), float32),
}
metrics = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
target = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
y_true_metrics = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
y_pred_metrics = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(
module: MyModule,
metrics: tx.metrics.LossAndLogs,
target,
preds,
y_true_metrics,
y_pred_metrics,
):
nonlocal N
N += 1
> metrics(
target=target,
preds=preds,
metrics_kwargs=dict(
target=y_true_metrics,
preds=y_pred_metrics,
),
aux_losses=module.loss_logs(),
aux_metrics=module.metric_logs(),
)
tests/metrics/test_loss_and_logs.py:34:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
args = ()
kwargs = {'aux_losses': MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: Nothing,
n...<DynamicJaxprTrace(level=0/1)>}, 'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, ...}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
metrics_kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
aux_losses = MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: Nothing,
name: "my_module", str
some_value: Nothing,
}
aux_metrics = MyModule {
aux_loss: Nothing,
aux_metric: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", str
some_value: Nothing,
}
losses_kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
def __call__(
self,
metrics_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
aux_losses: tp.Optional[tp.Any] = None,
aux_metrics: tp.Optional[tp.Any] = None,
**losses_kwargs,
) -> tp.Tuple[jnp.ndarray, Logs, Logs]:
> return super().__call__(
metrics_kwargs=metrics_kwargs,
aux_losses=aux_losses,
aux_metrics=aux_metrics,
**losses_kwargs,
)
treex/metrics/loss_and_logs.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
kwargs = {'aux_losses': MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: Nothing,
n...<DynamicJaxprTrace(level=0/1)>}, 'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, ...}
module = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff742b0a60>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_loss_and_logs.TestLossAndLogs object at 0x7fff5c0d92e0>
def test_basic(self):
class MyModule(tx.Module):
aux_loss: jnp.ndarray = tx.LossLog.node()
aux_metric: jnp.ndarray = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux_loss = jnp.array(1.0, jnp.float32)
self.aux_metric = jnp.array(2.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(
module: MyModule,
metrics: tx.metrics.LossAndLogs,
target,
preds,
y_true_metrics,
y_pred_metrics,
):
nonlocal N
N += 1
metrics(
target=target,
preds=preds,
metrics_kwargs=dict(
target=y_true_metrics,
preds=y_pred_metrics,
),
aux_losses=module.loss_logs(),
aux_metrics=module.metric_logs(),
)
return metrics
module = MyModule()
metrics = tx.metrics.LossAndLogs(
losses=[
tx.losses.MeanSquaredError(),
tx.losses.MeanSquaredError(),
],
metrics=dict(
a=tx.metrics.Accuracy(num_classes=10),
b=tx.metrics.Accuracy(num_classes=10),
),
aux_losses=tx.metrics.AuxLosses(module.loss_logs()),
aux_metrics=tx.metrics.AuxMetrics(module.metric_logs()),
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
y_true_metrics = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
y_pred_metrics = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metrics = f(module, metrics, target, preds, y_true_metrics, y_pred_metrics)
tests/metrics/test_loss_and_logs.py:65:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_loss_and_logs.py:34: in f
metrics(
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/loss_and_logs.py:126: in __call__
return super().__call__(
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff742b0a60>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
_______________________ TestLossAndLogs.test_batch_loss ________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=3 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=3 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=3 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=3 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=3 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_batch_loss>, 'nextitem': <Function test_list>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_batch_loss>, 'nextitem': <Function test_list>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_batch_loss>, 'nextitem': <Function test_list>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_batch_loss>, nextitem = <Function test_list>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_batch_loss>, log = True, nextitem = <Function test_list>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_batch_loss>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=8>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_batch_loss>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff7453f550>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_batch_loss>}, argname = 'item'
firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_batch_loss>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_batch_loss>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_batch_loss>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_batch_loss>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_batch_loss>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_batch_loss>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_batch_loss>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_batch_loss>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_loss_and_logs.TestLossAndLogs object at 0x7fff5c318be0>
def test_batch_loss(self):
class MyModule(tx.Module):
aux_loss: jnp.ndarray = tx.LossLog.node()
aux_metric: jnp.ndarray = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux_loss = jnp.array(1.0, jnp.float32)
self.aux_metric = jnp.array(2.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(
module: MyModule,
metrics: tx.metrics.LossAndLogs,
target,
preds,
y_true_metrics,
y_pred_metrics,
):
nonlocal N
N += 1
loss, losses_logs, metrics_logs = metrics.batch_loss_epoch_logs(
target=target,
preds=preds,
metrics_kwargs=dict(
target=y_true_metrics,
preds=y_pred_metrics,
),
aux_losses=module.loss_logs(),
aux_metrics=module.metric_logs(),
)
logs = {**losses_logs, **metrics_logs}
return loss, logs, metrics
module = MyModule()
metrics = tx.metrics.LossAndLogs(
losses=[
tx.losses.MeanSquaredError(),
tx.losses.MeanSquaredError(),
],
metrics=dict(
a=tx.metrics.Accuracy(num_classes=10),
b=tx.metrics.Accuracy(num_classes=10),
),
aux_losses=tx.metrics.AuxLosses(module.loss_logs()),
aux_metrics=tx.metrics.AuxMetrics(module.metric_logs()),
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
y_true_metrics = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
y_pred_metrics = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> loss, logs, metrics = f(
module, metrics, target, preds, y_true_metrics, y_pred_metrics
)
tests/metrics/test_loss_and_logs.py:158:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux_loss: jaxlib.DeviceArray((), float32), LossLog
aux_metric: jaxlib.DeviceArray((), float32...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux_loss: jaxlib.DeviceArray((), float32), LossLog
aux_metric: jaxlib.DeviceArray((), float32...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'test_loss_and_logs.TestLossAndLogs.test_batch_loss.<locals>.MyModule'>[{'_mutable': Fal...'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {}))
args_flat = [DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_loss_and_logs.TestLossAndLogs.test_batch_lo...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff742b08b0>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_loss_and_logs.TestLossAndLogs.test_batch_lo...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff742b0940>
tracers = [DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
tracers = [DeviceArray(1., dtype=float32), DeviceArray(2., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ...alse, False, ...), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...{})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(float32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ..._state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_loss_and_logs.py:115', traced_for='jit', arg_info=funct...l_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...tate': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}])}]), *, *, *, *), {})),)
Core: f
args = (MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: jax.DynamicJaxprTracer((),...[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module = MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: jax.DynamicJaxprTracer((), ...cLog
name: "my_module", str
some_value: jax.DynamicJaxprTracer((), float32),
}
metrics = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
target = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
y_true_metrics = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
y_pred_metrics = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(
module: MyModule,
metrics: tx.metrics.LossAndLogs,
target,
preds,
y_true_metrics,
y_pred_metrics,
):
nonlocal N
N += 1
> loss, losses_logs, metrics_logs = metrics.batch_loss_epoch_logs(
target=target,
preds=preds,
metrics_kwargs=dict(
target=y_true_metrics,
preds=y_pred_metrics,
),
aux_losses=module.loss_logs(),
aux_metrics=module.metric_logs(),
)
tests/metrics/test_loss_and_logs.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
metrics_kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
aux_losses = MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: Nothing,
name: "my_module", str
some_value: Nothing,
}
aux_metrics = MyModule {
aux_loss: Nothing,
aux_metric: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", str
some_value: Nothing,
}
losses_kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
def batch_loss_epoch_logs(
self,
metrics_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
aux_losses: tp.Optional[tp.Any] = None,
aux_metrics: tp.Optional[tp.Any] = None,
**losses_kwargs,
) -> tp.Tuple[jnp.ndarray, Logs, Logs]:
> batch_loss, *_ = self(
metrics_kwargs=metrics_kwargs,
aux_losses=aux_losses,
aux_metrics=aux_metrics,
**losses_kwargs,
)
treex/metrics/loss_and_logs.py:140:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
args = ()
kwargs = {'aux_losses': MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: Nothing,
n...<DynamicJaxprTrace(level=0/1)>}, 'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, ...}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
metrics_kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
aux_losses = MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: Nothing,
name: "my_module", str
some_value: Nothing,
}
aux_metrics = MyModule {
aux_loss: Nothing,
aux_metric: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", str
some_value: Nothing,
}
losses_kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
def __call__(
self,
metrics_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
aux_losses: tp.Optional[tp.Any] = None,
aux_metrics: tp.Optional[tp.Any] = None,
**losses_kwargs,
) -> tp.Tuple[jnp.ndarray, Logs, Logs]:
> return super().__call__(
metrics_kwargs=metrics_kwargs,
aux_losses=aux_losses,
aux_metrics=aux_metrics,
**losses_kwargs,
)
treex/metrics/loss_and_logs.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
kwargs = {'aux_losses': MyModule {
aux_loss: jax.DynamicJaxprTracer((), float32), LossLog
aux_metric: Nothing,
n...<DynamicJaxprTrace(level=0/1)>}, 'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, ...}
module = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff5c0d3ca0>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_loss_and_logs.TestLossAndLogs object at 0x7fff5c318be0>
def test_batch_loss(self):
class MyModule(tx.Module):
aux_loss: jnp.ndarray = tx.LossLog.node()
aux_metric: jnp.ndarray = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux_loss = jnp.array(1.0, jnp.float32)
self.aux_metric = jnp.array(2.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(
module: MyModule,
metrics: tx.metrics.LossAndLogs,
target,
preds,
y_true_metrics,
y_pred_metrics,
):
nonlocal N
N += 1
loss, losses_logs, metrics_logs = metrics.batch_loss_epoch_logs(
target=target,
preds=preds,
metrics_kwargs=dict(
target=y_true_metrics,
preds=y_pred_metrics,
),
aux_losses=module.loss_logs(),
aux_metrics=module.metric_logs(),
)
logs = {**losses_logs, **metrics_logs}
return loss, logs, metrics
module = MyModule()
metrics = tx.metrics.LossAndLogs(
losses=[
tx.losses.MeanSquaredError(),
tx.losses.MeanSquaredError(),
],
metrics=dict(
a=tx.metrics.Accuracy(num_classes=10),
b=tx.metrics.Accuracy(num_classes=10),
),
aux_losses=tx.metrics.AuxLosses(module.loss_logs()),
aux_metrics=tx.metrics.AuxMetrics(module.metric_logs()),
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
y_true_metrics = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
y_pred_metrics = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> loss, logs, metrics = f(
module, metrics, target, preds, y_true_metrics, y_pred_metrics
)
tests/metrics/test_loss_and_logs.py:158:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_loss_and_logs.py:126: in f
loss, losses_logs, metrics_logs = metrics.batch_loss_epoch_logs(
treex/metrics/loss_and_logs.py:140: in batch_loss_epoch_logs
batch_loss, *_ = self(
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/loss_and_logs.py:126: in __call__
return super().__call__(
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = LossAndLogs {
aux_losses: AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint...name: "loss_and_logs", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff5c0d3ca0>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
_____________________________ TestLosses.test_list _____________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=4 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=4 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=4 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=4 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=4 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_list>, 'nextitem': <Function test_dict>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_list>, 'nextitem': <Function test_dict>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_list>, 'nextitem': <Function test_dict>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, nextitem = <Function test_dict>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, log = True, nextitem = <Function test_dict>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff74438d30>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_list>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_list>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_list>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_list>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_list>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_list>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_list>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_list>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_losses.TestLosses object at 0x7fff5c7b2ac0>
def test_list(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
losses = tx.metrics.Losses(
[
tx.losses.MeanSquaredError(),
tx.losses.MeanSquaredError(),
]
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
> losses = f(losses, target, preds)
tests/metrics/test_losses.py:30:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Losses {
counts: dict {
mean_squared_error_loss: jaxlib.DeviceArray((), uint32), ...State
},
}, DeviceArray([[[[0., 0., 0., 0.]]]], dtype=float32), DeviceArray([[[[1., 1., 1., 1.]]]], dtype=float32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Losses {
counts: dict {
mean_squared_error_loss: jaxlib.DeviceArray((), uint32), ...State
},
}, DeviceArray([[[[0., 0., 0., 0.]]]], dtype=float32), DeviceArray([[[[1., 1., 1., 1.]]]], dtype=float32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'treex.metrics.losses.Losses'>[{'_mutable': False, '_field_metadata': {'_initial_state':...'mean_squared_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {}))
args_flat = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[[[1., 1., 1., 1.]]]], dtype=float32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.losses.Losses'>[{'_mutable': False...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff743d6940>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.losses.Losses'>[{'_mutable': False...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff5c0ba9d0>
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ... False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ... {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...ared_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_losses.py:14', traced_for='jit', arg_info=functools.par...uared_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ed_error_loss2': *}, 'totals': {'mean_squared_error_loss': *, 'mean_squared_error_loss2': *}}]), *, *), {})),)
Core: f
args = (Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
m = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
target = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
> m(target=target, preds=preds)
tests/metrics/test_losses.py:18:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
args = ()
kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
def __call__(self, **kwargs) -> tp.Tuple[jnp.ndarray, tp.Dict[str, jnp.ndarray]]:
> return super().__call__(**kwargs)
treex/metrics/losses.py:70:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
module = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff5c0ba5e0>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_losses.TestLosses object at 0x7fff5c7b2ac0>
def test_list(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
losses = tx.metrics.Losses(
[
tx.losses.MeanSquaredError(),
tx.losses.MeanSquaredError(),
]
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
> losses = f(losses, target, preds)
tests/metrics/test_losses.py:30:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_losses.py:18: in f
m(target=target, preds=preds)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/losses.py:70: in __call__
return super().__call__(**kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...uared_error_loss2: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff5c0ba5e0>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
_____________________________ TestLosses.test_dict _____________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=5 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=5 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=5 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=5 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=5 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_dict>, 'nextitem': <Function test_basic>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_dict>, 'nextitem': <Function test_basic>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_dict>, 'nextitem': <Function test_basic>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, nextitem = <Function test_basic>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, log = True, nextitem = <Function test_basic>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff744389d0>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_dict>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_dict>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_dict>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_dict>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_dict>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_dict>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_dict>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_dict>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_losses.TestLosses object at 0x7fff7412dc10>
def test_dict(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
losses = tx.metrics.Losses(
dict(
a=tx.losses.MeanSquaredError(),
b=tx.losses.MeanSquaredError(),
)
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
> losses = f(losses, target, preds)
tests/metrics/test_losses.py:70:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Losses {
counts: dict {
a/mean_squared_error_loss: jaxlib.DeviceArray((), uint32), ...State
},
}, DeviceArray([[[[0., 0., 0., 0.]]]], dtype=float32), DeviceArray([[[[1., 1., 1., 1.]]]], dtype=float32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Losses {
counts: dict {
a/mean_squared_error_loss: jaxlib.DeviceArray((), uint32), ...State
},
}, DeviceArray([[[[0., 0., 0., 0.]]]], dtype=float32), DeviceArray([[[[1., 1., 1., 1.]]]], dtype=float32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'treex.metrics.losses.Losses'>[{'_mutable': False, '_field_metadata': {'_initial_state':...ean_squared_error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {}))
args_flat = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[[[1., 1., 1., 1.]]]], dtype=float32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.losses.Losses'>[{'_mutable': False..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74040a60>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.losses.Losses'>[{'_mutable': False..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74040d30>
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ... False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ... {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...ed_error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_losses.py:54', traced_for='jit', arg_info=functools.par...red_error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (..._error_loss': *}, 'totals': {'a/mean_squared_error_loss': *, 'b/mean_squared_error_loss': *}}]), *, *), {})),)
Core: f
args = (Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
m = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
target = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
> m(target=target, preds=preds)
tests/metrics/test_losses.py:58:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
args = ()
kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
def __call__(self, **kwargs) -> tp.Tuple[jnp.ndarray, tp.Dict[str, jnp.ndarray]]:
> return super().__call__(**kwargs)
treex/metrics/losses.py:70:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'preds': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(float32[1,1,1,4])>with<DynamicJaxprTrace(level=0/1)>}
module = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74040f70>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_losses.TestLosses object at 0x7fff7412dc10>
def test_dict(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
losses = tx.metrics.Losses(
dict(
a=tx.losses.MeanSquaredError(),
b=tx.losses.MeanSquaredError(),
)
)
target = jnp.array([0.0, 0.0, 0.0, 0.0])[None, None, None, :]
preds = jnp.array([1.0, 1.0, 1.0, 1.0])[None, None, None, :]
> losses = f(losses, target, preds)
tests/metrics/test_losses.py:70:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_losses.py:58: in f
m(target=target, preds=preds)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/losses.py:70: in __call__
return super().__call__(**kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Losses {
counts: dict {
a/mean_squared_error_loss: jax.DynamicJaxprTracer((), uint32), ...quared_error_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74040f70>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
___________________________ TestAuxLosses.test_basic ___________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=6 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=6 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=6 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=6 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=6 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_named>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_named>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_named>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, nextitem = <Function test_named>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, log = True, nextitem = <Function test_named>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff74438d30>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_basic>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_basic>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_basic>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_basic>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_basic>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_basic>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_basic>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_basic>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_losses.TestAuxLosses object at 0x7fff7412d040>
def test_basic(self):
class MyModule(tx.Module):
aux: jnp.ndarray = tx.LossLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = jnp.array(1.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_losses: tx.metrics.AuxLosses):
nonlocal N
N += 1
loss_logs = module.filter(tx.LossLog)
aux_losses(aux_losses=loss_logs)
return aux_losses
module = MyModule()
loss_logs = module.filter(tx.LossLog)
losses = tx.metrics.AuxLosses(loss_logs)
> losses = f(module, losses)
tests/metrics/test_losses.py:116:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: jaxlib.DeviceArray((), float32), LossLog
name: "my_module", ... str
totals: dict {
aux_loss: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: jaxlib.DeviceArray((), float32), LossLog
name: "my_module", ... str
totals: dict {
aux_loss: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'test_losses.TestAuxLosses.test_basic.<locals>.MyModule'>[{'_mutable': False, '_field_me... {'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {}))
args_flat = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
donated_invars = (False, False, False, False, False, False)
arg = DeviceArray(0., dtype=float32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_losses.TestAuxLosses.test_basic.<locals>.My... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff740790d0>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_losses.TestAuxLosses.test_basic.<locals>.My... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74079280>
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ...se, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...])), {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
abstract_args = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
arg_devices = (None, None, None, None, None, None)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...': {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_losses.py:103', traced_for='jit', arg_info=functools.pa...s': {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (... {'aux_loss': *}, 'totals': {'aux_loss': *}}, 'counts': {'aux_loss': *}, 'totals': {'aux_loss': *}}])), {})),)
Core: f
args = (MyModule {
aux: jax.DynamicJaxprTracer((), float32), LossLog
name: "my_module", ... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
})
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module = MyModule {
aux: jax.DynamicJaxprTracer((), float32), LossLog
name: "my_module", str
some_value: jax.DynamicJaxprTracer((), float32),
}
aux_losses = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
@jax.jit
def f(module: MyModule, aux_losses: tx.metrics.AuxLosses):
nonlocal N
N += 1
loss_logs = module.filter(tx.LossLog)
> aux_losses(aux_losses=loss_logs)
tests/metrics/test_losses.py:108:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
args = ()
kwargs = {'aux_losses': MyModule {
aux: jax.DynamicJaxprTracer((), float32), LossLog
name: "my_module", str
some_value: Nothing,
}}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
aux_losses = MyModule {
aux: jax.DynamicJaxprTracer((), float32), LossLog
name: "my_module", str
some_value: Nothing,
}
def __call__(
self, aux_losses: tp.Any
) -> tp.Tuple[jnp.ndarray, tp.Dict[str, jnp.ndarray]]:
> return super().__call__(aux_losses=aux_losses)
treex/metrics/losses.py:110:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'aux_losses': MyModule {
aux: jax.DynamicJaxprTracer((), float32), LossLog
name: "my_module", str
some_value: Nothing,
}}
module = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74079430>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_losses.TestAuxLosses object at 0x7fff7412d040>
def test_basic(self):
class MyModule(tx.Module):
aux: jnp.ndarray = tx.LossLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = jnp.array(1.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_losses: tx.metrics.AuxLosses):
nonlocal N
N += 1
loss_logs = module.filter(tx.LossLog)
aux_losses(aux_losses=loss_logs)
return aux_losses
module = MyModule()
loss_logs = module.filter(tx.LossLog)
losses = tx.metrics.AuxLosses(loss_logs)
> losses = f(module, losses)
tests/metrics/test_losses.py:116:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_losses.py:108: in f
aux_losses(aux_losses=loss_logs)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/losses.py:110: in __call__
return super().__call__(aux_losses=aux_losses)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
aux_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
aux_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74079430>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
___________________________ TestAuxLosses.test_named ___________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=7 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=7 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=7 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=7 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=7 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_named>, 'nextitem': <Function test_basic>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_named>, 'nextitem': <Function test_basic>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_named>, 'nextitem': <Function test_basic>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, nextitem = <Function test_basic>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, log = True, nextitem = <Function test_basic>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff7413ba60>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_named>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_named>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_named>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_named>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_named>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_named>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_named>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_named>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_losses.TestAuxLosses object at 0x7fff5c4e8760>
def test_named(self):
class MyModule(tx.Module):
aux: tx.Named[jnp.ndarray] = tx.LossLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = tx.Named("my_loss", jnp.array(1.0, jnp.float32))
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_losses: tx.metrics.AuxLosses):
nonlocal N
N += 1
loss_logs = module.filter(tx.LossLog)
aux_losses(aux_losses=loss_logs)
return aux_losses
module = MyModule()
loss_logs = module.filter(tx.LossLog)
losses = tx.metrics.AuxLosses(loss_logs)
> losses = f(module, losses)
tests/metrics/test_losses.py:150:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: Named(name='my_loss', value=FieldInfo(name='aux', value=DeviceArray(1., dtype=float32), kind=<cla... str
totals: dict {
my_loss: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: Named(name='my_loss', value=FieldInfo(name='aux', value=DeviceArray(1., dtype=float32), kind=<cla... str
totals: dict {
my_loss: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'test_losses.TestAuxLosses.test_named.<locals>.MyModule'>[{'_mutable': False, '_field_me...te': {'counts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {}))
args_flat = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
donated_invars = (False, False, False, False, False, False)
arg = DeviceArray(0., dtype=float32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_losses.TestAuxLosses.test_named.<locals>.My...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74079f70>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_losses.TestAuxLosses.test_named.<locals>.My...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74079700>
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ...se, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...])), {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
abstract_args = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
arg_devices = (None, None, None, None, None, None)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...unts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_losses.py:137', traced_for='jit', arg_info=functools.pa...ounts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ts': {'my_loss': *}, 'totals': {'my_loss': *}}, 'counts': {'my_loss': *}, 'totals': {'my_loss': *}}])), {})),)
Core: f
args = (MyModule {
aux: Named(name='my_loss', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[])>with<Dynami... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
})
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module = MyModule {
aux: Named(name='my_loss', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[])>with<Dynamic... str
some_value: jax.DynamicJaxprTracer((), float32),
}
aux_losses = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
@jax.jit
def f(module: MyModule, aux_losses: tx.metrics.AuxLosses):
nonlocal N
N += 1
loss_logs = module.filter(tx.LossLog)
> aux_losses(aux_losses=loss_logs)
tests/metrics/test_losses.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
args = ()
kwargs = {'aux_losses': MyModule {
aux: Named(name='my_loss', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[... str
some_value: Nothing,
}}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
aux_losses = MyModule {
aux: Named(name='my_loss', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[])>with<Dynamic... str
some_value: Nothing,
}
def __call__(
self, aux_losses: tp.Any
) -> tp.Tuple[jnp.ndarray, tp.Dict[str, jnp.ndarray]]:
> return super().__call__(aux_losses=aux_losses)
treex/metrics/losses.py:110:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'aux_losses': MyModule {
aux: Named(name='my_loss', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[... str
some_value: Nothing,
}}
module = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74079670>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_losses.TestAuxLosses object at 0x7fff5c4e8760>
def test_named(self):
class MyModule(tx.Module):
aux: tx.Named[jnp.ndarray] = tx.LossLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = tx.Named("my_loss", jnp.array(1.0, jnp.float32))
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_losses: tx.metrics.AuxLosses):
nonlocal N
N += 1
loss_logs = module.filter(tx.LossLog)
aux_losses(aux_losses=loss_logs)
return aux_losses
module = MyModule()
loss_logs = module.filter(tx.LossLog)
losses = tx.metrics.AuxLosses(loss_logs)
> losses = f(module, losses)
tests/metrics/test_losses.py:150:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_losses.py:142: in f
aux_losses(aux_losses=loss_logs)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/losses.py:110: in __call__
return super().__call__(aux_losses=aux_losses)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxLosses {
counts: dict {
my_loss: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype:... str
totals: dict {
my_loss: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74079670>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
______________________________ TestMetric.test_on ______________________________
self = <test_metric.TestMetric object at 0x7fff5c2cbbe0>
def test_on(self):
class MyMetric(tx.Metric):
def update(self, target, preds):
self.target = target
self.preds = preds
return self
def compute(self):
return self.target, self.preds
metric = MyMetric(on=("a", 0))
target = {"a": [10]}
preds = {"a": [20]}
> assert metric(target=target, preds=preds) == (10, 20)
tests/metrics/test_metric.py:45:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = MyMetric {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", ...eds: 20, int
target: 10, int
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff7413baf0>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
_____________________________ TestMetric.test_jit ______________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=9 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=9 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=9 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=9 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=9 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_jit>, 'nextitem': <Function test_list>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_jit>, 'nextitem': <Function test_list>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_jit>, 'nextitem': <Function test_list>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, nextitem = <Function test_list>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, log = True, nextitem = <Function test_list>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=6>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff742b03a0>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_jit>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_jit>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_jit>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_jit>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_jit>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_jit>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_jit>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_jit>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_jit>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_metric.TestMetric object at 0x7fff4c3d9100>
def test_jit(self):
class MyMetric(tx.Metric):
a: int = tx.MetricState.node()
def __init__(self) -> None:
self.a = 0
super().__init__()
def update(self, n):
self.a += n
def compute(self):
return self.a
N = 0
@jax.jit
def f(m):
nonlocal N
N += 1
m(n=2)
return m
metric = MyMetric()
> metric = f(metric)
tests/metrics/test_metric.py:90:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyMetric {
a: 0, MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
},)
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyMetric {
a: 0, MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
},)
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'test_metric.TestMetric.test_jit.<locals>.MyMetric'>[{'_mutable': False, '_field_metadat...e': 'my_metric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {}))
args_flat = [0, 0], donated_invars = (False, False), arg = 0
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_metric.TestMetric.test_jit.<locals>.MyMetri...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74079430>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_metric.TestMetric.test_jit.<locals>.MyMetri...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
args = (0, 0)
params = {'backend': None, 'device': None, 'donated_invars': (False, False), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
args = (0, 0)
params = {'backend': None, 'device': None, 'donated_invars': (False, False), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74370550>
tracers = [0, 0]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
tracers = [0, 0]
params = {'backend': None, 'device': None, 'donated_invars': (False, False), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
device = None, backend = None, name = 'f', donated_invars = (False, False)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
args = (None, None, 'f', (False, False), (ShapedArray(int32[], weak_type=True), None), (ShapedArray(int32[], weak_type=True), None))
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ...pedArray(int32[], weak_type=True), None), (ShapedArray(int32[], weak_type=True), None)), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
device = None, backend = None, name = 'f', donated_invars = (False, False)
arg_specs = ((ShapedArray(int32[], weak_type=True), None), (ShapedArray(int32[], weak_type=True), None))
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...: *}, 'a': *}]),), {})),)
Core: f
, None, None, 'f', (False, False), (ShapedArray(int32[], weak_type=True), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
device = None, backend = None, name = 'f', donated_invars = (False, False)
arg_specs = ((ShapedArray(int32[], weak_type=True), None), (ShapedArray(int32[], weak_type=True), None))
abstract_args = (ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True))
arg_devices = (None, None)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...etric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
in_avals = (ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True))
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_metric.py:81', traced_for='jit', arg_info=functools.par...metric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True))
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...ric', 'dtype': <class 'jax._src.numpy.lax_numpy.float32'>}], [{'_initial_state': {'a': *}, 'a': *}]),), {})),)
Core: f
args = (MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
},)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
m = MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
}
@jax.jit
def f(m):
nonlocal N
N += 1
> m(n=2)
tests/metrics/test_metric.py:85:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
}
args = (), kwargs = {'n': 2}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
}
kwargs = {'n': 2}
module = MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74370e50>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_metric.TestMetric object at 0x7fff4c3d9100>
def test_jit(self):
class MyMetric(tx.Metric):
a: int = tx.MetricState.node()
def __init__(self) -> None:
self.a = 0
super().__init__()
def update(self, n):
self.a += n
def compute(self):
return self.a
N = 0
@jax.jit
def f(m):
nonlocal N
N += 1
m(n=2)
return m
metric = MyMetric()
> metric = f(metric)
tests/metrics/test_metric.py:90:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_metric.py:85: in f
m(n=2)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = MyMetric {
a: jax.DynamicJaxprTracer((), int32), MetricState
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
name: "my_metric", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74370e50>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
____________________________ TestAccuracy.test_list ____________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=10 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=10 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=10 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=10 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=10 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_list>, 'nextitem': <Function test_dict>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_list>, 'nextitem': <Function test_dict>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_list>, 'nextitem': <Function test_dict>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, nextitem = <Function test_dict>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, log = True, nextitem = <Function test_dict>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff740e9a60>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_list>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_list>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_list>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_list>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_list>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_list>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_list>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_list>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_list>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_metrics.TestAccuracy object at 0x7fff4c4eec70>
def test_list(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metrics = tx.metrics.Metrics(
[
tx.metrics.Accuracy(num_classes=10),
tx.metrics.Accuracy(num_classes=10),
]
)
target = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
preds = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metrics = f(metrics, target, preds)
tests/metrics/test_metrics.py:30:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'treex.metrics.metrics.Metrics'>[{'_mutable': False, '_field_metadata': {'_initial_state... 3>}], [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {}))
args_flat = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.metrics.Metrics'>[{'_mutable': Fal...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74075310>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.metrics.Metrics'>[{'_mutable': Fal...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff740753a0>
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ... False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ... {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_metrics.py:14', traced_for='jit', arg_info=functools.pa...{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a...[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
m = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
target = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
> m(target=target, preds=preds)
tests/metrics/test_metrics.py:18:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
args = ()
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
def __call__(self, **kwargs) -> tp.Dict[str, jnp.ndarray]:
> return super().__call__(**kwargs)
treex/metrics/metrics.py:63:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
module = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74075670>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_metrics.TestAccuracy object at 0x7fff4c4eec70>
def test_list(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metrics = tx.metrics.Metrics(
[
tx.metrics.Accuracy(num_classes=10),
tx.metrics.Accuracy(num_classes=10),
]
)
target = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
preds = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metrics = f(metrics, target, preds)
tests/metrics/test_metrics.py:30:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_metrics.py:18: in f
m(target=target, preds=preds)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metrics.py:63: in __call__
return super().__call__(**kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
ac...2), MetricState
},
},
name: "metrics", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74075670>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
____________________________ TestAccuracy.test_dict ____________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=11 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=11 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=11 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=11 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=11 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_dict>, 'nextitem': <Function test_basic>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_dict>, 'nextitem': <Function test_basic>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_dict>, 'nextitem': <Function test_basic>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, nextitem = <Function test_basic>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, log = True, nextitem = <Function test_basic>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff740758b0>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_dict>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_dict>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_dict>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_dict>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_dict>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_dict>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_dict>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_dict>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_dict>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_metrics.TestAccuracy object at 0x7fff4c3a6a00>
def test_dict(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metrics = tx.metrics.Metrics(
dict(
a=tx.metrics.Accuracy(num_classes=10),
b=tx.metrics.Accuracy(num_classes=10),
)
)
target = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
preds = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metrics = f(metrics, target, preds)
tests/metrics/test_metrics.py:58:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a...ray([[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]]], dtype=int32), DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32))
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'treex.metrics.metrics.Metrics'>[{'_mutable': False, '_field_metadata': {'_initial_state... 3>}], [{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {}))
args_flat = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
donated_invars = (False, False, False, False, False, False, ...)
arg = DeviceArray([[[[0, 1, 2, 3, 0, 5, 6, 7, 0, 9]]]], dtype=int32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.metrics.Metrics'>[{'_mutable': Fal...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff5c45a550>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'treex.metrics.metrics.Metrics'>[{'_mutable': Fal...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...)
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff5c45a5e0>
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
tracers = [DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), DeviceArray(0, dtype=uint32), ...]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False, ...), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ... False, False, ...), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ... {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False, ...), (ShapedArray(uint32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False, ...)
arg_specs = ((ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(uint32[]), None), ...)
abstract_args = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
arg_devices = (None, None, None, None, None, None, ...)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_metrics.py:42', traced_for='jit', arg_info=functools.pa...{'_initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ShapedArray(uint32[]), ...)
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...initial_state': {'fn': *, 'fp': *, 'tn': *, 'tp': *}, 'fn': *, 'fp': *, 'tn': *, 'tp': *}])}}]), *, *), {})),)
Core: f
args = (Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a...[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>)
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
m = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
target = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
preds = Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
> m(target=target, preds=preds)
tests/metrics/test_metrics.py:46:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
args = ()
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
def __call__(self, **kwargs) -> tp.Dict[str, jnp.ndarray]:
> return super().__call__(**kwargs)
treex/metrics/metrics.py:63:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
kwargs = {'preds': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>, 'target': Traced<ShapedArray(int32[1,1,1,10])>with<DynamicJaxprTrace(level=0/1)>}
module = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74370e50>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_metrics.TestAccuracy object at 0x7fff4c3a6a00>
def test_dict(self):
N = 0
@jax.jit
def f(m, target, preds):
nonlocal N
N += 1
m(target=target, preds=preds)
return m
metrics = tx.metrics.Metrics(
dict(
a=tx.metrics.Accuracy(num_classes=10),
b=tx.metrics.Accuracy(num_classes=10),
)
)
target = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])[None, None, None, :]
preds = jnp.array([0, 1, 2, 3, 0, 5, 6, 7, 0, 9])[None, None, None, :]
> metrics = f(metrics, target, preds)
tests/metrics/test_metrics.py:58:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_metrics.py:46: in f
m(target=target, preds=preds)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metrics.py:63: in __call__
return super().__call__(**kwargs)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Metrics {
dtype: <class 'jax._src.numpy.lax_numpy.float32'>, _ScalarMeta
metrics: dict {
a/...2), MetricState
},
},
name: "metrics", str
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74370e50>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
__________________________ TestAuxMetrics.test_basic ___________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=12 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=12 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=12 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=12 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=12 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_named>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_named>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_basic>, 'nextitem': <Function test_named>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, nextitem = <Function test_named>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, log = True, nextitem = <Function test_named>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff740e9a60>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_basic>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_basic>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_basic>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_basic>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_basic>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_basic>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_basic>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_basic>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_basic>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_metrics.TestAuxMetrics object at 0x7fff4c3a0c10>
def test_basic(self):
class MyModule(tx.Module):
aux: jnp.ndarray = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = jnp.array(1.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_metrics: tx.metrics.AuxMetrics):
nonlocal N
N += 1
metric_logs = module.filter(tx.MetricLog)
aux_metrics(aux_metrics=metric_logs)
return aux_metrics
module = MyModule()
metric_logs = module.filter(tx.MetricLog)
metrics = tx.metrics.AuxMetrics(metric_logs)
> metrics = f(module, metrics)
tests/metrics/test_metrics.py:92:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: jaxlib.DeviceArray((), float32), MetricLog
name: "my_module", ... str
totals: dict {
aux: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: jaxlib.DeviceArray((), float32), MetricLog
name: "my_module", ... str
totals: dict {
aux: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'test_metrics.TestAuxMetrics.test_basic.<locals>.MyModule'>[{'_mutable': False, '_field_... [{'_initial_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {}))
args_flat = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
donated_invars = (False, False, False, False, False, False)
arg = DeviceArray(0., dtype=float32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_metrics.TestAuxMetrics.test_basic.<locals>....l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff742a54c0>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_metrics.TestAuxMetrics.test_basic.<locals>....l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff742a50d0>
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ...se, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...])), {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
abstract_args = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
arg_devices = (None, None, None, None, None, None)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...ial_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_metrics.py:79', traced_for='jit', arg_info=functools.pa...tial_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...l_state': {'counts': {'aux': *}, 'totals': {'aux': *}}, 'counts': {'aux': *}, 'totals': {'aux': *}}])), {})),)
Core: f
args = (MyModule {
aux: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", ... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
})
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module = MyModule {
aux: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", str
some_value: jax.DynamicJaxprTracer((), float32),
}
aux_metrics = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
@jax.jit
def f(module: MyModule, aux_metrics: tx.metrics.AuxMetrics):
nonlocal N
N += 1
metric_logs = module.filter(tx.MetricLog)
> aux_metrics(aux_metrics=metric_logs)
tests/metrics/test_metrics.py:84:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
args = ()
kwargs = {'aux_metrics': MyModule {
aux: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", str
some_value: Nothing,
}}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
aux_metrics = MyModule {
aux: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", str
some_value: Nothing,
}
def __call__(self, aux_metrics: tp.Any) -> tp.Dict[str, jnp.ndarray]:
> return super().__call__(aux_metrics=aux_metrics)
treex/metrics/metrics.py:100:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'aux_metrics': MyModule {
aux: jax.DynamicJaxprTracer((), float32), MetricLog
name: "my_module", str
some_value: Nothing,
}}
module = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff742a5790>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_metrics.TestAuxMetrics object at 0x7fff4c3a0c10>
def test_basic(self):
class MyModule(tx.Module):
aux: jnp.ndarray = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = jnp.array(1.0, jnp.float32)
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_metrics: tx.metrics.AuxMetrics):
nonlocal N
N += 1
metric_logs = module.filter(tx.MetricLog)
aux_metrics(aux_metrics=metric_logs)
return aux_metrics
module = MyModule()
metric_logs = module.filter(tx.MetricLog)
metrics = tx.metrics.AuxMetrics(metric_logs)
> metrics = f(module, metrics)
tests/metrics/test_metrics.py:92:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_metrics.py:84: in f
aux_metrics(aux_metrics=metric_logs)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metrics.py:100: in __call__
return super().__call__(aux_metrics=aux_metrics)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
aux: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtype... str
totals: dict {
aux: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff742a5790>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
__________________________ TestAuxMetrics.test_named ___________________________
"""The pytest entry point."""
import pytest
if __name__ == "__main__":
> raise SystemExit(pytest.console_main())
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/pytest/__main__.py:5:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
def console_main() -> int:
"""The CLI entry point of pytest.
This function is not meant for programmable use; use `main()` instead.
"""
# https://docs.python.org/3/library/signal.html#note-on-sigpipe
try:
> code = main()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:185:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = None, plugins = None
def main(
args: Optional[Union[List[str], py.path.local]] = None,
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]] = None,
) -> Union[int, ExitCode]:
"""Perform an in-process test run.
:param args: List of command line arguments.
:param plugins: List of plugin objects to be auto-registered during initialization.
:returns: An exit code.
"""
try:
try:
config = _prepareconfig(args, plugins)
except ConftestImportFailure as e:
exc_info = ExceptionInfo(e.excinfo)
tw = TerminalWriter(sys.stderr)
tw.line(f"ImportError while loading conftest '{e.path}'.", red=True)
exc_info.traceback = exc_info.traceback.filter(
filter_traceback_for_conftest_import_failure
)
exc_repr = (
exc_info.getrepr(style="short", chain=False)
if exc_info.traceback
else exc_info.exconly()
)
formatted_tb = str(exc_repr)
for line in formatted_tb.splitlines():
tw.line(line.rstrip(), red=True)
return ExitCode.USAGE_ERROR
else:
try:
> ret: Union[ExitCode, int] = config.hook.pytest_cmdline_main(
config=config
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/config/__init__.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_cmdline_main'>, args = ()
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
argname = 'config', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_cmdline_main'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_cmdline_main'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/setuponly.py'>>, ...]
caller_kwargs = {'config': <_pytest.config.Config object at 0x7ffff69ff370>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
> return wrap_session(config, _main)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:316:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
doit = <function _main at 0x7ffff6cd4a60>
def wrap_session(
config: Config, doit: Callable[[Config, "Session"], Optional[Union[int, ExitCode]]]
) -> Union[int, ExitCode]:
"""Skeleton command line program."""
session = Session.from_config(config)
session.exitstatus = ExitCode.OK
initstate = 0
try:
try:
config._do_configure()
initstate = 1
config.hook.pytest_sessionstart(session=session)
initstate = 2
> session.exitstatus = doit(config, session) or 0
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
config = <_pytest.config.Config object at 0x7ffff69ff370>
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=13 testscollected=140>
def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
"""Default command line protocol for initialization, session,
running tests and reporting."""
config.hook.pytest_collection(session=session)
> config.hook.pytest_runtestloop(session=session)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:323:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtestloop'>, args = ()
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=13 testscollected=140>}
argname = 'session', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtestloop'
methods = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=13 testscollected=140>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtestloop'
hook_impls = [<HookImpl plugin_name='main', plugin=<module '_pytest.main' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3...t/main.py'>>, <HookImpl plugin_name='logging-plugin', plugin=<_pytest.logging.LoggingPlugin object at 0x7ffff68e37c0>>]
caller_kwargs = {'session': <Session source exitstatus=<ExitCode.OK: 0> testsfailed=13 testscollected=140>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
session = <Session source exitstatus=<ExitCode.OK: 0> testsfailed=13 testscollected=140>
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
% (session.testsfailed, "s" if session.testsfailed != 1 else "")
)
if session.config.option.collectonly:
return True
for i, item in enumerate(session.items):
nextitem = session.items[i + 1] if i + 1 < len(session.items) else None
> item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/main.py:348:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_protocol'>, args = ()
kwargs = {'item': <Function test_named>, 'nextitem': <TestCaseFunction test_call>}
argname = 'nextitem', firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_protocol'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
kwargs = {'item': <Function test_named>, 'nextitem': <TestCaseFunction test_call>}
firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_protocol'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...'/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/warnings.py'>>]
caller_kwargs = {'item': <Function test_named>, 'nextitem': <TestCaseFunction test_call>}
firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, nextitem = <TestCaseFunction test_call>
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
> runtestprotocol(item, nextitem=nextitem)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:109:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, log = True
nextitem = <TestCaseFunction test_call>
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request: # type: ignore[attr-defined]
item._initrequest() # type: ignore[attr-defined]
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
if item.config.getoption("setupshow", False):
show_test_item(item)
if not item.config.getoption("setuponly", False):
> reports.append(call_and_report(item, "call", log))
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, when = 'call', log = True, kwds = {}
call = <CallInfo when='call' excinfo=<ExceptionInfo AttributeError("module 'treeo.api' has no attribute 'apply'") tblen=7>>
hook = <pluggy._hooks._HookRelay object at 0x7ffff7345dc0>
def call_and_report(
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
> call = call_runtest_hook(item, when, **kwds)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:215:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>, when = 'call', kwds = {}
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook: Callable[..., None] = item.ihook.pytest_runtest_setup
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
ihook = item.ihook.pytest_runtest_teardown
else:
assert False, f"Unhandled runtest hook case: {when}"
reraise: Tuple[Type[BaseException], ...] = (Exit,)
if not item.config.getoption("usepdb", False):
reraise += (KeyboardInterrupt,)
> return CallInfo.from_call(
lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:254:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cls = <class '_pytest.runner.CallInfo'>
func = <function call_runtest_hook.<locals>.<lambda> at 0x7fff7413baf0>
when = 'call'
reraise = (<class '_pytest.outcomes.Exit'>, <class 'KeyboardInterrupt'>)
@classmethod
def from_call(
cls,
func: "Callable[[], TResult]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: Optional[
Union[Type[BaseException], Tuple[Type[BaseException], ...]]
] = None,
) -> "CallInfo[TResult]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
> result: Optional[TResult] = func()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:311:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> lambda: ihook(item=item, **kwds), when=when, reraise=reraise
)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:255:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_runtest_call'>, args = ()
kwargs = {'item': <Function test_named>}, argname = 'item', firstresult = False
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_runtest_call'
methods = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
kwargs = {'item': <Function test_named>}, firstresult = False
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_runtest_call'
hook_impls = [<HookImpl plugin_name='runner', plugin=<module '_pytest.runner' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-pyt...0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/unraisableexception.py'>>, ...]
caller_kwargs = {'item': <Function test_named>}, firstresult = False
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
item = <Function test_named>
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
del sys.last_value
del sys.last_traceback
except AttributeError:
pass
try:
> item.runtest()
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/runner.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Function test_named>
def runtest(self) -> None:
"""Execute the underlying test function."""
> self.ihook.pytest_pyfunc_call(pyfuncitem=self)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:1641:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_HookCaller 'pytest_pyfunc_call'>, args = ()
kwargs = {'pyfuncitem': <Function test_named>}, argname = 'pyfuncitem'
firstresult = True
def __call__(self, *args, **kwargs):
if args:
raise TypeError("hook calling supports only keyword arguments")
assert not self.is_historic()
# This is written to avoid expensive operations when not needed.
if self.spec:
for argname in self.spec.argnames:
if argname not in kwargs:
notincall = tuple(set(self.spec.argnames) - kwargs.keys())
warnings.warn(
"Argument(s) {} which are declared in the hookspec "
"can not be found in this hook call".format(notincall),
stacklevel=2,
)
break
firstresult = self.spec.opts.get("firstresult")
else:
firstresult = False
> return self._hookexec(self.name, self.get_hookimpls(), kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_hooks.py:265:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <_pytest.config.PytestPluginManager object at 0x7ffff74a4070>
hook_name = 'pytest_pyfunc_call'
methods = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
kwargs = {'pyfuncitem': <Function test_named>}, firstresult = True
def _hookexec(self, hook_name, methods, kwargs, firstresult):
# called from all hookcaller instances.
# enable_tracing will set its own wrapping function at self._inner_hookexec
> return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_manager.py:80:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
hook_name = 'pytest_pyfunc_call'
hook_impls = [<HookImpl plugin_name='python', plugin=<module '_pytest.python' from '/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py'>>]
caller_kwargs = {'pyfuncitem': <Function test_named>}, firstresult = True
def _multicall(hook_name, hook_impls, caller_kwargs, firstresult):
"""Execute a call into multiple python functions/methods and return the
result(s).
``caller_kwargs`` comes from _HookCaller.__call__().
"""
__tracebackhide__ = True
results = []
excinfo = None
try: # run impl and wrapper setup functions in a loop
teardowns = []
try:
for hook_impl in reversed(hook_impls):
try:
args = [caller_kwargs[argname] for argname in hook_impl.argnames]
except KeyError:
for argname in hook_impl.argnames:
if argname not in caller_kwargs:
raise HookCallError(
f"hook call must provide argument {argname!r}"
)
if hook_impl.hookwrapper:
try:
gen = hook_impl.function(*args)
next(gen) # first yield
teardowns.append(gen)
except StopIteration:
_raise_wrapfail(gen, "did not yield")
else:
> res = hook_impl.function(*args)
/nix/store/27wmb9723js9y895s18238845xi7rbbj-python3.9-pluggy-1.0.0/lib/python3.9/site-packages/pluggy/_callers.py:39:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
pyfuncitem = <Function test_named>
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
> result = testfunction(**testargs)
/nix/store/xqxk0viw75xf2vrdmf0na80r6cvq5g2j-python3.9-pytest-6.2.5/lib/python3.9/site-packages/_pytest/python.py:183:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_metrics.TestAuxMetrics object at 0x7fff5c045b50>
def test_named(self):
class MyModule(tx.Module):
aux: tx.Named[jnp.ndarray] = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = tx.Named("my_metric", jnp.array(1.0, jnp.float32))
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_metrics: tx.metrics.AuxMetrics):
nonlocal N
N += 1
metric_logs = module.filter(tx.MetricLog)
aux_metrics(aux_metrics=metric_logs)
return aux_metrics
module = MyModule()
metric_logs = module.filter(tx.MetricLog)
metrics = tx.metrics.AuxMetrics(metric_logs)
> metrics = f(module, metrics)
tests/metrics/test_metrics.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: Named(name='my_metric', value=FieldInfo(name='aux', value=DeviceArray(1., dtype=float32), kind=<c... str
totals: dict {
my_metric: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, __tracebackhide__ = True
msg = "AttributeError: module 'treeo.api' has no attribute 'apply'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------"
@util.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True
try:
> return fun(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/traceback_util.py:162:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (MyModule {
aux: Named(name='my_metric', value=FieldInfo(name='aux', value=DeviceArray(1., dtype=float32), kind=<c... str
totals: dict {
my_metric: jaxlib.DeviceArray((), float32), MetricState
},
})
kwargs = {}, closed_fun = Wrapped function:
Core: f
in_tree = PyTreeDef(((CustomNode(<class 'test_metrics.TestAuxMetrics.test_named.<locals>.MyModule'>[{'_mutable': False, '_field_...ounts': {'my_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {}))
args_flat = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
donated_invars = (False, False, False, False, False, False)
arg = DeviceArray(0., dtype=float32)
flat_fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_metrics.TestAuxMetrics.test_named.<locals>....y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
out_tree = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74198af0>
@api_boundary
def cache_miss(*args, **kwargs):
### This first part is basically the same code as in _python_jit.
# An alternative would be for cache_miss to accept from C++ the arguments
# (dyn_args, donated_invars, args_flat, in_tree), since otherwise we have
# work/code that is redundant between C++ and Python. We can try that later.
closed_fun, in_tree, args_flat, donated_invars = _prepare_jit(
fun, static_argnums, static_argnames, donate_argnums, args, kwargs)
for arg in args_flat:
_check_arg(arg)
flat_fun, out_tree = flatten_fun(closed_fun, in_tree)
> out_flat = xla.xla_call(
flat_fun, *args_flat,
device=device, backend=backend, name=flat_fun.__name__,
donated_invars=donated_invars, inline=inline)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/api.py:432:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = xla_call
fun = Wrapped function:
0 : flatten_fun (PyTreeDef(((CustomNode(<class 'test_metrics.TestAuxMetrics.test_named.<locals>....y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def bind(self, fun, *args, **params):
> return call_bind(self, fun, *args, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1709:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
primitive = xla_call
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
args = (DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32))
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
top_trace = EvalTrace(level=0/0)
env_trace_todo = <function transformation_with_aux.<locals>.<lambda> at 0x7fff74198b80>
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
def call_bind(primitive: CallPrimitive, fun, *args, **params):
top_trace = find_top_trace(args)
fun, env_trace_todo = process_env_traces_call(
fun, primitive, top_trace and top_trace.level, tuple(params.items()))
tracers = map(top_trace.full_raise, args)
> outs = top_trace.process_call(primitive, fun, tracers, params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:1721:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = EvalTrace(level=0/0), primitive = xla_call
f = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
tracers = [DeviceArray(1., dtype=float32), DeviceArray(10., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32), DeviceArray(0, dtype=uint32), DeviceArray(0., dtype=float32)]
params = {'backend': None, 'device': None, 'donated_invars': (False, False, False, False, False, False), 'inline': False, ...}
def process_call(self, primitive, f, tracers, params):
> return primitive.impl(f, *tracers, **params)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/core.py:614:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline):
del inline # Only used at tracing time
> compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
*unsafe_map(arg_spec, args))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:142:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
args = (None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...)
cache = {}
key = (((<function process_env_traces_call at 0x7fffa54cc160>, (xla_call, 0, (('device', None), ('backend', None), ('name', ...se, False, False), (ShapedArray(float32[]), None), (ShapedArray(float32[]), None), ...), False, (False, 'allow', None))
result = None
def memoized_fun(fun: WrappedFun, *args):
cache = fun_caches.setdefault(fun.f, {})
if config.jax_check_tracer_leaks:
key = (_copy_main_traces(fun.transforms), fun.params, args,
config.x64_enabled, config._trace_context())
else:
key = (fun.transforms, fun.params, args, config.x64_enabled,
config._trace_context())
result = cache.get(key, None)
if result is not None:
ans, stores = result
fun.populate_stores(stores)
else:
> ans = call(fun, *args)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:272:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
> return lower_xla_callable(fun, device, backend, name, donated_invars,
*arg_specs).compile().unsafe_call
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:169:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...])), {})),)
Core: f
, None, None, 'f', (False, False, False, False, False, False), (ShapedArray(float32[]), None), ...)
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
device = None, backend = None, name = 'f'
donated_invars = (False, False, False, False, False, False)
arg_specs = ((ShapedArray(float32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None), (ShapedArray(uint32[]), None), (ShapedArray(float32[]), None))
abstract_args = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
arg_devices = (None, None, None, None, None, None)
@profiler.annotate_function
def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
donated_invars, *arg_specs):
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = util.unzip2(arg_specs)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
> jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
fun, abstract_args, pe.debug_info_final(fun, "jit"))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/dispatch.py:197:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
args = (Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), ...'my_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})), True)))
kwargs = {}
@wraps(func)
def wrapper(*args, **kwargs):
with TraceAnnotation(name, **decorator_kwargs):
> return func(*args, **kwargs)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/_src/profiler.py:206:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
debug_info = DebugInfo(func_src_info='f at /build/source/tests/metrics/test_metrics.py:113', traced_for='jit', arg_info=functools.p...{'my_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})), True))
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.debug_info = debug_info # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
> jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1798:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
fun = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
main = MainTrace(0,DynamicJaxprTrace)
in_avals = (ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]), ShapedArray(uint32[]), ShapedArray(float32[]))
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[List[bool]] = None):
# In general, the Tracers passed to ther Python callable underlying `fun` may
# correspond to a subset of `in_avals` (i.e. a subset of the input binders in
# the jaxpr). For example:
#
# n = core.DShapedArray((), jnp.dtype('int32'), weak_type=False)
# a = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
# b = core.DShapedArray((n,), jnp.dtype('float32'), weak_type=False)
#
# @lu.wrap_init
# def f(x, y):
# return x, y
#
# jaxpr, _, _ = pe.trace_to_jaxpr_dynamic(f, [n, a, b],
# keep_inputs=[False, True, True])
# print(jaxpr)
# # { lambda ; a:i32[] b:f32[a] c:f32[a]. let in (b, c) }
#
# The abstract values passed to trace_to_jaxpr_dynamic are in direct
# correspondence to the input binders (invars) of the jaxpr it returns. But in
# general the Tracers passed to the function f correspond only to a subset of
# those abstract values. That's because axis size variables may not be
# explicit arguments to f, while we make everything explicit in the jaxpr.
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
with extend_jaxpr_stack(main, frame), source_info_util.reset_name_stack():
trace = DynamicJaxprTrace(main, core.cur_sublevel())
in_tracers = _avals_to_tracers(trace, in_avals)
in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
> ans = fun.call_wrapped(*in_tracers_)
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/interpreters/partial_eval.py:1775:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = Wrapped function:
0 : process_env_traces_call (xla_call, 0, (('device', None), ('backend', None), ('name', 'f'), (...y_metric': *}, 'totals': {'my_metric': *}}, 'counts': {'my_metric': *}, 'totals': {'my_metric': *}}])), {})),)
Core: f
args = (MyModule {
aux: Named(name='my_metric', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[])>with<Dyna... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
})
kwargs = {}, stack = [], gen = None, gen_static_args = None, out_store = None
def call_wrapped(self, *args, **kwargs):
"""Calls the underlying function, applying the transforms.
The positional `args` and keyword `kwargs` are passed to the first
transformation generator.
"""
stack = []
for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
gen = gen(*(gen_static_args + tuple(args)), **kwargs)
args, kwargs = next(gen)
stack.append((gen, out_store))
gen = gen_static_args = out_store = None
try:
> ans = self.f(*args, **dict(self.params, **kwargs))
/nix/store/l0llmvbc0n1j04gga0hz6jdfrl78zv9p-python3.9-jax-0.3.4/lib/python3.9/site-packages/jax/linear_util.py:166:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
module = MyModule {
aux: Named(name='my_metric', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[])>with<Dynam... str
some_value: jax.DynamicJaxprTracer((), float32),
}
aux_metrics = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
@jax.jit
def f(module: MyModule, aux_metrics: tx.metrics.AuxMetrics):
nonlocal N
N += 1
metric_logs = module.filter(tx.MetricLog)
> aux_metrics(aux_metrics=metric_logs)
tests/metrics/test_metrics.py:118:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
args = ()
kwargs = {'aux_metrics': MyModule {
aux: Named(name='my_metric', value=FieldInfo(name='aux', value=Traced<ShapedArray(float... str
some_value: Nothing,
}}
@functools.wraps(cls.update)
def new_call(self: M, *args, **kwargs) -> M:
if len(args) > 0:
raise TypeError(
f"All arguments to {cls.__name__}.__call__ should be passed as keyword arguments."
)
> return old_call(self, *args, **kwargs)
treex/metrics/metric.py:111:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
aux_metrics = MyModule {
aux: Named(name='my_metric', value=FieldInfo(name='aux', value=Traced<ShapedArray(float32[])>with<Dynam... str
some_value: Nothing,
}
def __call__(self, aux_metrics: tp.Any) -> tp.Dict[str, jnp.ndarray]:
> return super().__call__(aux_metrics=aux_metrics)
treex/metrics/metrics.py:100:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
kwargs = {'aux_metrics': MyModule {
aux: Named(name='my_metric', value=FieldInfo(name='aux', value=Traced<ShapedArray(float... str
some_value: Nothing,
}}
module = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def __call__(self, **kwargs) -> tp.Any:
if self._labels_filter is not None:
if "target" in kwargs and kwargs["target"] is not None:
for index in self._labels_filter:
kwargs["target"] = kwargs["target"][index]
if "preds" in kwargs and kwargs["preds"] is not None:
for index in self._labels_filter:
kwargs["preds"] = kwargs["preds"][index]
# update cumulative state
self.update(**kwargs)
# compute batch metrics
module = to.copy(self)
> module.reset()
treex/metrics/metric.py:79:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
def reset(self):
def do_reset(metric):
if isinstance(metric, Metric):
metric.__dict__.update(to.copy(metric._initial_state))
> self.apply(do_reset, inplace=True)
treex/metrics/metric.py:88:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74370e50>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E jax._src.traceback_util.UnfilteredStackTrace: AttributeError: module 'treeo.api' has no attribute 'apply'
E
E The stack trace below excludes JAX-internal frames.
E The preceding is the original exception that occurred, unmodified.
E
E --------------------
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: UnfilteredStackTrace
The above exception was the direct cause of the following exception:
self = <test_metrics.TestAuxMetrics object at 0x7fff5c045b50>
def test_named(self):
class MyModule(tx.Module):
aux: tx.Named[jnp.ndarray] = tx.MetricLog.node()
some_value: jnp.ndarray = tx.node()
def __init__(self) -> None:
self.aux = tx.Named("my_metric", jnp.array(1.0, jnp.float32))
self.some_value = jnp.array(10.0, jnp.float32)
N = 0
@jax.jit
def f(module: MyModule, aux_metrics: tx.metrics.AuxMetrics):
nonlocal N
N += 1
metric_logs = module.filter(tx.MetricLog)
aux_metrics(aux_metrics=metric_logs)
return aux_metrics
module = MyModule()
metric_logs = module.filter(tx.MetricLog)
metrics = tx.metrics.AuxMetrics(metric_logs)
> metrics = f(module, metrics)
tests/metrics/test_metrics.py:126:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/metrics/test_metrics.py:118: in f
aux_metrics(aux_metrics=metric_logs)
treex/metrics/metric.py:111: in new_call
return old_call(self, *args, **kwargs)
treex/metrics/metrics.py:100: in __call__
return super().__call__(aux_metrics=aux_metrics)
treex/metrics/metric.py:79: in __call__
module.reset()
treex/metrics/metric.py:88: in reset
self.apply(do_reset, inplace=True)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = AuxMetrics {
counts: dict {
my_metric: jax.DynamicJaxprTracer((), uint32), MetricState
},
dtyp... str
totals: dict {
my_metric: jax.DynamicJaxprTracer((), float32), MetricState
},
}
f = <function Metric.reset.<locals>.do_reset at 0x7fff74370e50>, inplace = True
rest = ()
def apply(self: A, f: tp.Callable[..., None], *rest: A, inplace: bool = False) -> A:
"""
`apply` is a wrapper over `treeo.apply` that passes `self` as the second argument.
Arguments:
f: The function to apply.
*rest: additional pytrees.
inplace: If `True`, the input `obj` is mutated.
Returns:
A new pytree with the updated Trees or the same input `obj` if `inplace` is `True`.
"""
> return api.apply(f, self, *rest, inplace=inplace)
E AttributeError: module 'treeo.api' has no attribute 'apply'
/nix/store/1crrfhbaq60ssjak09baavfxy49masnx-python3.9-treeo-0.0.11/lib/python3.9/site-packages/treeo/mixins.py:222: AttributeError
=============================== warnings summary ===============================
../../nix/store/gqaiyzrxgfc7iwajz4hi9k91g2v382nf-python3.9-flatbuffers-2.0.0/lib/python3.9/site-packages/flatbuffers/compat.py:19
/nix/store/gqaiyzrxgfc7iwajz4hi9k91g2v382nf-python3.9-flatbuffers-2.0.0/lib/python3.9/site-packages/flatbuffers/compat.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
import imp
treex/metrics/mean_absolute_error.py:31
/build/source/treex/metrics/mean_absolute_error.py:31: DeprecationWarning: invalid escape sequence \s
"""
treex/metrics/mean_square_error.py:31
/build/source/treex/metrics/mean_square_error.py:31: DeprecationWarning: invalid escape sequence \s
"""
-- Docs: https://docs.pytest.org/en/stable/warnings.html
=========================== short test summary info ============================
FAILED tests/metrics/test_accuracy.py::TestAccuracy::test_jit - AttributeErro...
FAILED tests/metrics/test_accuracy.py::TestAccuracy::test_logits_preds - Attr...
FAILED tests/metrics/test_loss_and_logs.py::TestLossAndLogs::test_basic - Att...
FAILED tests/metrics/test_loss_and_logs.py::TestLossAndLogs::test_batch_loss
FAILED tests/metrics/test_losses.py::TestLosses::test_list - AttributeError: ...
FAILED tests/metrics/test_losses.py::TestLosses::test_dict - AttributeError: ...
FAILED tests/metrics/test_losses.py::TestAuxLosses::test_basic - AttributeErr...
FAILED tests/metrics/test_losses.py::TestAuxLosses::test_named - AttributeErr...
FAILED tests/metrics/test_metric.py::TestMetric::test_on - AttributeError: mo...
FAILED tests/metrics/test_metric.py::TestMetric::test_jit - AttributeError: m...
FAILED tests/metrics/test_metrics.py::TestAccuracy::test_list - AttributeErro...
FAILED tests/metrics/test_metrics.py::TestAccuracy::test_dict - AttributeErro...
FAILED tests/metrics/test_metrics.py::TestAuxMetrics::test_basic - AttributeE...
FAILED tests/metrics/test_metrics.py::TestAuxMetrics::test_named - AttributeE...
============ 14 failed, 126 passed, 3 warnings in 93.45s (0:01:33) =============
error: builder for '/nix/store/03x06vaz54cyq2xnh77g4sm7vsp5ia7s-python3.9-treex-0.6.10.drv' failed with exit code 1;
last 10 log lines:
> FAILED tests/metrics/test_losses.py::TestLosses::test_dict - AttributeError: ...
> FAILED tests/metrics/test_losses.py::TestAuxLosses::test_basic - AttributeErr...
> FAILED tests/metrics/test_losses.py::TestAuxLosses::test_named - AttributeErr...
> FAILED tests/metrics/test_metric.py::TestMetric::test_on - AttributeError: mo...
> FAILED tests/metrics/test_metric.py::TestMetric::test_jit - AttributeError: m...
> FAILED tests/metrics/test_metrics.py::TestAccuracy::test_list - AttributeErro...
> FAILED tests/metrics/test_metrics.py::TestAccuracy::test_dict - AttributeErro...
> FAILED tests/metrics/test_metrics.py::TestAuxMetrics::test_basic - AttributeE...
> FAILED tests/metrics/test_metrics.py::TestAuxMetrics::test_named - AttributeE...
> ============ 14 failed, 126 passed, 3 warnings in 93.45s (0:01:33) =============
For full logs, run 'nix log /nix/store/03x06vaz54cyq2xnh77g4sm7vsp5ia7s-python3.9-treex-0.6.10.drv'.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment