Skip to content

Instantly share code, notes, and snippets.

@Mega-JC
Last active June 5, 2022 01:36
Show Gist options
  • Save Mega-JC/544aa8ba282efd6d9f1e2aad8c3409be to your computer and use it in GitHub Desktop.
Save Mega-JC/544aa8ba282efd6d9f1e2aad8c3409be to your computer and use it in GitHub Desktop.
This is an example of a `discord.ext.commands.Bot` subclass that allows you to pass setup arguments to extension `setup()` functions in a backwards-compatible way.
"""This is an example of a `discord.ext.commands.Bot` subclass that allows you to pass
an `options=` keyword argument to `load_extension`. The argument can optionally receive
a dictionary that is passed to the `setup()` function of an extension. This works by
checking if an extension also implements an `options=` keyword argument and passing the dictionary
to it. Receiving this dictionary is opt-in, if `options=` is not implemented for the `setup()` function
of an extension, the resulting behavior will be as if `options=` were never passed to `load_extension`, thereby
making it backwards-compatible with any older extensions.
Tested with discord.py 2.0.0 alpha.
"""
import importlib
import inspect
import sys
import types
from typing import Any, Dict, Optional
from discord.ext import commands
from discord.ext.commands import errors
__all__ = (
"Bot",
"AutoShardedBot",
)
def _is_submodule(parent: str, child: str) -> bool:
return parent == child or child.startswith(parent + ".")
class ExtBotBase(commands.bot.BotBase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# necessary evil to get extensions dict
self.__extensions: Dict[str, types.ModuleType] = self._BotBase__extensions
async def _call_extension_function(self, function, options: Dict[str, Any]):
sig = None
try:
sig = inspect.signature(function)
except (ValueError, TypeError):
pass
if (
sig is not None
and "options" in sig.parameters
and sig.parameters["options"].kind
in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
):
await function(self, options=options)
else:
await function(self)
async def _load_from_module_spec(
self,
spec: importlib.machinery.ModuleSpec,
key: str,
options: Optional[Dict[str, Any]] = None,
) -> None:
if not isinstance(options, dict):
return await super()._load_from_module_spec(spec, key)
# precondition: key not in self.__extensions
lib = importlib.util.module_from_spec(spec)
sys.modules[key] = lib
try:
spec.loader.exec_module(lib) # type: ignore
except Exception as e:
del sys.modules[key]
raise errors.ExtensionFailed(key, e) from e
try:
setup = getattr(lib, "setup")
except AttributeError:
del sys.modules[key]
raise errors.NoEntryPointError(key)
try:
await self._call_extension_function(setup, options)
except Exception as e:
del sys.modules[key]
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, key)
raise errors.ExtensionFailed(key, e) from e
else:
self.__extensions[key] = lib
async def load_extension(
self,
name: str,
*,
package: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> None:
if not isinstance(options, dict):
return await super().load_extension(name, package=package)
name = self._resolve_name(name, package)
if name in self.__extensions:
raise commands.errors.ExtensionAlreadyLoaded(name)
spec = importlib.util.find_spec(name)
if spec is None:
raise commands.errors.ExtensionNotFound(name)
await self._load_from_module_spec(spec, name, options)
async def reload_extension(
self,
name: str,
*,
package: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> None:
if not isinstance(options, dict):
return await super().reload_extension(name, package=package)
name = self._resolve_name(name, package)
lib = self.__extensions.get(name)
if lib is None:
raise commands.errors.ExtensionNotLoaded(name)
# get the previous module states from sys modules
# fmt: off
modules = {
name: module
for name, module in sys.modules.items()
if _is_submodule(lib.__name__, name)
}
# fmt: on
try:
# Unload and then load the module...
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, name)
await self.load_extension(name, options=options)
except Exception:
# if the load failed, the remnants should have been
# cleaned from the load_extension function call
# so let's load it from our old compiled library.
await self._call_extension_function(lib.setup, options)
self.__extensions[name] = lib
# revert sys.modules back to normal and raise back to caller
sys.modules.update(modules)
raise
async def _call_module_finalizers(
self, lib: types.ModuleType, key: str, options: Optional[Dict[str, Any]] = None
) -> None:
if not isinstance(options, dict):
return await super()._call_module_finalizers(lib, key)
try:
func = getattr(lib, "teardown")
except AttributeError:
pass
else:
try:
await self._call_extension_function(func, options)
except Exception:
pass
finally:
self.__extensions.pop(key, None)
sys.modules.pop(key, None)
name = lib.__name__
for module in list(sys.modules.keys()):
if _is_submodule(name, module):
del sys.modules[module]
async def unload_extension(
self,
name: str,
*,
package: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
) -> None:
if not isinstance(options, dict):
return await super().unload_extension(name, package=package)
name = self._resolve_name(name, package)
lib = self.__extensions.get(name)
if lib is None:
raise errors.ExtensionNotLoaded(name)
await self._remove_module_references(lib.__name__)
await self._call_module_finalizers(lib, name, options)
class ExtBot(commands.Bot, ExtBotBase):
"""A drop-in replacement for `discord.ext.commands.Bot` with more extension-loading features."""
pass
class ExtAutoShardedBot(commands.AutoShardedBot, ExtBotBase):
"""A drop-in replacement for `discord.ext.commands.AutoShardedBot` with more extension-loading features."""
pass
Bot = ExtBot # export with familiar name
"""A drop-in replacement for `discord.ext.commands.Bot` with more features."""
AutoShardedBot = ExtAutoShardedBot
"""A drop-in replacement for `discord.ext.commands.AutoShardedBot` with more features."""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment