Skip to content

Instantly share code, notes, and snippets.

@eavae
Created August 6, 2021 04:05
Show Gist options
  • Save eavae/b08108ef7fbcce6de28f3d4370e1d924 to your computer and use it in GitHub Desktop.
Save eavae/b08108ef7fbcce6de28f3d4370e1d924 to your computer and use it in GitHub Desktop.
tunable parameter
from collections import OrderedDict
import inspect
import typing
class SearchSpacesDict(OrderedDict):
def __setitem__(self, key, value):
super().__setitem__(key, value)
self.move_to_end(key)
class Tunable():
def __init__(self, *args, **kargs) -> None:
pass
@classmethod
def get_search_spaces(cls):
cs = SearchSpacesDict()
if hasattr(cls, '__search_spaces__'):
for name, space in getattr(cls, '__search_spaces__').items():
if issubclass(cls, space['_cls']):
cs[name] = space
return cs
def get_tunable_class(cls):
classes = []
if inspect.isclass(cls):
for cls in inspect.getmro(cls):
if cls == Tunable:
classes.append(cls)
return classes
def choice(name, options):
def inner(cls):
if not hasattr(cls, '__search_spaces__'):
setattr(cls, '__search_spaces__', SearchSpacesDict())
cs = getattr(cls, '__search_spaces__')
for i in options:
if inspect.isclass(i) and issubclass(i, Tunable):
for key, space in i.get_search_spaces().items():
cs[key] = { **space, '_cls': cls }
param = inspect.signature(cls.__init__).parameters.get(name)
if param is not None:
cs[name] = {
'bound': options,
'default': param.default,
'_cls': cls
}
else:
cs[name] = {
'name': name,
'bound': options,
'_cls': cls
}
return cls
return inner
def real(name, bound):
def inner(cls):
if not hasattr(cls, '__search_spaces__'):
setattr(cls, '__search_spaces__', SearchSpacesDict())
cs = getattr(cls, '__search_spaces__')
param = inspect.signature(cls.__init__).parameters.get(name)
if param is not None:
cs[name] = {
'bound': bound,
'default': param.default,
'_cls': cls
}
else:
cs[name] = {
'name': name,
'bound': bound,
'_cls': cls
}
return cls
return inner
"""
Simple example when collect multi search space with class decorator
"""
@choice('tunable_choice', [1, 2, 3])
@real('tunable_real', [1, 10])
class Hello(Tunable):
def __init__(self, list_var, dict_var=1, **kargs) -> None:
pass
# print(Hello.get_search_spaces())
"""
Deep inheritance
"""
@choice('have_seed', [True, False])
class Food(Tunable):
pass
@choice('color', ['red', 'green'])
class Apple(Food, Tunable):
pass
# print(Food.get_search_spaces())
# print(Apple.get_search_spaces())
"""
class composition
"""
@choice('size', ['s', 'm', 'x'])
class Noddle(Food, Tunable):
pass
@choice('eat', [Apple, Noddle])
class Lunch(Tunable):
pass
print(Lunch.get_search_spaces())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment