Skip to content

Instantly share code, notes, and snippets.

@adalke
Created June 17, 2021 15:31
Show Gist options
  • Save adalke/5563a54ba95edab06cc96cbe56a592ca to your computer and use it in GitHub Desktop.
Save adalke/5563a54ba95edab06cc96cbe56a592ca to your computer and use it in GitHub Desktop.
# Prototype implementation of a "toolbox" API for Open Force Field
# See https://github.com/openforcefield/openff-toolkit/issues/966 .
# Vocabulary:
#
# toolkit_wrapper = the existing OpenEyeToolkitWrapper, RDKitToolkitWrapper, etc.
#
# method name = something like "from_smiles" or "get_tagged_smarts_connectivity"
#
# method handler = a function which implement a named method
# (a given name may have multiple handlers)
#
# NOTE: the toolkit wrappers map a single method name to a single handler.
#
# toolbox = a collection of methods. This may have:
# - a dictionary mapping a method name to a handler function
# - zero or more toolkit_wrappers
#
# context stack = a sequence of toolkits, traversed from newest to oldest
# - the context element include a "inherit" flag which
# stops the traversal
#
# context method = a method handler found by traversing the context
# stack, along with the remaining context stack.
#
########
import dataclasses
import functools
import inspect
########
# Store information about the current toolbox call.
_ToolboxCall = dataclasses.make_dataclass(
"ToolboxCall",
("name", "args", "kwargs", "context_stack")
)
class ToolboxCall(_ToolboxCall):
def get_super(self):
"""get the next available context method, or None"""
return _get_context_method(self.context_stack, self.name)
# Provide a way to get access to the current method handler call.
# This lets the method handler forward the call to its super() method.
_current_call = None
def get_current_call():
return _current_call
def _set_current_call(toolbox_call):
global _current_call
_current_call = toolbox_call
########
# Search for the next available method handler in the context stack
# and return a context method.
# The context method wraps the handler function. When the context
# method is called, it updates the current call information, calls the
# handler, restores the old call information, and returns the result.
def _get_context_method(stack, name):
for i, (toolbox, inherit) in enumerate(stack):
# These are the actual callables
method_handlers = toolbox.get_method_handlers(name)
if method_handlers:
method_handler = method_handlers[0]
def new_context_method(*args, **kwargs):
# Remember the current call
old_call = get_current_call()
# There can be additional method handlers for this
# name. Store them in a toolkit-like container so
# get_super() can continue to look for them.
# TODO: could optimize if no handlers and/or stack remaining.
remaining_stack = [
(_MethodHandlerToolkit(name, method_handlers[1:]), inherit)
] + stack[i+1:]
# Update the information about the call that's about to be done
new_call = ToolboxCall(name, args, kwargs, remaining_stack)
_set_current_call(new_call)
try:
# THOUGHT. The method handler could raise a
# "I can't do this but keep on going".
# This could use new_call.get_super() to
# get the next possible handler, and
# keep track of all handlers that failed.
return method_handler(*args, **kwargs)
finally:
# Restore the old call
_set_current_call(old_call)
# Update the new context method so it has the same argspec
# and docstring as the handler function.
return functools.update_wrapper(new_context_method, method_handler)
if not inherit:
# Stop looking
return None
return None
# Keep track of the toolbox context stack
class ContextStack:
def __init__(self):
self.stack = [] # The top of the stack is [-1]
self.current_call = None
def push(self, toolbox, inherit=True):
self.stack.append( (toolbox, inherit) )
def pop(self):
self.stack.pop()
def get_context_method(self, name):
return _get_context_method(self.stack[::-1], name)
def get_method_names(self):
names = set()
for (toolbox, inherit) in reversed(self.stack):
names.update(toolbox.get_method_names())
if not inherit:
break
return names
# Maybe make this public?
context_stack = ContextStack()
# Don't allow names starting or ending with "_".
# Those aren't part of the public toolbox API.
def _allowed_name(name):
return not (name[:1] == "_" or name[-1:] == "_")
class CurrentToolbox:
def __getattr__(self, name):
if not _allowed_name(name):
raise AttributeError(f"{self.__class__.__name__!r} has no attribute '{name}'")
method = context_stack.get_context_method(name)
if method is None:
raise AttributeError(f"No toolbox function available named {name!r}")
return method
def __dir__(self):
return sorted(context_stack.get_method_names())
current_toolbox = CurrentToolbox()
# A Toolbox acts like a collection of methods.
# It uses method_lookups to find method handler or method names in the toolbox.
# These can be from a toolkit wrapper, dictionary of functions, etc.
class Toolbox:
def __init__(self, method_lookups, inherit):
self._method_lookups = method_lookups
self._inherit = inherit
def get_method_handlers(self, name):
"Find all method handlers for the given name"
if not _allowed_name(name):
raise ValueError(f"Invalid tool method name: {name!r}")
method_handlers = []
for lookup in self._method_lookups:
try:
method_handlers.append(lookup.get_method(name))
except _NotFound:
pass
return method_handlers
def get_method_names(self):
"Find all method names"
names = set()
for lookup in self._method_lookups:
names.update(lookup.get_method_names())
return names
def __dir__(self):
# support dir(toolbox)
names = self.get_method_names()
names.update([
"get_method_names", "get_method_handlers",
])
return sorted(names)
# Support 'with toolbox:'
def __enter__(self):
context_stack.push(self, self._inherit)
def __exit__(self, type, value, traceback):
context_stack.pop()
### Part of the internal API for finding methods
class _NotFound(Exception):
pass
## Just enough of the toolkit API to allow get_super() to work.
class _MethodHandlerToolkit:
def __init__(self, name, method_handlers):
self.name = name
self.method_handlers = method_handlers
def get_method_handlers(self, name):
assert name == self.name
return self.method_handlers
# Use attribute lookups; used to get methods from a toolkit wrapper
class _WrapperLookup:
def __init__(self, toolkit_wrapper):
self.toolkit_wrapper = toolkit_wrapper
def get_method(self, name):
try:
return getattr(self.toolkit_wrapper, name)
except AttributeError:
raise _NotFound
def get_method_names(self):
names = set()
for name, obj in inspect.getmembers(self.toolkit_wrapper):
if not _allowed_name(name):
continue
if not callable(obj):
continue
# Need some way to indicate which are toolbox functions
if name in {"is_available", "is_installed", "requires_toolkit"}:
continue
names.add(name)
return names
# Use dictionary lookups;l used to get methods from a dictionary
class _DictLookup:
def __init__(self, methods):
self.methods = methods
def get_method(self, name):
try:
return self.methods[name]
except KeyError:
raise _NotFound
def get_method_names(self):
return set(self.methods)
# Public interface to get methods from toolkit wrapper and/or kwargs.
# The kwargs takes precedence.
def get_toolbox(*toolkit_wrappers, inherit_=True, **methods):
method_lookups = []
if methods:
method_lookups.append(_DictLookup(methods))
for toolkit_wrapper in toolkit_wrappers:
method_lookups.append(_WrapperLookup(toolkit_wrapper))
return Toolbox(method_lookups = method_lookups, inherit=inherit_)
##### Let's see if it works.
from openff.toolkit import utils
rdkit_toolkit_wrapper = utils.RDKitToolkitWrapper()
openeye_toolkit_wrapper = utils.OpenEyeToolkitWrapper()
ambertools_toolkit_wrapper = utils.AmberToolsToolkitWrapper()
# Set the default toolkit order
context_stack.push(
get_toolbox(openeye_toolkit_wrapper, rdkit_toolkit_wrapper, ambertools_toolkit_wrapper)
)
# Override "from_smiles"
def my_from_smiles(smiles):
print(f"Making acetate instead of {smiles!r}")
from openff.toolkit.tests import create_molecules
return create_molecules.create_acetate()
# Ensure we only use RDKit for this method.
must_use_rdkit = get_toolbox(rdkit_toolkit_wrapper, inherit_=False)
def return_3x_methane():
tb = current_toolbox
with must_use_rdkit:
mol = tb.from_smiles("C")
smi = tb.to_smiles(mol)
return ".".join([smi]*3)
# Interpose the call, and forward to its super()
def debug_call(*args, **kwargs):
current_call = get_current_call()
next_call = current_call.get_super()
name = current_call.name
print(f"DEBUG: about to call {name}() in {next_call.__module__}")
try:
result = current_call.get_super()(*args, **kwargs)
except:
print("DEBUG: exception", name, *args, **kwargs)
raise
else:
print("DEBUG: returned", name, result)
return result
# Set up my toolbox
my_toolbox = get_toolbox(
from_smiles = my_from_smiles,
return_3x_methane = return_3x_methane,
to_smiles = debug_call,
)
with my_toolbox:
# If functions aren't found in my_toolbox then
# it will search the context stack
mol = current_toolbox.from_smiles("C#N")
print("mol", mol)
print("to_smiles:", current_toolbox.to_smiles(mol))
print("3x_methane:", current_toolbox.return_3x_methane())
with get_toolbox(rdkit_toolkit_wrapper):
with my_toolbox:
# If functions aren't found in my_toolbox then
# it will search the context stack
mol = current_toolbox.from_smiles("C#N")
print("mol", mol)
print("to_smiles:", current_toolbox.to_smiles(mol))
print("3x_methane:", current_toolbox.return_3x_methane())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment