Created
August 6, 2021 04:05
-
-
Save eavae/b08108ef7fbcce6de28f3d4370e1d924 to your computer and use it in GitHub Desktop.
tunable parameter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import OrderedDict | |
import inspect | |
import typing | |
class SearchSpacesDict(OrderedDict): | |
def __setitem__(self, key, value): | |
super().__setitem__(key, value) | |
self.move_to_end(key) | |
class Tunable(): | |
def __init__(self, *args, **kargs) -> None: | |
pass | |
@classmethod | |
def get_search_spaces(cls): | |
cs = SearchSpacesDict() | |
if hasattr(cls, '__search_spaces__'): | |
for name, space in getattr(cls, '__search_spaces__').items(): | |
if issubclass(cls, space['_cls']): | |
cs[name] = space | |
return cs | |
def get_tunable_class(cls): | |
classes = [] | |
if inspect.isclass(cls): | |
for cls in inspect.getmro(cls): | |
if cls == Tunable: | |
classes.append(cls) | |
return classes | |
def choice(name, options): | |
def inner(cls): | |
if not hasattr(cls, '__search_spaces__'): | |
setattr(cls, '__search_spaces__', SearchSpacesDict()) | |
cs = getattr(cls, '__search_spaces__') | |
for i in options: | |
if inspect.isclass(i) and issubclass(i, Tunable): | |
for key, space in i.get_search_spaces().items(): | |
cs[key] = { **space, '_cls': cls } | |
param = inspect.signature(cls.__init__).parameters.get(name) | |
if param is not None: | |
cs[name] = { | |
'bound': options, | |
'default': param.default, | |
'_cls': cls | |
} | |
else: | |
cs[name] = { | |
'name': name, | |
'bound': options, | |
'_cls': cls | |
} | |
return cls | |
return inner | |
def real(name, bound): | |
def inner(cls): | |
if not hasattr(cls, '__search_spaces__'): | |
setattr(cls, '__search_spaces__', SearchSpacesDict()) | |
cs = getattr(cls, '__search_spaces__') | |
param = inspect.signature(cls.__init__).parameters.get(name) | |
if param is not None: | |
cs[name] = { | |
'bound': bound, | |
'default': param.default, | |
'_cls': cls | |
} | |
else: | |
cs[name] = { | |
'name': name, | |
'bound': bound, | |
'_cls': cls | |
} | |
return cls | |
return inner | |
""" | |
Simple example when collect multi search space with class decorator | |
""" | |
@choice('tunable_choice', [1, 2, 3]) | |
@real('tunable_real', [1, 10]) | |
class Hello(Tunable): | |
def __init__(self, list_var, dict_var=1, **kargs) -> None: | |
pass | |
# print(Hello.get_search_spaces()) | |
""" | |
Deep inheritance | |
""" | |
@choice('have_seed', [True, False]) | |
class Food(Tunable): | |
pass | |
@choice('color', ['red', 'green']) | |
class Apple(Food, Tunable): | |
pass | |
# print(Food.get_search_spaces()) | |
# print(Apple.get_search_spaces()) | |
""" | |
class composition | |
""" | |
@choice('size', ['s', 'm', 'x']) | |
class Noddle(Food, Tunable): | |
pass | |
@choice('eat', [Apple, Noddle]) | |
class Lunch(Tunable): | |
pass | |
print(Lunch.get_search_spaces()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment