Created
June 17, 2021 15:31
-
-
Save adalke/5563a54ba95edab06cc96cbe56a592ca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Prototype implementation of a "toolbox" API for Open Force Field | |
# See https://github.com/openforcefield/openff-toolkit/issues/966 . | |
# Vocabulary: | |
# | |
# toolkit_wrapper = the existing OpenEyeToolkitWrapper, RDKitToolkitWrapper, etc. | |
# | |
# method name = something like "from_smiles" or "get_tagged_smarts_connectivity" | |
# | |
# method handler = a function which implement a named method | |
# (a given name may have multiple handlers) | |
# | |
# NOTE: the toolkit wrappers map a single method name to a single handler. | |
# | |
# toolbox = a collection of methods. This may have: | |
# - a dictionary mapping a method name to a handler function | |
# - zero or more toolkit_wrappers | |
# | |
# context stack = a sequence of toolkits, traversed from newest to oldest | |
# - the context element include a "inherit" flag which | |
# stops the traversal | |
# | |
# context method = a method handler found by traversing the context | |
# stack, along with the remaining context stack. | |
# | |
######## | |
import dataclasses | |
import functools | |
import inspect | |
######## | |
# Store information about the current toolbox call. | |
_ToolboxCall = dataclasses.make_dataclass( | |
"ToolboxCall", | |
("name", "args", "kwargs", "context_stack") | |
) | |
class ToolboxCall(_ToolboxCall): | |
def get_super(self): | |
"""get the next available context method, or None""" | |
return _get_context_method(self.context_stack, self.name) | |
# Provide a way to get access to the current method handler call. | |
# This lets the method handler forward the call to its super() method. | |
_current_call = None | |
def get_current_call(): | |
return _current_call | |
def _set_current_call(toolbox_call): | |
global _current_call | |
_current_call = toolbox_call | |
######## | |
# Search for the next available method handler in the context stack | |
# and return a context method. | |
# The context method wraps the handler function. When the context | |
# method is called, it updates the current call information, calls the | |
# handler, restores the old call information, and returns the result. | |
def _get_context_method(stack, name): | |
for i, (toolbox, inherit) in enumerate(stack): | |
# These are the actual callables | |
method_handlers = toolbox.get_method_handlers(name) | |
if method_handlers: | |
method_handler = method_handlers[0] | |
def new_context_method(*args, **kwargs): | |
# Remember the current call | |
old_call = get_current_call() | |
# There can be additional method handlers for this | |
# name. Store them in a toolkit-like container so | |
# get_super() can continue to look for them. | |
# TODO: could optimize if no handlers and/or stack remaining. | |
remaining_stack = [ | |
(_MethodHandlerToolkit(name, method_handlers[1:]), inherit) | |
] + stack[i+1:] | |
# Update the information about the call that's about to be done | |
new_call = ToolboxCall(name, args, kwargs, remaining_stack) | |
_set_current_call(new_call) | |
try: | |
# THOUGHT. The method handler could raise a | |
# "I can't do this but keep on going". | |
# This could use new_call.get_super() to | |
# get the next possible handler, and | |
# keep track of all handlers that failed. | |
return method_handler(*args, **kwargs) | |
finally: | |
# Restore the old call | |
_set_current_call(old_call) | |
# Update the new context method so it has the same argspec | |
# and docstring as the handler function. | |
return functools.update_wrapper(new_context_method, method_handler) | |
if not inherit: | |
# Stop looking | |
return None | |
return None | |
# Keep track of the toolbox context stack | |
class ContextStack: | |
def __init__(self): | |
self.stack = [] # The top of the stack is [-1] | |
self.current_call = None | |
def push(self, toolbox, inherit=True): | |
self.stack.append( (toolbox, inherit) ) | |
def pop(self): | |
self.stack.pop() | |
def get_context_method(self, name): | |
return _get_context_method(self.stack[::-1], name) | |
def get_method_names(self): | |
names = set() | |
for (toolbox, inherit) in reversed(self.stack): | |
names.update(toolbox.get_method_names()) | |
if not inherit: | |
break | |
return names | |
# Maybe make this public? | |
context_stack = ContextStack() | |
# Don't allow names starting or ending with "_". | |
# Those aren't part of the public toolbox API. | |
def _allowed_name(name): | |
return not (name[:1] == "_" or name[-1:] == "_") | |
class CurrentToolbox: | |
def __getattr__(self, name): | |
if not _allowed_name(name): | |
raise AttributeError(f"{self.__class__.__name__!r} has no attribute '{name}'") | |
method = context_stack.get_context_method(name) | |
if method is None: | |
raise AttributeError(f"No toolbox function available named {name!r}") | |
return method | |
def __dir__(self): | |
return sorted(context_stack.get_method_names()) | |
current_toolbox = CurrentToolbox() | |
# A Toolbox acts like a collection of methods. | |
# It uses method_lookups to find method handler or method names in the toolbox. | |
# These can be from a toolkit wrapper, dictionary of functions, etc. | |
class Toolbox: | |
def __init__(self, method_lookups, inherit): | |
self._method_lookups = method_lookups | |
self._inherit = inherit | |
def get_method_handlers(self, name): | |
"Find all method handlers for the given name" | |
if not _allowed_name(name): | |
raise ValueError(f"Invalid tool method name: {name!r}") | |
method_handlers = [] | |
for lookup in self._method_lookups: | |
try: | |
method_handlers.append(lookup.get_method(name)) | |
except _NotFound: | |
pass | |
return method_handlers | |
def get_method_names(self): | |
"Find all method names" | |
names = set() | |
for lookup in self._method_lookups: | |
names.update(lookup.get_method_names()) | |
return names | |
def __dir__(self): | |
# support dir(toolbox) | |
names = self.get_method_names() | |
names.update([ | |
"get_method_names", "get_method_handlers", | |
]) | |
return sorted(names) | |
# Support 'with toolbox:' | |
def __enter__(self): | |
context_stack.push(self, self._inherit) | |
def __exit__(self, type, value, traceback): | |
context_stack.pop() | |
### Part of the internal API for finding methods | |
class _NotFound(Exception): | |
pass | |
## Just enough of the toolkit API to allow get_super() to work. | |
class _MethodHandlerToolkit: | |
def __init__(self, name, method_handlers): | |
self.name = name | |
self.method_handlers = method_handlers | |
def get_method_handlers(self, name): | |
assert name == self.name | |
return self.method_handlers | |
# Use attribute lookups; used to get methods from a toolkit wrapper | |
class _WrapperLookup: | |
def __init__(self, toolkit_wrapper): | |
self.toolkit_wrapper = toolkit_wrapper | |
def get_method(self, name): | |
try: | |
return getattr(self.toolkit_wrapper, name) | |
except AttributeError: | |
raise _NotFound | |
def get_method_names(self): | |
names = set() | |
for name, obj in inspect.getmembers(self.toolkit_wrapper): | |
if not _allowed_name(name): | |
continue | |
if not callable(obj): | |
continue | |
# Need some way to indicate which are toolbox functions | |
if name in {"is_available", "is_installed", "requires_toolkit"}: | |
continue | |
names.add(name) | |
return names | |
# Use dictionary lookups;l used to get methods from a dictionary | |
class _DictLookup: | |
def __init__(self, methods): | |
self.methods = methods | |
def get_method(self, name): | |
try: | |
return self.methods[name] | |
except KeyError: | |
raise _NotFound | |
def get_method_names(self): | |
return set(self.methods) | |
# Public interface to get methods from toolkit wrapper and/or kwargs. | |
# The kwargs takes precedence. | |
def get_toolbox(*toolkit_wrappers, inherit_=True, **methods): | |
method_lookups = [] | |
if methods: | |
method_lookups.append(_DictLookup(methods)) | |
for toolkit_wrapper in toolkit_wrappers: | |
method_lookups.append(_WrapperLookup(toolkit_wrapper)) | |
return Toolbox(method_lookups = method_lookups, inherit=inherit_) | |
##### Let's see if it works. | |
from openff.toolkit import utils | |
rdkit_toolkit_wrapper = utils.RDKitToolkitWrapper() | |
openeye_toolkit_wrapper = utils.OpenEyeToolkitWrapper() | |
ambertools_toolkit_wrapper = utils.AmberToolsToolkitWrapper() | |
# Set the default toolkit order | |
context_stack.push( | |
get_toolbox(openeye_toolkit_wrapper, rdkit_toolkit_wrapper, ambertools_toolkit_wrapper) | |
) | |
# Override "from_smiles" | |
def my_from_smiles(smiles): | |
print(f"Making acetate instead of {smiles!r}") | |
from openff.toolkit.tests import create_molecules | |
return create_molecules.create_acetate() | |
# Ensure we only use RDKit for this method. | |
must_use_rdkit = get_toolbox(rdkit_toolkit_wrapper, inherit_=False) | |
def return_3x_methane(): | |
tb = current_toolbox | |
with must_use_rdkit: | |
mol = tb.from_smiles("C") | |
smi = tb.to_smiles(mol) | |
return ".".join([smi]*3) | |
# Interpose the call, and forward to its super() | |
def debug_call(*args, **kwargs): | |
current_call = get_current_call() | |
next_call = current_call.get_super() | |
name = current_call.name | |
print(f"DEBUG: about to call {name}() in {next_call.__module__}") | |
try: | |
result = current_call.get_super()(*args, **kwargs) | |
except: | |
print("DEBUG: exception", name, *args, **kwargs) | |
raise | |
else: | |
print("DEBUG: returned", name, result) | |
return result | |
# Set up my toolbox | |
my_toolbox = get_toolbox( | |
from_smiles = my_from_smiles, | |
return_3x_methane = return_3x_methane, | |
to_smiles = debug_call, | |
) | |
with my_toolbox: | |
# If functions aren't found in my_toolbox then | |
# it will search the context stack | |
mol = current_toolbox.from_smiles("C#N") | |
print("mol", mol) | |
print("to_smiles:", current_toolbox.to_smiles(mol)) | |
print("3x_methane:", current_toolbox.return_3x_methane()) | |
with get_toolbox(rdkit_toolkit_wrapper): | |
with my_toolbox: | |
# If functions aren't found in my_toolbox then | |
# it will search the context stack | |
mol = current_toolbox.from_smiles("C#N") | |
print("mol", mol) | |
print("to_smiles:", current_toolbox.to_smiles(mol)) | |
print("3x_methane:", current_toolbox.return_3x_methane()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment