Skip to content

Instantly share code, notes, and snippets.

@zhuyifei1999
Last active June 16, 2020 04:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zhuyifei1999/329a2028295b0c28e27b2dd5f79f83c1 to your computer and use it in GitHub Desktop.
Save zhuyifei1999/329a2028295b0c28e27b2dd5f79f83c1 to your computer and use it in GitHub Desktop.
# WARNING:
# This code is thread-unsafe.
# This code may break dicts if any key __hash__ are mutated as a side effect.
# Don't use in production or with threads unless you know what you're doing.
import types
import ctypes
import sys
from queue import Queue
from guppy import hpy
PyFrame_LocalsToFast = ctypes.pythonapi.PyFrame_LocalsToFast
PyFrame_LocalsToFast.argtypes = [ctypes.py_object, ctypes.c_int]
PyFrame_LocalsToFast.restype = None
hp = hpy()
class NotReadyException(Exception):
pass
class DoRollback(Exception):
pass
def hidden_scope(old, new):
_hiding_tag_ = hp._hiding_tag_
Path = hp._parent.Path
localsdict = hp.View.mutnodeset()
def mkrel(reltype, relarg):
return Path.rel_table[reltype.code](relarg)
def path_matters(src, rel, dst):
if (
isinstance(src, type) and
rel.code == Path.R_ATTRIBUTE.code and
rel.r == '__mro__'
):
# absolutely don't care about mro tuples,
# we have __base__ and __bases__
return False
if (
src in localsdict and
rel.code == Path.R_INDEXVAL.code
):
# don't care about locals dict, R_LOCAL_VAR is sufficient
return False
if (
rel.code == Path.R_ATTRIBUTE.code and
rel.r == '__objclass__'
):
return False
return True
DO_NOT_CARE = object()
def apply_rel(src, rel, orig, repl):
care_orig = orig is not DO_NOT_CARE
if rel.code == Path.R_ATTRIBUTE.code:
if care_orig:
assert getattr(src, rel.r) is orig
setattr(src, rel.r, repl)
assert getattr(src, rel.r) is repl
elif rel.code == Path.R_INDEXVAL.code:
if care_orig:
assert src[rel.r] is orig
src[rel.r] = repl
assert src[rel.r] is repl
elif rel.code == Path.R_INDEXKEY.code:
if care_orig and repl not in src:
val = src[orig]
del src[orig]
try:
src[repl] = val
except Exception:
src[orig] = val
raise
assert src[repl] is val
elif rel.code == Path.R_HASATTR.code:
apply_rel(src.__dict__,
mkrel(Path.R_INDEXKEY, rel.r),
orig, repl)
elif rel.code == Path.R_LOCAL_VAR.code:
assert type(src) is types.FrameType
assert care_orig
# XXX: Guppy yields cell objects on cell variables,
# but f_locals yields cell dereferenced
f_locals = src.f_locals
assert f_locals[rel.r] is orig
f_locals[rel.r] = repl
PyFrame_LocalsToFast(src, 1)
localsdict.add(f_locals)
assert src.f_locals[rel.r] is repl
# elif rel.code == Path.R_CELL.code:
# assert type(src) is types.FrameType
# if care_orig:
# assert src.f_locals[rel.r] is orig
# src.f_locals[rel.r] = repl
# PyFrame_LocalsToFast(src, 1)
# assert src.f_locals[rel.r] is repl
elif rel.code == Path.R_INSET.code:
if care_orig:
assert orig in src
src.remove(orig)
src.add(repl)
assert repl in src
else:
raise NotImplementedError(rel)
def get_replaced_obj(obj):
if obj not in reconstruct:
return obj
if not replaces.domain_covers((obj,)):
raise NotReadyException
return replaces[obj]
def do_reconstruct(obj):
paths = [path.path for path in hp.iso(obj).pathsout]
paths = [(src.theone, rel, dst.theone) for src, rel, dst in paths]
# Source: inspect.getfile
def is_class_builtin(cls):
if hasattr(cls, '__module__'):
module = sys.modules.get(cls.__module__)
if getattr(module, '__file__', None):
return False
return True
def get_rel(inrel):
res = []
for i, (src, rel, dst) in reversed(list(enumerate(paths))):
assert src is obj
if rel == inrel:
res.append(dst)
paths.pop(i)
if not res:
raise KeyError
assert len(res) == 1
return get_replaced_obj(res[0])
def get_rel_fb(rel, fb):
try:
return get_rel(rel)
except KeyError:
return fb()
def getattrrel(attr):
return get_rel_fb(mkrel(Path.R_ATTRIBUTE, attr),
lambda: getattr(obj, attr))
try:
typ = get_rel(mkrel(Path.R_INTERATTR, 'ob_type'))
except KeyError:
typ = type(obj)
assert isinstance(typ, type)
basetype = typ
while not is_class_builtin(basetype):
basetype = get_replaced_obj(basetype.__base__)
if basetype is object:
res = basetype.__new__(typ)
elif basetype is type:
res = basetype.__new__(
typ,
getattrrel('__name__'),
getattrrel('__bases__'),
getattrrel('__dict__'),
)
getattrrel('__base__')
elif basetype is tuple:
res = basetype.__new__(
typ,
(get_rel_fb(mkrel(Path.R_INDEXVAL, i), lambda: obj[i])
for i in range(len(obj)))
)
elif basetype is frozenset:
res = set()
for i, (src, rel, dst) in reversed(list(enumerate(paths))):
if rel.code == Path.R_INSET.code:
assert dst in obj
assert dst not in res
repl = get_replaced_obj(dst)
assert repl not in res
res.add(repl)
paths.pop(i)
res = basetype.__new__(typ, res)
elif basetype is classmethod:
res = basetype.__new__(typ, getattrrel('__func__'))
elif basetype is staticmethod:
res = basetype.__new__(typ, getattrrel('__func__'))
elif basetype is property:
res = basetype.__new__(
typ,
getattrrel('fget'),
getattrrel('fset'),
getattrrel('fdel'),
)
elif basetype is slice:
res = basetype.__new__(
typ,
getattrrel('start'),
getattrrel('stop'),
getattrrel('step'),
)
elif basetype is types.FunctionType:
res = basetype.__new__(
typ,
getattrrel('__code__'),
getattrrel('__globals__'),
getattrrel('__name__'),
getattrrel('__defaults__'),
getattrrel('__closure__'),
)
elif basetype is types.CodeType:
assert typ is basetype
args = {}
for i, (src, rel, dst) in reversed(list(enumerate(paths))):
if rel.code == Path.R_ATTRIBUTE.code and rel.r.startswith('co_'):
assert rel.r not in args
args[rel.r] = get_replaced_obj(dst)
paths.pop(i)
try:
res = obj.replace
except AttributeError:
kwds = typ.__doc__.translate(str.maketrans('', '', ' []\n'))
kwds = kwds[kwds.index('(')+1:kwds.index(')')].split(',')
kwds = [{
'codestring': 'code',
'constants': 'consts',
}.get(kw, kw) for kw in kwds]
res = typ(*(
args.get('co_' + kw, getattr(obj, 'co_' + kw))
for kw in kwds))
else:
res = res(**args)
elif basetype is types.MethodType:
res = basetype.__new__(
typ,
getattrrel('__func__'),
getattrrel('__self__'),
)
else:
raise NotImplementedError(basetype)
for src, rel, dst in paths:
if path_matters(src, rel, dst):
apply_rel(res, rel, DO_NOT_CARE, dst)
return res
reconstruct = hp.View.mutnodeset()
totraverse = hp.View.mutnodeset()
affected = hp.View.mutnodeset()
replaces = hp.View.nodegraph(is_mapping=True)
reconstruct.add(old)
totraverse.add(old)
replaces.add_edge(old, new)
while totraverse:
paths = hp.idset(totraverse).pathsin
totraverse = hp.View.mutnodeset()
for path in paths:
src, rel, dst = path.path
src, dst = src.theone, dst.theone
affected.add(src)
if src not in reconstruct:
if not path_matters(src, rel, dst):
continue
# print(src, rel, type(rel), dst, dst)
try:
apply_rel(src, rel, dst, dst)
except AssertionError:
raise
except Exception:
totraverse.add(src)
reconstruct.add(src)
# print('IMMMUTABLE')
toreconstruct = hp.View.mutnodeset(reconstruct - replaces.get_domain())
while toreconstruct:
reconstructed_any = False
for obj in list(toreconstruct):
try:
repl = do_reconstruct(obj)
except NotReadyException:
continue
toreconstruct.remove(obj)
replaces.add_edge(obj, repl)
reconstructed_any = True
assert reconstructed_any, 'Circular immutable reference'
rollback = Queue()
try:
mutate = hp.View.mutnodeset(affected - reconstruct)
for path in list(hp.idset(mutate).pathsout):
src, rel, dst = path.path
src, dst = src.theone, dst.theone
if not path_matters(src, rel, dst):
continue
if dst not in reconstruct:
continue
repl = get_replaced_obj(dst)
try:
apply_rel(src, rel, dst, repl)
except Exception as e:
raise DoRollback from e
else:
rollback.put((src, rel, repl, dst))
except DoRollback as e:
while not rollback.empty():
apply_rel(*rollback.get())
raise e.__cause__
def do_replace(old, new):
hp.View.enter(lambda: hidden_scope(old, new))
def do_it():
class A:
# __slots__ = 'foo',
attr = 1
class B:
# __slots__ = 'foo',
attr = 2
# a = A()
# a.foo = 'abcdefg'
# print(a, a.attr, a.foo)
# do_replace(A, B)
# print(a, a.attr, a.foo)
# def res():
# return A()
#
# print(res())
# do_replace(A, B)
# print(res())
a = type('frozenset_built', (frozenset, A), {'attr2': 3})({A()})
print(a, a.attr, a.attr2, type(a).__mro__)
do_replace(A, B)
print(a, a.attr, a.attr2, type(a).__mro__)
# a = 'abc'
# b = 'def'
# print(a)
# do_replace(a, b)
# print(b)
# def b():
# return 123456789
# print(b())
# do_replace(b(), 987654321)
ns = {}
exec(__import__('textwrap').dedent('''
def A():
return 'foobarbaz'
'''), ns)
print(ns['A']())
do_replace(ns['A'](), 'basbarfoo')
print(ns['A']())
do_it()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment