Skip to content

Instantly share code, notes, and snippets.

@betodealmeida
Created April 13, 2023 21:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save betodealmeida/588289331d49add3dbb937e419d2e8f6 to your computer and use it in GitHub Desktop.
Save betodealmeida/588289331d49add3dbb937e419d2e8f6 to your computer and use it in GitHub Desktop.
import logging
from typing import Any, Dict, Generic, List
from flask_appbuilder import Model
from marshmallow import Schema, fields
from superset import security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.dao import DatasetDAO
from superset.datasets.commands.exceptions import (
DatasetForbiddenError,
DatasetNotFoundError,
DatasetRefreshFailedError,
)
from superset.exceptions import SupersetSecurityException
logger = logging.getLogger(__name__)
class CommandSchema(Schema):
number = fields.Integer()
class NewCommand(Generic[Model]):
schema = CommandSchema()
def __init__(self, deserialized: Dict[str, Any]):
self.models = self.load_models(deserialized)
@classmethod
def from_serialized(self, serialized: Dict[str, Any]) -> 'NewCommand':
"""
Instantiate the command from a serialized dictionary.
"""
deserialized = self.schema.load(serialized)
return NewCommand.from_deserialized(deserialized)
@classmethod
def from_deserialized(self, deserialized: Dict[str, Any]) -> 'NewCommand':
"""
Instantiate the command from a deserialized dictionary.
This is useful when the command is called from an API endpoint, since
the API can deserialize the request body into a dictionary and call the
Marshmallow schema validation.
"""
return NewCommand(deserialized)
def load_models(self, deserialized: Dict[str, Any]) -> List[Model]:
"""
Load models from the deserialized dictionary.
Even the command works on a single model it should return a list with
a single element.
"""
raise NotImplementedError("Subclasses must implement load_models")
def run(self) -> Any:
raise NotImplementedError("Subclasses must implement run")
class RefreshDatasetSchema(Schema):
pk = fields.Integer(required=True)
class RefreshDatasetCommand(NewCommand[SqlaTable]):
schema = RefreshDatasetSchema()
def load_models(self, deserialized: Dict[str, Any]) -> List[SqlaTable]:
model = DatasetDAO.find_by_id(deserialized['pk'])
if not model:
raise DatasetNotFoundError()
# check ownership
try:
security_manager.raise_for_ownership(model)
except SupersetSecurityException as ex:
raise DatasetForbiddenError() from ex
return [model]
def run(self) -> SqlaTable:
model = self.models[0]
try:
model.fetch_metadata()
return model
except Exception as ex:
logger.exception(
"An error occurred while fetching dataset metadata"
)
raise DatasetRefreshFailedError() from ex
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment