Skip to content

Instantly share code, notes, and snippets.

@Jokeren
Last active August 29, 2023 23:40
Show Gist options
  • Save Jokeren/fc756ad6b3b22c6dfcec32d5460a1e03 to your computer and use it in GitHub Desktop.
Save Jokeren/fc756ad6b3b22c6dfcec32d5460a1e03 to your computer and use it in GitHub Desktop.
record function reproducer
import torch
import sys
device = torch.device('cpu')
left = torch.zeros(100, device=device, requires_grad=True)
right = torch.zeros(100, device=device, requires_grad=True)
grad = torch.zeros(100, device=device)
for _ in range(10):
output = torch.add(left, right)
output.backward(grad)
#include <torch/extension.h>
#include <iostream>
int driver_register() {
at::addGlobalCallback(
at::RecordFunctionCallback(
[](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
std::cout << fn.forwardThreadId() << std::endl;
return nullptr;
},
[](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) {
return;
})
.needsInputs(false) // TODO(Keren): monitor inputs if needed?
.needsOutputs(false) // TODO(Keren): monitor outputs if needed?
.scopes({}));
return 0;
}
int _ret = driver_register();
@davidberard98
Copy link

Just to consolidate the discussion here, @Jokeren said I always get fwd_thread_id = 18446744073709551615, which should be 0 or 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment