-
-
Save jansel/fbce058e74ef39a033e4bc14d82b3b68 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py | |
index 084721a5b2f..ed9a8764849 100644 | |
--- a/torch/_dynamo/eval_frame.py | |
+++ b/torch/_dynamo/eval_frame.py | |
@@ -320,7 +320,7 @@ class _TorchDynamoContext: | |
self.export = export | |
self.compiler_config = compiler_config | |
self.cleanup_fns: List[Callable[[], Any]] = [] | |
- self.enter_exit_hooks = [backend_cache_manager(self.callback)] | |
+ self.enter_exit_hooks = [] | |
patch_fn() | |
if dynamic is not None: | |
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py | |
index 2786be65e37..285059a7a71 100644 | |
--- a/torch/_dynamo/guards.py | |
+++ b/torch/_dynamo/guards.py | |
@@ -594,11 +594,6 @@ class GuardBuilder(GuardBuilderBase): | |
def BACKEND_MATCH(self, guard: Guard): | |
"""Guard on backend matching based on id of current_backend""" | |
assert guard.source is GuardSource.GLOBAL | |
- backend_id = ( | |
- f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}" | |
- ) | |
- code = [f"___check_current_backend({backend_id})"] | |
- self._produce_guard_code(guard, code) | |
def SHAPE_ENV(self, guard: Guard): | |
# Let's handle ShapeEnv guards. To do this, we will resolve | |
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c | |
index 5982022ce10..de00f18eae0 100644 | |
--- a/torch/csrc/dynamo/eval_frame.c | |
+++ b/torch/csrc/dynamo/eval_frame.c | |
@@ -651,25 +651,7 @@ static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEn | |
// NB: intentionally not using Py_RETURN_NONE, to return borrowed ref | |
return Py_None; | |
} | |
- PyObject *f_locals = frame->f_locals; | |
- // remember to update the type signature for GuardFn.__call__ in torch/_dynamo/types.py | |
- // if this calling convention changes | |
- PyObject* valid = PyObject_CallOneArg(e->check_fn, f_locals); | |
- if (unlikely(valid == NULL)) { | |
- if (guard_error_hook != NULL) { | |
- PyObject *type = NULL, *value = NULL, *traceback = NULL; | |
- PyErr_Fetch(&type, &value, &traceback); | |
- PyObject* r = call_guard_fail_hook(guard_error_hook, e, index, f_locals); | |
- if (r == NULL) { | |
- return NULL; | |
- } | |
- Py_DECREF(r); | |
- PyErr_Restore(type, value, traceback); | |
- } | |
- return NULL; | |
- } | |
- Py_DECREF(valid); | |
- if (valid == Py_True) { | |
+ if (true) { | |
// Keep the head as the most recently used cache entry. | |
// If the hit cache entry is not the head of the linked list, | |
// move it to the head | |
@@ -945,14 +927,6 @@ static PyObject* _custom_eval_frame( | |
CacheEntry* cache_entry = extract_cache_entry(extra); | |
FrameState* frame_state = extract_frame_state(extra); | |
- // TODO(jansel): investigate directly using the "fast" representation | |
- // TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame | |
- // even though we should pass a PyFrameObject. | |
- if (THP_PyFrame_FastToLocalsWithError(frame) < 0) { | |
- DEBUG_TRACE("error %s", get_frame_name(frame)); | |
- return NULL; | |
- } | |
- | |
// A callback of Py_False indicates "run only" mode, the cache is checked, but | |
// we never compile. | |
if (callback == Py_False) { | |
@@ -973,7 +947,6 @@ static PyObject* _custom_eval_frame( | |
DEBUG_TRACE("cache hit %s", get_frame_name(frame)); | |
return eval_custom_code(tstate, frame, cached_code, throw_flag); | |
} | |
- DEBUG_CHECK(PyDict_CheckExact(frame->f_locals)); | |
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals)); | |
DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins)); | |
@@ -996,6 +969,13 @@ static PyObject* _custom_eval_frame( | |
eval_frame_callback_set(callback); | |
return eval_custom_code(tstate, frame, cached_code, throw_flag); | |
} | |
+ | |
+ // only populate frame->f_locals on a cache miss | |
+ if (THP_PyFrame_FastToLocalsWithError(frame) < 0) { | |
+ DEBUG_TRACE("error %s", get_frame_name(frame)); | |
+ return NULL; | |
+ } | |
+ | |
// cache miss | |
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame | |
// that gets re-interpreted as a PyObject (which it is NOT!) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment