|
""" |
|
|
|
""" |
|
from __future__ import annotations |
|
__all__ = [] |
|
|
|
|
|
#--- Wrap: Create Pre- and Postfix Functions |
|
|
|
import dataclasses |
|
import typing |
|
|
|
@dataclasses.dataclass(frozen=True) |
|
class Wrap: |
|
func: typing.Callable |
|
args: list = dataclasses.field(default_factory=list) |
|
kwargs: dict = dataclasses.field(default_factory=dict) |
|
|
|
def __call__(self, *args, **kwargs) -> auto.Callable: |
|
return self.func(*self.args, *args, **self.kwargs, **kwargs) |
|
|
|
def __matmul__(self, other: auto.typing.Any) -> auto.typing.Any: |
|
return self(other) |
|
|
|
def __rmatmul__ (self, other: auto.typing.Any) -> auto.typing.Any: |
|
return self(other) |
|
|
|
def __or__(self, other: auto.typing.Any) -> auto.typing.Any: |
|
return self(other) |
|
|
|
def __ror__(self, other: auto.typing.Any) -> auto.typing.Any: |
|
return self(other) |
|
|
|
def wrap(func: auto.Callable, /, *args, **kwargs) -> Wrap: |
|
return Wrap(func, args, kwargs) |
|
|
|
assert (wrap(lambda x: x + 1) @ 1) == 2 |
|
assert (1 @ wrap(lambda x: x + 1)) == 2 |
|
|
|
wrap = wrap(wrap) # :) |
|
|
|
|
|
|
|
#--- Automatic Importer (and Package Installer) |
|
|
|
class AutoImportError(ImportError): |
|
pass |
|
|
|
class AutoImport(object): |
|
wrap: ClassVar = wrap |
|
|
|
__install_names: ClassVar[dict[str, list[str]]] = {} |
|
__import_names: ClassVar[dict[str, list[str]]] = {} |
|
|
|
# XXX(th): My Jupyter notebooks keep crashing when Google Colab tries to access |
|
# these properties, and AutoImport tries to install it. Hopefully this fixes that |
|
# problem. |
|
__custom_documentations__ = {} |
|
__wrapped__ = None |
|
|
|
@classmethod |
|
def __register( |
|
cls, |
|
*, |
|
name: str, |
|
install_names: list[str], |
|
import_names: list[str], |
|
): |
|
cls.__install_names[name] = install_names |
|
cls.__import_names[name] = import_names |
|
|
|
@classmethod |
|
def register( |
|
cls, |
|
name: str, |
|
*extra_names: list[str], |
|
install_names: list[str]=None, |
|
import_names: list[str]=None, |
|
): |
|
if install_names is None: |
|
install_names = [] |
|
|
|
if import_names is None: |
|
import_names = [] |
|
|
|
for extra_name in extra_names: |
|
if '.' in extra_name: |
|
import_names.append(extra_name) |
|
|
|
else: |
|
install_names.append(extra_name) |
|
|
|
if not install_names: |
|
install_names.append(name) |
|
|
|
if not import_names: |
|
import_names.append(name) |
|
|
|
cls.__register( |
|
name=name, |
|
install_names=install_names, |
|
import_names=import_names, |
|
) |
|
|
|
def __import(self, name: str): |
|
import importlib |
|
|
|
import_names = self.__import_names.get(name, [name]) |
|
modules = [] |
|
for import_name in import_names: |
|
module = importlib.import_module(import_name) |
|
modules.append(module) |
|
|
|
return modules[0] |
|
|
|
def __install(self, name: str): |
|
import pip, warnings |
|
|
|
install_names = self.__install_names.get(name, [name]) |
|
with warnings.catch_warnings(): |
|
auto.warnings.simplefilter(action='ignore', category=DeprecationWarning) |
|
pip.main([ |
|
'install', |
|
'--quiet', |
|
*install_names, |
|
]) |
|
|
|
def __getitem__(self, name: str): |
|
return getattr(self, name) |
|
|
|
def __getattr__(self, name: str): |
|
import subprocess, importlib, sys |
|
|
|
module = None |
|
try: |
|
module = self.__import(name) |
|
|
|
except ImportError as e: |
|
self.__install(name) |
|
module = self.__import(name) |
|
|
|
assert module is not None |
|
setattr(self, name, module) |
|
return module |
|
|
|
AutoImport.register('langchain', import_names=[ |
|
'langchain', |
|
|
|
'langchain.adapters', 'langchain.agents', 'langchain._api', |
|
'langchain.callbacks', 'langchain.chains', 'langchain.chat_loaders', |
|
'langchain.chat_models', 'langchain.docstore', |
|
'langchain.document_loaders', 'langchain.document_transformers', |
|
'langchain.embeddings', 'langchain.evaluation', 'langchain.globals', |
|
'langchain.graphs', 'langchain.indexes', 'langchain.llms', |
|
'langchain.load', 'langchain.memory', 'langchain.output_parsers', |
|
'langchain.prompts', 'langchain.pydantic_v1', 'langchain.retrievers', |
|
'langchain.runnables', 'langchain.schema', 'langchain.smith', |
|
'langchain.storage', 'langchain.tools', 'langchain.utilities', |
|
'langchain.utils', 'langchain.vectorstores', |
|
|
|
# 'langchain.agents.agent_toolkits', 'langchain.agents.chat', |
|
# 'langchain.agents.conversational', |
|
# 'langchain.agents.conversational_chat', |
|
# 'langchain.agents.format_scratchpad', 'langchain.agents.mrkl', |
|
# 'langchain.agents.openai_functions_agent', |
|
# 'langchain.agents.openai_functions_multi_agent', |
|
# 'langchain.agents.output_parsers', 'langchain.agents.react', |
|
# 'langchain.agents.self_ask_with_search', |
|
# 'langchain.agents.structured_chat', 'langchain.agents.xml', |
|
# 'langchain.callbacks.streamlit', 'langchain.callbacks.tracers', |
|
# 'langchain.chains.api', 'langchain.chains.chat_vector_db', |
|
# 'langchain.chains.combine_documents', |
|
# 'langchain.chains.constitutional_ai', 'langchain.chains.conversation', |
|
# 'langchain.chains.conversational_retrieval', |
|
# 'langchain.chains.elasticsearch_database', 'langchain.chains.flare', |
|
# 'langchain.chains.graph_qa', 'langchain.chains.hyde', |
|
# 'langchain.chains.llm_bash', 'langchain.chains.llm_checker', |
|
# 'langchain.chains.llm_math', |
|
# 'langchain.chains.llm_summarization_checker', |
|
# 'langchain.chains.llm_symbolic_math', 'langchain.chains.natbot', |
|
# 'langchain.chains.openai_functions', 'langchain.chains.qa_generation', |
|
# 'langchain.chains.qa_with_sources', |
|
# 'langchain.chains.query_constructor', |
|
# 'langchain.chains.question_answering', 'langchain.chains.retrieval_qa', |
|
# 'langchain.chains.router', 'langchain.chains.sql_database', |
|
# 'langchain.chains.summarize', 'langchain.document_loaders.blob_loaders', |
|
# 'langchain.document_loaders.parsers', 'langchain.evaluation.agents', |
|
# 'langchain.evaluation.comparison', 'langchain.evaluation.criteria', |
|
# 'langchain.evaluation.embedding_distance', |
|
# 'langchain.evaluation.exact_match', 'langchain.evaluation.parsing', |
|
# 'langchain.evaluation.qa', 'langchain.evaluation.regex_match', |
|
# 'langchain.evaluation.scoring', 'langchain.evaluation.string_distance', |
|
# 'langchain.indexes.prompts', 'langchain.memory.chat_message_histories', |
|
# 'langchain.prompts.example_selector', |
|
# 'langchain.retrievers.document_compressors', |
|
# 'langchain.retrievers.self_query', 'langchain.schema.runnable', |
|
# 'langchain.smith.evaluation', 'langchain.tools.amadeus', |
|
# 'langchain.tools.arxiv', 'langchain.tools.azure_cognitive_services', |
|
# 'langchain.tools.bearly', 'langchain.tools.bing_search', |
|
# 'langchain.tools.brave_search', 'langchain.tools.clickup', |
|
# 'langchain.tools.dataforseo_api_search', 'langchain.tools.ddg_search', |
|
# 'langchain.tools.e2b_data_analysis', 'langchain.tools.edenai', |
|
# 'langchain.tools.eleven_labs', 'langchain.tools.file_management', |
|
# 'langchain.tools.github', 'langchain.tools.gitlab', |
|
# 'langchain.tools.gmail', 'langchain.tools.golden_query', |
|
# 'langchain.tools.google_places', 'langchain.tools.google_scholar', |
|
# 'langchain.tools.google_search', 'langchain.tools.google_serper', |
|
# 'langchain.tools.graphql', 'langchain.tools.human', |
|
# 'langchain.tools.interaction', 'langchain.tools.jira', |
|
# 'langchain.tools.json', 'langchain.tools.metaphor_search', |
|
# 'langchain.tools.multion', 'langchain.tools.nuclia', |
|
# 'langchain.tools.office365', 'langchain.tools.openapi', |
|
# 'langchain.tools.openweathermap', 'langchain.tools.playwright', |
|
# 'langchain.tools.powerbi', 'langchain.tools.pubmed', |
|
# 'langchain.tools.python', 'langchain.tools.requests', |
|
# 'langchain.tools.scenexplain', 'langchain.tools.searchapi', |
|
# 'langchain.tools.searx_search', 'langchain.tools.shell', |
|
# 'langchain.tools.sleep', 'langchain.tools.spark_sql', |
|
# 'langchain.tools.sql_database', |
|
# 'langchain.tools.steamship_image_generation', |
|
# 'langchain.tools.tavily_search', 'langchain.tools.vectorstore', |
|
# 'langchain.tools.wikipedia', 'langchain.tools.wolfram_alpha', |
|
# 'langchain.tools.youtube', 'langchain.tools.zapier', |
|
# 'langchain.vectorstores.docarray', 'langchain.vectorstores.redis' |
|
]) |
|
|
|
AutoImport.register('tf', import_names=[ |
|
'tensorflow', |
|
]) |
|
|
|
AutoImport.register('google', import_names=[ |
|
'google', 'google.colab', 'google.colab.syntax', 'google.colab.userdata', |
|
]) |
|
|
|
AutoImport.register('tk', import_names=[ |
|
'tkinter', 'tkinter.ttk', 'tkinter.scrolledtext', 'tkinter.dnd', |
|
'tkinter.font', 'tkinter.tix', 'tkinter.colorchooser', |
|
'tkinter.messagebox', |
|
]) |
|
|
|
AutoImport.register('tkinter', import_names=[ |
|
'tkinter', 'tkinter.ttk', 'tkinter.scrolledtext', 'tkinter.dnd', |
|
'tkinter.font', 'tkinter.tix', 'tkinter.colorchooser', |
|
'tkinter.messagebox', |
|
]) |
|
|
|
AutoImport.register('ttk', import_names=[ |
|
'tkinter.ttk', 'tkinter.scrolledtext', 'tkinter.dnd', |
|
'tkinter.font', 'tkinter.tix', 'tkinter.colorchooser', |
|
'tkinter.messagebox', |
|
]) |
|
|
|
AutoImport.register('np', install_names=['numpy'], import_names=['numpy']) |
|
AutoImport.register('pd', install_names=['pandas'], import_names=['pandas']) |
|
|
|
AutoImport.register('tqdm', import_names=['tqdm', 'tqdm.auto', 'tqdm.notebook']) |
|
|
|
for pyplot_name in ['pyplot', 'plt']: |
|
AutoImport.register(pyplot_name, install_names=['matplotlib'], import_names=['matplotlib.pyplot']) |
|
AutoImport.register('matplotlib', import_names=[ |
|
'matplotlib', |
|
'matplotlib.pyplot', |
|
]) |
|
|
|
# Thanks https://docs.scipy.org/doc/scipy/reference/index.html |
|
# (() => { |
|
# const texts = []; |
|
# for (const $el of $$('ul.nav > li.toctree-l1 > a.reference.internal')) { |
|
# texts.push(` "${$el.textContent.trim()}",`); |
|
# } |
|
# console.log(texts.join("\n")); |
|
# })(); |
|
AutoImport.register('scipy', import_names=[ |
|
"scipy", |
|
"scipy.cluster", |
|
"scipy.constants", |
|
"scipy.datasets", |
|
"scipy.fft", |
|
"scipy.fftpack", |
|
"scipy.integrate", |
|
"scipy.interpolate", |
|
"scipy.io", |
|
"scipy.linalg", |
|
"scipy.misc", |
|
"scipy.ndimage", |
|
"scipy.odr", |
|
"scipy.optimize", |
|
"scipy.signal", |
|
"scipy.sparse", |
|
"scipy.spatial", |
|
"scipy.special", |
|
"scipy.stats", |
|
]) |
|
|
|
# Thanks https://scikit-learn.org/stable/modules/classes.html |
|
# (() => { |
|
# const texts = []; |
|
# for (const $el of $$('code.xref.py.py-mod.docutils.literal.notranslate > span.pre')) { |
|
# texts.push(` "${$el.textContent}",`); |
|
# } |
|
# console.log(texts.join("\n")); |
|
# })(); |
|
AutoImport.register('sklearn', import_names=[ |
|
"sklearn", |
|
"sklearn.base", |
|
"sklearn.calibration", |
|
"sklearn.cluster", |
|
"sklearn.compose", |
|
"sklearn.covariance", |
|
"sklearn.cross_decomposition", |
|
"sklearn.datasets", |
|
"sklearn.decomposition", |
|
"sklearn.discriminant_analysis", |
|
"sklearn.dummy", |
|
"sklearn.ensemble", |
|
"sklearn.exceptions", |
|
"sklearn.experimental", |
|
"sklearn.feature_extraction", |
|
"sklearn.feature_selection", |
|
"sklearn.gaussian_process", |
|
"sklearn.impute", |
|
"sklearn.inspection", |
|
"sklearn.isotonic", |
|
"sklearn.kernel_approximation", |
|
"sklearn.kernel_ridge", |
|
"sklearn.linear_model", |
|
"sklearn.manifold", |
|
"sklearn.metrics", |
|
"sklearn.mixture", |
|
"sklearn.model_selection", |
|
"sklearn.multiclass", |
|
"sklearn.multioutput", |
|
"sklearn.naive_bayes", |
|
"sklearn.neighbors", |
|
"sklearn.neural_network", |
|
"sklearn.pipeline", |
|
"sklearn.preprocessing", |
|
"sklearn.random_projection", |
|
"sklearn.semi_supervised", |
|
"sklearn.svm", |
|
"sklearn.tree", |
|
"sklearn.utils", |
|
"sklearn.base", |
|
"sklearn.calibration", |
|
"sklearn.cluster", |
|
"sklearn.cluster", |
|
"sklearn.compose", |
|
"sklearn.covariance", |
|
"sklearn.covariance", |
|
"sklearn.cross_decomposition", |
|
"sklearn.datasets", |
|
"sklearn.datasets", |
|
"sklearn.decomposition", |
|
"sklearn.decomposition", |
|
"sklearn.discriminant_analysis", |
|
"sklearn.dummy", |
|
"sklearn.ensemble", |
|
"sklearn.ensemble", |
|
"sklearn.exceptions", |
|
"sklearn.exceptions", |
|
"sklearn.experimental", |
|
"sklearn.experimental", |
|
"sklearn.feature_extraction", |
|
"sklearn.feature_extraction", |
|
"sklearn.feature_extraction.image", |
|
"sklearn.feature_extraction.text", |
|
"sklearn.feature_selection", |
|
"sklearn.feature_selection", |
|
"sklearn.gaussian_process", |
|
"sklearn.gaussian_process", |
|
"sklearn.impute", |
|
"sklearn.inspection", |
|
"sklearn.inspection", |
|
"sklearn.isotonic", |
|
"sklearn.kernel_approximation", |
|
"sklearn.kernel_approximation", |
|
"sklearn.kernel_ridge", |
|
"sklearn.kernel_ridge", |
|
"sklearn.linear_model", |
|
"sklearn.linear_model", |
|
"sklearn.manifold", |
|
"sklearn.manifold", |
|
"sklearn.metrics", |
|
"sklearn.metrics", |
|
"sklearn.metrics.cluster", |
|
"sklearn.mixture", |
|
"sklearn.mixture", |
|
"sklearn.model_selection", |
|
"sklearn.multiclass", |
|
"sklearn.multioutput", |
|
"sklearn.naive_bayes", |
|
"sklearn.naive_bayes", |
|
"sklearn.neighbors", |
|
"sklearn.neighbors", |
|
"sklearn.neural_network", |
|
"sklearn.neural_network", |
|
"sklearn.pipeline", |
|
"sklearn.pipeline", |
|
"sklearn.preprocessing", |
|
"sklearn.preprocessing", |
|
"sklearn.random_projection", |
|
"sklearn.semi_supervised", |
|
"sklearn.semi_supervised", |
|
"sklearn.svm", |
|
"sklearn.svm", |
|
"sklearn.tree", |
|
"sklearn.tree", |
|
"sklearn.utils", |
|
"sklearn.utils", |
|
]) |
|
|
|
AutoImport.register('PIL', install_names=['pillow'], import_names=[ |
|
'PIL', |
|
'PIL.BmpImagePlugin', |
|
'PIL.ExifTags', |
|
'PIL.GifImagePlugin', |
|
'PIL.GimpGradientFile', |
|
'PIL.GimpPaletteFile', |
|
'PIL.Image', |
|
'PIL.ImageChops', |
|
'PIL.ImageColor', |
|
'PIL.ImageFile', |
|
'PIL.ImageMode', |
|
'PIL.ImageOps', |
|
'PIL.ImagePalette', |
|
'PIL.ImageSequence', |
|
'PIL.JpegImagePlugin', |
|
'PIL.JpegPresets', |
|
'PIL.PaletteFile', |
|
'PIL.PngImagePlugin', |
|
'PIL.PpmImagePlugin', |
|
'PIL.TiffImagePlugin', |
|
'PIL.TiffTags', |
|
]) |
|
|
|
machine_learning_packages = [ |
|
'transformers', |
|
'accelerate', |
|
'datasets', |
|
'tokenizers', |
|
'evaluate', |
|
'huggingface_hub', |
|
'torch', |
|
] |
|
for machine_learning_package in machine_learning_packages: |
|
AutoImport.register(machine_learning_package, install_names=machine_learning_packages) |
|
|
|
auto = AutoImport() |
|
__all__ += [ |
|
'auto', |
|
] |
|
|
|
|
|
#--- Auto Import Submodules |
|
|
|
def __getattr__(name: str): |
|
if name.startswith('_') or name.endswith('_'): |
|
raise AttributeError(name) |
|
|
|
ckey = f'mediocreatbest.{name}' |
|
if ckey in auto.sys.modules: |
|
module = auto.sys.modules[ckey] |
|
return module |
|
|
|
raise AttributeError(name) |
|
|
|
|
|
#--- Auto Display in Jupyter |
|
|
|
class AutoDisplay: |
|
def __getattr__(self, name: str, /) -> auto.typing.Callable: |
|
func = getattr(auto.IPython.display, name) |
|
|
|
@auto.functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
return self(func(*args, **kwargs)) |
|
|
|
return wrapper |
|
|
|
@auto.functools.wraps(auto.IPython.display.display) |
|
def __call__(self, *args, **kwargs): |
|
return auto.IPython.display.display(*args, **kwargs) |
|
|
|
auto.display = AutoDisplay() |
|
|
|
|
|
#--- Run doctests on a single function |
|
|
|
def doctest(func=None, /, verbose=False, sterile=False): |
|
def wrapper(func): |
|
# Thanks https://stackoverflow.com/a/49659927 |
|
import doctest, copy |
|
|
|
# I need this to error out on failure; the default one doesn't. |
|
def run_docstring_examples(f, globs, verbose=False, name="NoName", compileflags=None, optionflags=0): |
|
finder = doctest.DocTestFinder(verbose=verbose, recurse=False) |
|
runner = doctest.DocTestRunner(verbose=verbose, optionflags=optionflags) |
|
for test in finder.find(func, name, globs=globs): |
|
runner.run(test, compileflags=compileflags) |
|
assert runner.failures == 0 |
|
|
|
name = func.__name__ |
|
|
|
if sterile: |
|
globs = {} |
|
else: |
|
globs = copy.copy(globals()) |
|
globs[name] = func |
|
run_docstring_examples(func, globs, verbose=verbose, name=name) |
|
return func |
|
|
|
if func is not None: |
|
return wrapper(func) |
|
else: |
|
return wrapper |
|
|
|
#--- |
|
|
|
try: |
|
g |
|
except NameError: |
|
g = {} |
|
|
|
try: |
|
f |
|
except NameError: |
|
f = {} |
|
|
|
def run(func=None, /, name=None, cond=True, splat=False, after=None, scope=None, once=False): |
|
def wrapper(func, /, *, name=name, cond=cond): |
|
import inspect |
|
|
|
if callable(cond): |
|
cond = cond() |
|
|
|
if not cond: |
|
return None |
|
|
|
if name is None: |
|
name = func.__name__ |
|
|
|
|
|
f[name] = func |
|
|
|
args = [] |
|
for key, parameter in inspect.signature(func).parameters.items(): |
|
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY: |
|
keys = [key] |
|
if scope is not None: |
|
keys.insert(0, f'{scope}__{key}') |
|
|
|
for key in keys: |
|
try: |
|
value = g[key] |
|
except KeyError: |
|
continue |
|
else: |
|
args.append(value) |
|
break |
|
else: |
|
raise KeyError(f'None of {keys=!r} found in g') |
|
|
|
|
|
if once: |
|
if scope is not None: |
|
if f'{scope}__{name}' in g: |
|
return |
|
|
|
else: |
|
if name in g: |
|
return |
|
|
|
ret = func(*args) |
|
|
|
if callable(after): |
|
after(ret) |
|
|
|
if splat: |
|
it = ret.items() |
|
else: |
|
it = [(name, ret)] |
|
|
|
for name, ret in it: |
|
if scope is not None: |
|
name = f'{scope}__{name}' |
|
|
|
g[name] = ret |
|
|
|
return None |
|
|
|
if func is not None: |
|
return wrapper(func) |
|
else: |
|
return wrapper |
|
run.f = f |
|
run.g = g |
|
|
|
|
|
#--- Jupyter Magic: Create Module |
|
|
|
try: |
|
get_ipython |
|
except NameError: |
|
pass # Not in Jupyter |
|
else: |
|
@auto.IPython.core.magic.register_cell_magic('mediocreatbest.module') |
|
@auto.IPython.core.magic.register_cell_magic('module') |
|
def _(line: str, cell: str): |
|
__tracebackhide__ = True |
|
|
|
parser = auto.argparse.ArgumentParser() |
|
parser.add_argument('name', type=str) |
|
parser.add_argument('--reuse', action='store_true') |
|
args = parser.parse_args(auto.shlex.split(line)) |
|
|
|
traceback = auto.traceback.extract_stack() |
|
filename = traceback[-4][0] |
|
|
|
code = get_ipython().transform_cell(cell) |
|
code = compile(code, filename, 'exec') |
|
|
|
if args.reuse and args.name in auto.sys.modules: |
|
module = auto.sys.modules[args.name] |
|
else: |
|
module = auto.types.ModuleType(args.name) |
|
|
|
try: |
|
exec(code, module.__dict__) |
|
except Exception as e: |
|
traceback = auto.traceback.extract_tb(auto.sys.exc_info()[2]) |
|
frame = traceback[1] |
|
lineno = frame.lineno |
|
|
|
traceback = auto.traceback.format_exc() |
|
traceback = traceback.replace('<module>', f'<cell line: {lineno}>()') |
|
print(traceback) |
|
|
|
raise e from None |
|
|
|
auto.sys.modules[args.name] = module |
|
if hasattr(auto, args.name): |
|
delattr(auto, args.name) |
|
|
|
return module |
|
|
|
|
|
#--- Jupyter Magic: Embed source code here |
|
|
|
try: |
|
get_ipython |
|
except NameError: |
|
pass # Not in Jupyter |
|
else: |
|
@auto.IPython.core.magic.register_cell_magic('mediocreatbest.embed') |
|
@auto.IPython.core.magic.register_cell_magic('embed') |
|
@auto.IPython.core.magic.register_line_magic('mediocreatbest.embed') |
|
@auto.IPython.core.magic.register_line_magic('embed') |
|
def __embed(line: str, cell: str=None): |
|
import inspect, textwrap |
|
|
|
def embed(arg: str, replace: bool): |
|
arg = get_ipython().ev(arg) |
|
arg = inspect.getsource(arg) |
|
arg = textwrap.dedent(arg) |
|
if replace: |
|
arg = f'# %mediocreatbest.embed {line}\n{arg}' |
|
get_ipython().set_next_input(arg, replace=replace) |
|
else: |
|
get_ipython().set_next_input(f'# %mediocreatbest.embed {line}\n{cell if cell is not None else ""}', replace=True) |
|
get_ipython().set_next_input(arg, replace=False) |
|
|
|
opt_replace = False |
|
|
|
arg = line.strip() |
|
if arg.startswith('--replace'): |
|
opt_replace = True |
|
arg = arg.removeprefix('--replace') |
|
arg = arg.strip() |
|
|
|
embed(arg, replace=opt_replace) |
|
|
|
|
|
|
|
#--- Jupyter Magic: Source a Bash script and turn its environment variables into a cell. |
|
|
|
try: |
|
get_ipython |
|
except NameError: |
|
pass |
|
else: |
|
@auto.IPython.core.magic.register_line_magic('mediocreatbest.source') |
|
@auto.IPython.core.magic.register_line_magic('source') |
|
@auto.IPython.core.magic.register_cell_magic('mediocreatbest.source') |
|
@auto.IPython.core.magic.register_cell_magic('source') |
|
def source(magic_line, magic_cell=None): |
|
import os, subprocess, shlex |
|
|
|
if magic_cell is None or magic_cell == '': |
|
before = os.environ.copy() |
|
|
|
process = subprocess.run([ |
|
'bash', '-c', f'source {magic_line}; export', |
|
], capture_output=True, text=True) |
|
|
|
after = {} |
|
for line in process.stdout.split('\n'): |
|
if line == '': continue |
|
parts = shlex.split(line) |
|
assert parts[0] == 'declare', f'{line=!r}' |
|
assert parts[1] == '-x', f'{line=!r}' |
|
if '=' not in parts[2]: continue |
|
name, value = parts[2].split('=', 1) |
|
|
|
if before.get(name, None) == value: continue |
|
after[name] = value |
|
|
|
magic_cell = f'%%source {magic_line}\n' |
|
magic_cell += f'os.environ |= {{\n' |
|
for name, value in after.items(): |
|
magic_cell += f' {name!r}: ' |
|
if ':' in value: |
|
magic_cell += f'":".join([\n' |
|
for value in value.split(':'): |
|
magic_cell += f' {value!r},\n' |
|
magic_cell += f' ]),\n' |
|
else: |
|
magic_cell += f' {value!r},\n' |
|
magic_cell += f'}}\n' |
|
|
|
get_ipython().set_next_input(magic_cell, replace=True) |
|
|
|
get_ipython().run_cell(magic_cell) |
|
|
|
#--- Jupyter Magic: Assign verbatim contents of a cell as a local variable |
|
|
|
try: |
|
get_ipython |
|
except NameError: |
|
pass |
|
else: |
|
@auto.IPython.core.magic.register_cell_magic('mediocreatbest.verbatim') |
|
@auto.IPython.core.magic.register_cell_magic('verbatim') |
|
def verbatim(line, cell=None): |
|
def verbatim(*, variable: str): |
|
get_ipython().push({ |
|
variable: cell, |
|
}) |
|
|
|
parser = auto.argparse.ArgumentParser() |
|
parser.add_argument('-v', dest='variable', default='verbvatim') |
|
args = vars(parser.parse_args(auto.shlex.split(line))) |
|
|
|
return verbatim(**args) |
|
|
|
|
|
#--- Better singledispatch implementation |
|
# |
|
# functools.singledispatch *only* supports positional arguments. If you want the |
|
# base function to be called with keyword arguments, but still allow the |
|
# dispatching on any positional argument, then you need to work around that |
|
# problem. |
|
|
|
def dispatch(func, /): |
|
dispatch = auto.functools.singledispatch(func) |
|
|
|
@auto.functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
if not args: |
|
return func(**kwargs) |
|
|
|
return dispatch(*args, **kwargs) |
|
|
|
wrapper.register = dispatch.register |
|
wrapper.dispatch = dispatch.dispatch |
|
wrapper.registry = dispatch.registry |
|
wrapper._clear_cache = dispatch._clear_cache |
|
return wrapper |
|
|
|
|
|
#--- Combination Context Manager + Generator/Coroutine |
|
# |
|
# Example: A coroutine/generator for reading from a file. |
|
# |
|
# @contextgenerator |
|
# def foo(): |
|
# with open(__file__, 'rt') as f: |
|
# text = None |
|
# while True: |
|
# size_to_read = yield text |
|
# text = f.read(size_to_read) |
|
# |
|
# with foo() as foo: |
|
# foo.send(64) |
|
# foo.send(64) |
|
|
|
def contextgenerator(func=None, /, *, call=lambda generator: generator): |
|
if func is None: |
|
return auto.functools.partial(contextgenerator, call=call) |
|
|
|
@auto.contextlib.contextmanager |
|
@auto.functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
generator = None |
|
try: |
|
generator = func(*args, **kwargs) |
|
|
|
next(generator) |
|
yield call(generator) |
|
|
|
finally: |
|
if generator is not None: |
|
generator.close() |
|
|
|
return wrapper |
|
|
|
|
|
#--- Simple decorator to immediately call a function |
|
|
|
def immediate(func, /): |
|
return func() |
|
|
|
|
|
#--- Simple decorator for coroutines |
|
|
|
def coroutine(func, /): |
|
@auto.functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
coroutine = func(*args, **kwargs) |
|
next(coroutine) |
|
return coroutine |
|
|
|
return wrapper |
|
|
|
|
|
#--- Automatic namedtuple generation |
|
# It can be helpful to quickly have a class/namedtuple when declaring data |
|
# structures, but it is annoying to have to switch back and forth to write code, |
|
# write a class, then back to writing code that uses the class. |
|
# |
|
# Instead, this helper class lets you write code like the following: |
|
# |
|
# >>> mediocreatbest.namedtuple.name.position('hello', mediocreatbest.namedtuple.x.y.z(0.0, 1.0, 2.0)) |
|
# np(name='hello', position=xyz(x=0., y=1., z=2.)) |
|
|
|
|
|
class NamedTuple: |
|
def __init__(self, names=None): |
|
if names is None: |
|
names = [] |
|
|
|
self.__names = names |
|
self.__class = None |
|
|
|
def __getattr__(self, name): |
|
attr = NamedTuple(self.__names + [name]) |
|
setattr(self, name, attr) |
|
return attr |
|
|
|
def __call__(self, *args, **kwargs): |
|
if self.__class is None: |
|
typename = ''.join(name[0] for name in self.__names) |
|
fields = self.__names |
|
self.__class = auto.collections.namedtuple(typename, fields) |
|
|
|
return self.__class(*args, **kwargs) |
|
|
|
namedtuple = NamedTuple() |
|
|
|
|
|
#--- Jupyter Magic: Lexical Scope Cell |
|
|
|
try: |
|
get_ipython |
|
except NameError: |
|
pass |
|
else: |
|
@auto.IPython.core.magic.register_cell_magic('mediocreatbest.scope') |
|
@auto.IPython.core.magic.register_cell_magic('scope') |
|
def scope(line: str, cell: str): |
|
def scope( |
|
filename: str, |
|
name: str, |
|
inpvars: list[tuple[str, str]], |
|
outvars: list[tuple[str, str]], |
|
skip: bool, |
|
*, |
|
line=line, |
|
cell=cell, |
|
): |
|
cell = get_ipython().transform_cell( |
|
cell, |
|
) |
|
|
|
module = auto.ast.parse( |
|
cell, |
|
filename, |
|
'exec', |
|
) |
|
|
|
# auto.astpretty.pprint(module) |
|
|
|
kwonlyargs = [] |
|
kw_defaults = [] |
|
for inpvar, inpval in inpvars: |
|
kwonlyargs.append(auto.ast.arg( |
|
arg=( |
|
inpvar |
|
), |
|
annotation=None, |
|
type_comment=None, |
|
)) |
|
kw_defaults.append(None) |
|
|
|
value = None |
|
if len(outvars) == 0: |
|
value = auto.ast.Constant( |
|
value=None, |
|
) |
|
|
|
elif len(outvars) == 1: |
|
value = auto.ast.Name( |
|
id=( |
|
outvars[0][0] |
|
), |
|
ctx=auto.ast.Load(), |
|
) |
|
|
|
else: |
|
elts = [] |
|
for outvar, outval in outvars: |
|
elts.append(auto.ast.parse( |
|
outval, |
|
mode='eval', |
|
).body) |
|
|
|
tuple_ = auto.ast.Tuple( |
|
elts=( |
|
elts |
|
), |
|
ctx=auto.ast.Load(), |
|
) |
|
|
|
value = tuple_ |
|
|
|
return_ = auto.ast.Return( |
|
value=( |
|
value |
|
), |
|
) |
|
|
|
body = module.body[:] |
|
body.append(return_) |
|
|
|
function = auto.ast.FunctionDef( |
|
name=( |
|
name |
|
), |
|
args=auto.ast.arguments( |
|
posonlyargs=[], |
|
args=[], |
|
vararg=None, |
|
kwonlyargs=( |
|
kwonlyargs |
|
), |
|
kw_defaults=( |
|
kw_defaults |
|
), |
|
kwarg=None, |
|
defaults=[], |
|
), |
|
body=( |
|
body |
|
), |
|
decorator_list=[], |
|
returns=None, |
|
type_comment=None, |
|
) |
|
|
|
body = [] |
|
body.append(function) |
|
|
|
if not skip: |
|
target = None |
|
|
|
if len(outvars) == 0: |
|
target = auto.ast.Name( |
|
id=( |
|
'_' |
|
), |
|
ctx=auto.ast.Store(), |
|
) |
|
|
|
elif len(outvars) == 1: |
|
target = auto.ast.Name( |
|
id=( |
|
outvars[0][0] |
|
), |
|
ctx=auto.ast.Store(), |
|
) |
|
|
|
else: |
|
elts = [] |
|
for outvar, outval in outvars: |
|
elts.append(auto.ast.Name( |
|
id=( |
|
outvar |
|
), |
|
ctx=auto.ast.Store(), |
|
)) |
|
|
|
tuple_ = auto.ast.Tuple( |
|
elts=( |
|
elts |
|
), |
|
ctx=auto.ast.Store(), |
|
) |
|
|
|
target = tuple_ |
|
|
|
targets = [] |
|
targets.append(target) |
|
|
|
keywords = [] |
|
for inpvar, inpval in inpvars: |
|
keywords.append(auto.ast.keyword( |
|
arg=( |
|
inpvar |
|
), |
|
value=auto.ast.parse( |
|
inpval, |
|
mode='eval', |
|
).body, |
|
)) |
|
|
|
call = auto.ast.Call( |
|
func=auto.ast.Name( |
|
id=( |
|
name |
|
), |
|
ctx=auto.ast.Load(), |
|
), |
|
args=[], |
|
keywords=( |
|
keywords |
|
), |
|
starargs=None, |
|
kwargs=None, |
|
) |
|
|
|
assign = auto.ast.Assign( |
|
targets=( |
|
targets |
|
), |
|
value=( |
|
call |
|
), |
|
type_comment=None, |
|
) |
|
|
|
body.append(assign) |
|
#/if not skip |
|
|
|
module = auto.ast.Module( |
|
body=( |
|
body |
|
), |
|
type_ignores=[], |
|
) |
|
|
|
module = auto.ast.fix_missing_locations(module) |
|
|
|
# pretty print ast tree |
|
# auto.astpretty.pprint(module) |
|
|
|
code = compile(module, filename, 'exec') |
|
get_ipython().ex(code) |
|
|
|
return module |
|
|
|
stack = auto.traceback.extract_stack() |
|
# print(f'{stack[-1].filename = !r}') |
|
# print(f'{stack[-2].filename = !r}') |
|
# print(f'{stack[-3].filename = !r}') |
|
# print(f'{stack[-4].filename = !r}') |
|
# print(f'{stack[-5].filename = !r}') |
|
# print(f'{stack[-6].filename = !r}') |
|
filename = stack[-4].filename |
|
|
|
def csv(s: str) -> list[str]: |
|
return s.split(',') |
|
|
|
def kv(s: str, /) -> tuple[str, str]: |
|
kv = s.split('=', 1) |
|
if len(kv) == 1: |
|
return kv[0], kv[0] |
|
else: |
|
k, v = kv |
|
return k, v |
|
|
|
parser = auto.argparse.ArgumentParser() |
|
parser.add_argument('--name', '-n', default='scope') |
|
parser.add_argument('--inpvars', '-i', type=kv, default=[], action='append') |
|
parser.add_argument('--outvars', '-o', type=kv, default=[], action='append') |
|
parser.add_argument('--skip', action='store_true') |
|
|
|
args = auto.shlex.split(line) |
|
args = vars(parser.parse_args(args)) |
|
|
|
scope(filename=filename, **args) |
|
|
|
|
|
#--- Tkinter Grid |
|
|
|
def Grid(parent: auto.tk.Widget, widgets: list[list[auto.tk.Widget]], /) -> None: |
|
nrows = len(widgets) |
|
assert nrows >= 1 |
|
ncols = len(widgets[0]) |
|
assert ncols >= 1 |
|
for row in widgets: |
|
assert len(row) == ncols |
|
|
|
grid = { |
|
(ri, ci): widgets[ri][ci] |
|
for ri in range(nrows) |
|
for ci in range(ncols) |
|
} |
|
|
|
for ri in range(nrows): |
|
parent.grid_rowconfigure(ri, weight=1) |
|
|
|
for ci in range(ncols): |
|
parent.grid_columnconfigure(ci, weight=1) |
|
|
|
seen = set() |
|
for ri, ci in auto.itertools.product(range(nrows), range(ncols)): |
|
if id(grid[ri, ci]) in seen: |
|
continue |
|
seen.add(id(grid[ri, ci])) |
|
|
|
ri0, ci0 = ri, ci |
|
|
|
# Walk down |
|
for ri in range(ri0, nrows+1): |
|
if grid.get((ri, ci0), None) is not grid[ri0, ci0]: |
|
break |
|
|
|
# Walk right |
|
for ci in range(ci0, ncols+1): |
|
if grid.get((ri0, ci), None) is not grid[ri0, ci0]: |
|
break |
|
|
|
nr = ri - ri0 |
|
nc = ci - ci0 |
|
|
|
grid[ri0, ci0].grid(row=ri0, column=ci0, rowspan=nr, columnspan=nc, sticky='nsew') |
|
|
|
|
|
#--- DotDict: Access Dictionary Members with getattr |
|
|
|
class DotDict(dict): |
|
__getattr__ = dict.__getitem__ |
|
__setattr__ = dict.__setitem__ |
|
__delattr__ = dict.__delitem__ |
|
|
|
|
|
#--- Convert HTML into a static https://itty.bitty.site link |
|
|
|
def IttyBittySite(html: str, /) -> str: |
|
ret = html |
|
ret = bytes(ret, encoding="utf-8") |
|
ret = auto.lzma.compress(ret, format=auto.lzma.FORMAT_ALONE, preset=9) |
|
ret = auto.base64.b64encode(ret) |
|
ret = ret.decode("utf-8") |
|
ret = 'https://itty.bitty.site/#/'+ret |
|
return ret |
|
|
|
|
|
#--- Format Chat |
|
|
|
def FormatChat(*, prompt: dict={}, output: None | dict=None, **kwargs) -> auto.IPython.display.HTML: |
|
prompt = prompt | kwargs |
|
ret = r""" |
|
<script src="https://cdn.tailwindcss.com"></script> |
|
<div class=" |
|
flex |
|
flex-col |
|
w-[640px] |
|
border |
|
"> |
|
{% for message in messages %} |
|
<div |
|
class=" |
|
w-[80%] |
|
p-4 |
|
m-2 |
|
rounded-lg |
|
whitespace-pre-wrap |
|
text-left |
|
{% if message.role == 'system' %} |
|
self-center |
|
bg-gray-100 |
|
{% elif message.role == 'user' %} |
|
self-end |
|
bg-blue-100 |
|
{% elif message.role == 'assistant' %} |
|
self-start |
|
bg-gray-100 |
|
{% endif %} |
|
" |
|
>{{ message.content | replace("\n", "⏎\n") | escape }}</div> |
|
{% endfor %} |
|
</div> |
|
""" |
|
|
|
ret = auto.textwrap.dedent(ret) |
|
ret = auto.jinja2.Template(ret) |
|
ret = ret.render( |
|
messages=( |
|
[*prompt['messages'], output['choices'][0]['message']] |
|
if output is not None else |
|
prompt['messages'] |
|
), |
|
) |
|
url = IttyBittySite(ret) |
|
ret = ( |
|
f'''<a href="{url}" class="underline text-blue-500">[itty bitty site]</a>''' |
|
) + ret |
|
ret = auto.IPython.display.HTML(ret) |
|
return ret |
|
|
|
|
|
#--- Hide the contents of the string when printed in repr format |
|
|
|
class HiddenStr(auto.collections.UserString): |
|
def __repr__(self): |
|
return f'<{self.__class__.__name__}>' |
|
|
|
|
|
#--- Random |
|
|
|
# @title RANDOM { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown RANDOM = ( |
|
#@markdown *, |
|
#@markdown seed: int, |
|
#@markdown ) -> Random |
|
#@markdown |
|
#@markdown Random = ( |
|
#@markdown n: int, |
|
#@markdown ) -> random.Random |
|
#@markdown |
|
#@markdown int(Random(n: int)) = int |
|
#@markdown ``` |
|
|
|
class __Random(auto.random.Random): |
|
def __int__(self): |
|
return int.from_bytes(self.randbytes(4), 'little') |
|
|
|
|
|
def RANDOM(*, seed: int): |
|
def Random(n: int, /): |
|
random = auto.random.Random(seed) |
|
# random.randbytes(n) # XXX(th): This call doesn't work! |
|
for _ in range(n): |
|
random.randbytes(1) |
|
n = random.randbytes(4) |
|
m = int.from_bytes(n, 'little') |
|
random = auto.random.Random(n) |
|
random.np = auto.np.random.default_rng(m) |
|
random.__class__ = __Random |
|
return random |
|
return Random |
|
|
|
def scope(): |
|
display(RANDOM(seed=1337)(0).sample(list(range(100)), 10)) |
|
display(RANDOM(seed=1337)(1).sample(list(range(100)), 10)) |
|
display(RANDOM(seed=1337)(2).sample(list(range(100)), 10)) |
|
display(RANDOM(seed=1337)(2).sample(list(range(100)), 10)) |
|
|
|
display(RANDOM(seed=1337)(0).np.integers(0, 100, 10)) |
|
display(RANDOM(seed=1337)(1).np.integers(0, 100, 10)) |
|
display(RANDOM(seed=1337)(2).np.integers(0, 100, 10)) |
|
display(RANDOM(seed=1337)(2).np.integers(0, 100, 10)) |
|
|
|
Random = RANDOM(seed=1337) |
|
print(f'{int(Random(0)):08x}') |
|
print(f'{int(Random(1)):08x}') |
|
print(f'{int(Random(2)):08x}') |
|
print(f'{int(Random(2)):08x}') |
|
|
|
Random = RANDOM(seed=1337) |
|
print(f'{int(Random(0)):08x}') |
|
print(f'{int(Random(1)):08x}') |
|
print(f'{int(Random(2)):08x}') |
|
print(f'{int(Random(2)):08x}') |
|
|
|
# /scope |
|
|
|
|
|
#--- Textarea |
|
|
|
#@title Textarea { display-mode: "form" } |
|
class Textarea: |
|
def __init__(self, value: str | None = None, /): |
|
self.io = auto.io.StringIO() |
|
if value is not None: |
|
self.io.write(value) |
|
|
|
def _repr_html_(self): |
|
return f'<textarea rows=12 style="width: 90%; margin-left: 5%">{auto.html.escape(self.value)}</textarea>' |
|
|
|
def __enter__(self): |
|
self.stack = auto.contextlib.ExitStack() |
|
self.stack.enter_context( auto.contextlib.redirect_stdout(self.io) ) |
|
return self |
|
|
|
def __exit__(self, *args): |
|
self.stack.close() |
|
self.value = self.io.getvalue() |
|
display(self) |
|
|
|
def scope(): |
|
with Textarea(): |
|
print('hello') |
|
|
|
# /scope |
|
|
|
|
|
#--- Remember |
|
|
|
#@title Remember { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown Remember = ( |
|
#@markdown func: auto.typing.Callable[[], T] | None = None, |
|
#@markdown /, |
|
#@markdown **kwargs: dict, |
|
#@markdown ) -> T |
|
#@markdown ``` |
|
|
|
def _Remember( |
|
func: auto.typing.Callable | None = None, |
|
/, |
|
**kwargs, |
|
): |
|
key = { |
|
**kwargs, |
|
} |
|
key = auto.json.dumps(key, sort_keys=True) |
|
key = auto.hashlib.sha256(key.encode('utf-8')).hexdigest() |
|
|
|
if key not in Remember.cache: |
|
ret = func() |
|
|
|
Remember.cache[key] = ret |
|
|
|
else: |
|
ret = Remember.cache[key] |
|
|
|
return ret |
|
|
|
class __Remember: |
|
def __getitem__(self, func): |
|
return auto.functools.partial( |
|
_Remember, |
|
func, |
|
) |
|
|
|
def __call__(self, func, **kwargs): |
|
return _Remember( |
|
func, |
|
**kwargs, |
|
) |
|
|
|
Remember = __Remember() |
|
|
|
try: |
|
__Remember_cache |
|
except NameError: |
|
__Remember_cache = {} |
|
|
|
Remember.cache = __Remember_cache |
|
# /Remember.cache.clear |
|
|
|
|
|
#--- TEMPLATE |
|
|
|
#@title TEMPLATE { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown TEMPLATE = ( |
|
#@markdown s: str, |
|
#@markdown /, |
|
#@markdown ) -> Template |
|
#@markdown ``` |
|
#@markdown ```python |
|
#@markdown TEMPLATE = ( |
|
#@markdown s: str, |
|
#@markdown /, |
|
#@markdown **context: dict, |
|
#@markdown ) -> str |
|
#@markdown ``` |
|
#@markdown ```python |
|
#@markdown Template = ( |
|
#@markdown **context: dict, |
|
#@markdown ) -> str |
|
#@markdown ``` |
|
|
|
def TEMPLATE(s: str, /, **context): |
|
env = auto.jinja2.Environment( |
|
) |
|
env.globals.update({ |
|
'auto': auto, |
|
'config': config, |
|
}) |
|
|
|
template = env.from_string(s) |
|
|
|
def Template(**context): |
|
return template.render(**context) |
|
|
|
if not context: |
|
return Template |
|
else: |
|
return Template(**context) |
|
|
|
|
|
#--- Export |
|
|
|
#@title Export { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown with Export( |
|
#@markdown name: str, |
|
#@markdown /, |
|
#@markdown mode: auto.typing.Literal['w'] = 'wb', |
|
#@markdown ) -> typing.BinaryIO: |
|
#@markdown ... |
|
#@markdown ``` |
|
#@markdown ```python |
|
#@markdown with Export( |
|
#@markdown name: str, |
|
#@markdown /, |
|
#@markdown mode: auto.typing.Literal['w'] = 'wb', |
|
#@markdown ) -> typing.TextIO: |
|
#@markdown ... |
|
#@markdown ``` |
|
#@markdown ```python |
|
#@markdown Export.clear = ( |
|
#@markdown ) -> None |
|
#@markdown ``` |
|
|
|
@auto.contextlib.contextmanager |
|
def Export( |
|
name: str, |
|
/, |
|
mode: auto.typing.Literal['w', 'wb'] = 'wb', |
|
): |
|
assert mode in ['w', 'wb'] |
|
|
|
path = Export.path |
|
if path.exists(): |
|
old_size = path.stat().st_size |
|
else: |
|
old_size = None |
|
|
|
if path.exists(): |
|
to_delete = [] |
|
with auto.zipfile.ZipFile(path, 'r') as arc: |
|
names = arc.namelist() |
|
if name in names: |
|
to_delete.append(name) |
|
|
|
if to_delete: |
|
auto.subprocess.run([ |
|
'zip', |
|
'-d', |
|
path, |
|
*to_delete, |
|
]) |
|
|
|
with auto.contextlib.ExitStack() as stack: |
|
arc = stack.enter_context(auto.zipfile.ZipFile(path, 'a')) |
|
|
|
f = stack.enter_context(arc.open(name, 'w')) |
|
if mode == 'w': |
|
f = stack.enter_context(auto.io.TextIOWrapper(f)) |
|
elif mode == 'wb': |
|
pass |
|
else: |
|
raise ValueError(f'{mode=}') |
|
|
|
yield f |
|
|
|
new_size = path.stat().st_size |
|
if old_size is not None: |
|
print(f'Added {new_size-old_size:,d} bytes to {path}') |
|
print(f' Total: {new_size:,d} bytes') |
|
else: |
|
print(f'Wrote {new_size:,d} bytes to {path}') |
|
|
|
def __Export_clear(): |
|
if Export.path.exists(): |
|
Export.path.unlink() |
|
|
|
Export.path = auto.pathlib.Path('export.zip') |
|
Export.clear = __Export_clear |
|
# /Export.clear |
|
|
|
|
|
#--- Grow |
|
|
|
#@title Grow { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown # +/- d |
|
#@markdown Grow = ( |
|
#@markdown lo, |
|
#@markdown hi, |
|
#@markdown /, |
|
#@markdown *, |
|
#@markdown d: float, |
|
#@markdown ) -> tuple[float, float] |
|
#@markdown ``` |
|
#@markdown ```python |
|
#@markdown # +/- p * (hi - lo) |
|
#@markdown Grow = ( |
|
#@markdown lo, |
|
#@markdown hi, |
|
#@markdown /, |
|
#@markdown *, |
|
#@markdown p: float, |
|
#@markdown ) -> tuple[float, float] |
|
#@markdown ``` |
|
|
|
def Grow(lo, hi, /, *, d=None, p=None): |
|
assert (d is not None) != (p is not None) |
|
if p is not None: |
|
mi = (lo + hi) / 2 |
|
newlo, newhi = ( |
|
mi - (hi - lo)/2 * (1.0 + p), |
|
mi + (hi - lo)/2 * (1.0 + p), |
|
) |
|
elif d is not None: |
|
newlo, newhi = ( |
|
lo - d, |
|
hi + d, |
|
) |
|
else: |
|
raise NotImplementedError() |
|
|
|
eps = 1e-3 |
|
assert newlo <= lo + eps, \ |
|
f'{newlo=!r} !<= {lo=!r}' |
|
assert newhi >= hi - eps, \ |
|
f'{newhi=!r} !>= {hi=!r}' |
|
|
|
# print(f'Grow [{lo}, {hi}] to [{newlo}, {newhi}] ({d=}, {p=})') |
|
return newlo, newhi |
|
|
|
|
|
#--- Complete |
|
|
|
# @title Complete { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown def Complete( |
|
#@markdown *, |
|
#@markdown config = Complete.config, |
|
#@markdown **prompt, |
|
#@markdown ) -> dict: |
|
#@markdown ... |
|
#@markdown ``` |
|
|
|
def Complete( |
|
*, |
|
config=None, |
|
block = True, |
|
**prompt, |
|
): |
|
if config is None: |
|
config = Complete.config |
|
|
|
prompt.setdefault('model', config.model) |
|
|
|
key = auto.json.dumps(prompt, sort_keys=True) |
|
key = auto.hashlib.sha256(key.encode()).hexdigest() |
|
if key not in Complete.cache: |
|
if not block: |
|
with Complete.lock: |
|
with open(Complete.todo, 'a') as f: |
|
auto.json.dump({ key: prompt }, f) |
|
|
|
return None |
|
|
|
url = config.base_url |
|
if 'prompt' in prompt: |
|
url = auto.urllib.parse.urljoin(url, |
|
'completion', |
|
) |
|
|
|
elif 'messages' in prompt: |
|
url = auto.urllib.parse.urljoin(url, |
|
'v1/chat/completions', |
|
) |
|
|
|
else: |
|
assert False |
|
|
|
with auto.requests.request( |
|
'POST', |
|
url, |
|
headers={ |
|
'Accept': 'application/json', |
|
'Authorization': f'Bearer {config.api_key}', |
|
'Content-Type': 'application/json', |
|
}, |
|
json=prompt, |
|
) as r: |
|
r.raise_for_status() |
|
output = r.json() |
|
|
|
Complete.was_cached = False |
|
Complete.cache[key] = output |
|
|
|
else: |
|
Complete.was_cached = True |
|
output = Complete.cache[key] |
|
|
|
return output |
|
|
|
try: |
|
__Complete_cache |
|
except NameError: |
|
__Complete_cache = {} |
|
# __Complete_cache.clear() |
|
|
|
Complete.cache = __Complete_cache |
|
# Complete.config = config.complete.default |
|
Complete.lock = auto.threading.Lock() |
|
Complete.todo = auto.pathlib.Path('complete.todo.ndjson') |
|
|
|
def scope(): |
|
auto.pprint.pp(Complete( |
|
messages=[ |
|
{ 'role': 'user', 'content': 'What is the capital of France?' }, |
|
], |
|
max_tokens=100, |
|
# config=config.complete.tinyllama, |
|
block=False, |
|
)) |
|
|
|
# /scope |
|
|
|
|
|
#--- PROMPT |
|
|
|
#@title PROMPT { display-mode: "form" } |
|
|
|
#@markdown ```python |
|
#@markdown def PROMPT( |
|
#@markdown s: str, |
|
#@markdown /, |
|
#@markdown ) -> Prompt: |
|
#@markdown ... |
|
#@markdown |
|
#@markdown def Prompt( |
|
#@markdown **query, |
|
#@markdown ) -> dict: |
|
#@markdown ... |
|
#@markdown ``` |
|
|
|
@auto.functools.cache |
|
def PROMPT(s: str, /): |
|
def Prompt(**query): |
|
environment = auto.jinja2.Environment( |
|
loader=auto.jinja2.DictLoader(PROMPT.templates), |
|
undefined=auto.jinja2.StrictUndefined, |
|
) |
|
environment.globals.update({ |
|
'auto': auto, |
|
}) |
|
template = environment.from_string(s) |
|
|
|
_messages = None |
|
def AddMessage(role: str, content: str): |
|
nonlocal _messages |
|
if _messages is None: |
|
_messages = [] |
|
content = content.strip() |
|
_messages.append(dict( |
|
role=role, |
|
content=content, |
|
)) |
|
return f'<Message({role!r}, {content!r})>' |
|
environment.globals |= dict( |
|
user=lambda caller: AddMessage('user', caller()), |
|
assistant=lambda caller: AddMessage('assistant', caller()), |
|
system=lambda caller: AddMessage('system', caller()), |
|
) |
|
|
|
_prompt = None |
|
def SetPrompt(prompt: str): |
|
nonlocal _prompt |
|
_prompt = prompt |
|
return f'<Prompt({prompt!r})>' |
|
environment.globals |= dict( |
|
prompt=lambda caller: SetPrompt(caller()), |
|
) |
|
|
|
_grammar = None |
|
def SetGrammar(grammar: str): |
|
nonlocal _grammar |
|
_grammar = grammar |
|
return f'<Grammar({grammar!r})>' |
|
environment.globals |= dict( |
|
grammar=lambda caller: SetGrammar(caller()), |
|
) |
|
|
|
_parser = None |
|
def SetParser(parser: str): |
|
nonlocal _parser |
|
_parser = parser |
|
return f'<Parser({parser!r})>' |
|
environment.globals |= dict( |
|
parser=lambda caller: SetParser(caller()), |
|
) |
|
|
|
context = {} |
|
context |= query |
|
|
|
_ = template.render( |
|
**context, |
|
) |
|
|
|
prompt = auto.collections.UserDict( |
|
) |
|
|
|
assert (bool(_messages) != bool(_prompt)), \ |
|
f"Exactly one of 'messages' or 'prompt' must be specified." |
|
if _messages is not None: |
|
prompt |= dict( |
|
messages=_messages, |
|
) |
|
elif _prompt is not None: |
|
prompt |= dict( |
|
prompt=_prompt, |
|
) |
|
else: |
|
assert False |
|
|
|
if _grammar is not None: |
|
prompt |= dict( |
|
grammar=_grammar, |
|
) |
|
|
|
if _parser is not None: |
|
prompt.parser = _parser |
|
return prompt |
|
|
|
return Prompt |
|
|
|
PROMPT.templates = {} |
|
|
|
def scope(): |
|
PROMPT.templates['capital'] = r""" |
|
{% macro capital(where) -%} |
|
{% call system() %} |
|
You are a helpful AI assistant. |
|
{% endcall %} |
|
|
|
{% call user() %} |
|
What is the capital of {{ where }}? |
|
{% endcall %} |
|
{% endmacro %} |
|
""" |
|
|
|
display(auto.mediocreatbest.FormatChat( |
|
prompt=PROMPT(r""" |
|
{% from 'capital' import capital %} |
|
{{ capital("France") }} |
|
|
|
{% call grammar() %} |
|
root ::= intro |
|
intro ::= "The capital of {{ where }} is " quoted |
|
quoted ::= "\"" [^"]+ "\"" |
|
{% endcall %} |
|
|
|
{% call parser() %} |
|
"(?P<quoted>[^"]+)" |
|
{% endcall %} |
|
""")( |
|
where='France', |
|
))) |
|
|
|
# /scope |
|
|
|
|
|
#--- ChatML |
|
|
|
def ChatML(messages: list[dict], /): |
|
prompt = ChatML.template.render( |
|
messages=messages, |
|
) |
|
|
|
return prompt |
|
|
|
ChatML.template = auto.jinja2.Environment().from_string(''' |
|
{%- for message in messages -%} |
|
<|im_start|>{{ message.role }} |
|
{%- if message.content %} |
|
{{ message.content }} |
|
<|im_end|> |
|
{% else %} |
|
{% endif -%} |
|
{%- endfor -%} |
|
''') |
|
|
|
|
|
def scope(): |
|
messages = [ |
|
{ |
|
'role': 'user', |
|
'content': 'What is the capital of France?', |
|
}, |
|
{ |
|
'role': 'assistant', |
|
'content': 'Paris', |
|
}, |
|
{ |
|
'role': 'user', |
|
'content': 'What is the capital of Germany?', |
|
}, |
|
{ |
|
'role': 'assistant', |
|
'content': None, |
|
}, |
|
] |
|
# /auto.pprint.pp messages width=144 |
|
|
|
prompt = ChatML(messages) |
|
# /auto.pprint.pp prompt width=144 |
|
|
|
# /scope |
|
|
|
|
|
#--- Embed |
|
|
|
#@title Embed { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown def Embed( |
|
#@markdown *, |
|
#@markdown query: str | list[str] | None = None, |
|
#@markdown passage: str | list[str] | None = None, |
|
#@markdown batch: int | None = None, |
|
#@markdown progress: None | auto.typing.Any = None, |
|
#@markdown ) -> auto.np.ndarray[float]: |
|
#@markdown ... |
|
#@markdown ``` |
|
|
|
def Embed( |
|
*, |
|
query: str | list[str] | None = None, |
|
passage: str | list[str] | None = None, |
|
batch: int | None = None, |
|
progress: None | auto.typing.Any = None, |
|
config: auto.typing.Any | None = None, |
|
) -> auto.np.ndarray[float]: |
|
def Batch(seq, /, *, batch: int | None) -> auto.typing.Iterable[list]: |
|
if batch is None: |
|
yield seq |
|
return |
|
|
|
for i in range(0, len(seq), batch): |
|
yield seq[i:i+batch] |
|
|
|
assert (query is not None) != (passage is not None), \ |
|
f"Exactly one of 'query' or 'passage' must be specified." |
|
|
|
if config is None: |
|
config = Embed.config |
|
|
|
inputs = [] |
|
if query is not None: |
|
if isinstance(query, str): |
|
query = [query] |
|
inputs = [ |
|
f'query: {q}' |
|
for q in query |
|
] |
|
|
|
elif passage is not None: |
|
if isinstance(passage, str): |
|
passage = [passage] |
|
inputs = [ |
|
f'passage: {p}' |
|
for p in passage |
|
] |
|
|
|
needs = [] |
|
for input in inputs: |
|
key = auto.hashlib.sha256(input.encode()).hexdigest() |
|
if key not in Embed.cache: |
|
needs.append(input) |
|
|
|
if needs: |
|
if progress is not None: |
|
progress.reset(len(needs)) |
|
|
|
for needs in Batch(needs, batch=batch): |
|
with auto.requests.request( |
|
'POST', |
|
f'{config.base_url}embeddings', |
|
headers={ |
|
'Accept': 'application/json', |
|
'Authorization': f'Bearer {config.api_key}', |
|
'Content-Type': 'application/json', |
|
}, |
|
json={ |
|
'input': needs, |
|
}, |
|
) as r: |
|
r.raise_for_status() |
|
output = r.json() |
|
|
|
if progress is not None: |
|
progress.update(len(needs)) |
|
|
|
for input, data in zip(needs, output['data']): |
|
embed = data['embedding'] |
|
key = auto.hashlib.sha256(input.encode()).hexdigest() |
|
Embed.cache[key] = embed |
|
|
|
embeds = [] |
|
for input in inputs: |
|
key = auto.hashlib.sha256(input.encode()).hexdigest() |
|
assert key in Embed.cache |
|
embeds.append(Embed.cache[key]) |
|
|
|
embeds = auto.np.stack(embeds, axis=0) |
|
|
|
if len(embeds) == 1: |
|
return embeds[0] |
|
return embeds |
|
|
|
Embed.cache = {} |
|
# Embed.config = config.embed.default |
|
|
|
def scope(): |
|
display(Embed( |
|
query='What is the capital of France?', |
|
)[:10]) |
|
|
|
display(auto.np.dot( |
|
Embed( |
|
query='What is the capital of France?', |
|
), |
|
Embed( |
|
passage='Paris is the capital of France.', |
|
), |
|
)) |
|
|
|
# /scope |
|
|
|
|
|
#--- Enlookup |
|
|
|
#@title Enlookup { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown class Enlookup: |
|
#@markdown ... |
|
#@markdown ``` |
|
|
|
class Enlookup(auto.collections.UserList): |
|
def __init__(self, *args, key=str, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.key = key |
|
|
|
def __getitem__(self, key): |
|
try: |
|
return super().__getitem__(key) |
|
except (KeyError, TypeError): |
|
pass |
|
|
|
for d in self.data: |
|
if self.key(d) == key: |
|
return d |
|
|
|
if isinstance(key, list): |
|
return Enlookup([self[k] for k in key]) |
|
|
|
if isinstance(key, tuple): |
|
key, offset = key |
|
else: |
|
offset = 0 |
|
|
|
embeds = Embed( |
|
passage=list(map(self.key, self.data)), |
|
batch=1_000, |
|
) |
|
|
|
embed = Embed( |
|
query=str(key), |
|
) |
|
|
|
cdist = auto.scipy.spatial.distance.cdist( |
|
embeds, |
|
[embed], |
|
metric='cosine', |
|
) |
|
assert len(cdist.shape) == 2, \ |
|
f'{cdist.shape=} is not length 2' |
|
assert cdist.shape[1] == 1, \ |
|
f'{cdist.shape[1]=} is not 1' |
|
cdist = cdist[:, 0] |
|
assert len(cdist.shape) == 1, \ |
|
f'{cdist.shape=} is not length 1' |
|
|
|
inds = auto.numpy.argsort(cdist) |
|
ind = inds[offset] |
|
|
|
return self.data[ind] |
|
|
|
def scope(): |
|
lookup = Enlookup('''Social Vulnerability - Score |
|
Prevention: Health Insurance: Current lack of health insurance among adults aged 18-64 years |
|
tot_park_area_sqmiles'''.split('\n')) |
|
|
|
for k in ['risk', 'prevention', 'park area']: |
|
auto.pprint.pp({ k: lookup[k] }) |
|
auto.pprint.pp({ (k, 1): lookup[k, 1] }) |
|
|
|
auto.pprint.pp(lookup[['risk', 'prevention', 'park area']]) |
|
auto.pprint.pp(lookup[[('risk', 1), ('prevention', 1), ('park area', 1)]]) |
|
|
|
# /scope |
|
|
|
|
|
#--- Novelty |
|
|
|
#@title Novelty { display-mode: "form" } |
|
#@markdown ```python |
|
#@markdown Novelty = ( |
|
#@markdown document: str, |
|
#@markdown /, |
|
#@markdown *, |
|
#@markdown config = Novelty.config, |
|
#@markdown cache: bool=True, |
|
#@markdown ) -> auto.types.SimpleNamespace( |
|
#@markdown tokens: list[str], |
|
#@markdown scores: list[float], |
|
#@markdown ) |
|
#@markdown |
|
#@markdown Novelty.tokenize = ( |
|
#@markdown document: str, |
|
#@markdown /, |
|
#@markdown *, |
|
#@markdown config = Novelty.config, |
|
#@markdown ) -> list[str] |
|
#@markdown ``` |
|
|
|
|
|
def Novelty(document: str, /, *, config=None, cache: bool=True): |
|
if config is None: |
|
config = Novelty.config |
|
|
|
url = config.base_url |
|
url = auto.urllib.parse.urljoin( |
|
url, |
|
'novelty', |
|
) |
|
|
|
headers = {} |
|
headers['Authorization'] = f'Bearer {config.api_key}' |
|
|
|
json = {} |
|
json['document'] = document |
|
|
|
identity = auto.json.dumps(json, sort_keys=True) |
|
identity = auto.hashlib.sha256(identity.encode('utf-8')).hexdigest() |
|
identity = f'Novelty:{identity}' |
|
|
|
# if identity not in Novelty.cache: |
|
if (not cache) or (identity not in Novelty.cache): |
|
with Novelty.session.request( |
|
'POST', |
|
url, |
|
headers=headers, |
|
json=json, |
|
) as response: |
|
response.raise_for_status() |
|
json = response.json() |
|
|
|
if cache: |
|
Novelty.cache[identity] = auto.json.dumps(json) |
|
|
|
else: |
|
json = auto.json.loads(Novelty.cache[identity]) |
|
|
|
tokens = json.pop('tokens') |
|
scores = json.pop('scores') |
|
assert not json, list(json.keys()) |
|
|
|
novelty = auto.types.SimpleNamespace( |
|
tokens=tokens, |
|
scores=scores, |
|
) |
|
return novelty |
|
|
|
def __Novelty_tokenize(document: str, /, *, config=None): |
|
if config is None: |
|
config = Novelty.config |
|
|
|
url = config.base_url |
|
url = auto.urllib.parse.urljoin( |
|
url, |
|
'tokenize', |
|
) |
|
|
|
headers = {} |
|
headers['Authorization'] = f'Bearer {config.api_key}' |
|
|
|
json = {} |
|
json['document'] = document |
|
|
|
with Novelty.session.request( |
|
'POST', |
|
url, |
|
headers=headers, |
|
json=json, |
|
) as response: |
|
response.raise_for_status() |
|
json = response.json() |
|
|
|
json = auto.copy.copy(json) |
|
tokens = json.pop('tokens') |
|
assert not json, list(json.keys()) |
|
|
|
return tokens |
|
|
|
try: |
|
__Novelty_cache |
|
except NameError: |
|
__Novelty_cache = ( |
|
{} |
|
# auto.shelve.open('Novelty.cache', 'c') |
|
) |
|
|
|
# Novelty.config = config.learned_quality |
|
Novelty.session = auto.requests.Session() |
|
Novelty.cache = __Novelty_cache |
|
Novelty.tokenize = __Novelty_tokenize |
|
|
|
def scope(): |
|
documents = [] |
|
documents += ['The quick brown fox jumps over the lazy dog.'] * 2 |
|
documents += ['What is the meaning of life?'] * 2 |
|
documents += ['What is the purpose of life?'] * 2 |
|
|
|
for document in documents: |
|
novelty = Novelty( |
|
document, |
|
cache=False, |
|
) |
|
auto.pprint.pp(novelty) |
|
|
|
# /scope |
|
|
|
def scope(): |
|
tokens = Novelty.tokenize( |
|
'The quick brown fox jumps over the lazy dog.', |
|
) |
|
auto.pprint.pp(tokens) |
|
|
|
# /scope |
|
|
|
|
|
#--- Clipboard |
|
|
|
#@title Clipboard |
|
def Clipboard( |
|
d: dict[str, str] = {}, |
|
/, |
|
*, |
|
ipynb: None | str | list[str] = None, |
|
) -> auto.IPython.display.HTML: |
|
if ipynb is not None: |
|
if isinstance(ipynb, str): |
|
ipynb = [ipynb] |
|
|
|
# d['text/plain'] = ipynb |
|
d['application/ipynb'] = auto.json.dumps([ |
|
{ |
|
'cell_type': 'code', |
|
'metadata': {}, |
|
'execution_count': None, |
|
'source': ipynb.splitlines(keepends=True), |
|
'outputs': [], |
|
} |
|
for ipynb in ipynb |
|
]) |
|
|
|
js = auto.google.colab.syntax.javascript(TEMPLATE(r""" |
|
(() => { |
|
clipboard.write([ |
|
{%- for k, v in d.items() %} |
|
(() => { |
|
const t = {{ auto.json.dumps(k) | safe }}; |
|
const v = {{ auto.json.dumps(v) | safe }}; |
|
const b = new Blob([v], { type: t }); |
|
const c = new clipboard.ClipboardItem({ [t]: b }); |
|
return c; |
|
})(), |
|
{%- endfor %} |
|
]); |
|
})(); |
|
""", **locals())) |
|
|
|
html = auto.google.colab.syntax.html(TEMPLATE(r""" |
|
<script src="https://unpkg.com/clipboard-polyfill@4.0.2/dist/es5/window-var/clipboard-polyfill.window-var.es5.js"></script> |
|
<button onclick="javascript:{{ js | escape }}">Copy</button> |
|
""", **locals())) |
|
|
|
html = auto.IPython.display.HTML(html) |
|
return html |
|
|
|
def scope(): |
|
display(Clipboard({ |
|
'text/plain': 'Hello, world!', |
|
})) |
|
|
|
display(Clipboard(ipynb=r""" |
|
def scope(): |
|
print("Hello, world!") |
|
|
|
/scope |
|
""")) |
|
|
|
# /scope |
|
|