Skip to content

Instantly share code, notes, and snippets.

@MatrixManAtYrService
Last active March 1, 2024 20:44
Show Gist options
  • Save MatrixManAtYrService/eedd5e09a3b8ae116514ce38b30e61e8 to your computer and use it in GitHub Desktop.
Save MatrixManAtYrService/eedd5e09a3b8ae116514ce38b30e61e8 to your computer and use it in GitHub Desktop.
libcst vs refactor, an unfair trial

This analysis of libcst vs refactor is incomplete. I probably won't come back and complete it (maybe your case warrants more analysis than my case did).

I had a difficult time with refactor a while back, so I recently tried libcst. I couldn't remember exactly why I had a difficult time, maybe I just didn't sleep well the night before or some kind of non-reason like that. Rather than justify this to myself, I just went with libcst.

I ended up with this:

import sys
from pathlib import Path
from shutil import copytree
from textwrap import dedent
import libcst as cst
def apply_transformers(file: Path, transformers: list[cst.CSTTransformer]):
"""
Applies the given transformers (which mutate code based on its AST) to the
indicated file.
"""
tree = cst.parse_module(file.read_text())
for transformer in transformers:
tree = tree.visit(transformer)
file.write_text(tree.code)
class Prefixer(cst.CSTTransformer):
"""
Adds code to the beginning of a module
"""
def __init__(self, prefix: str):
self.prefix = prefix
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
prefix_module = cst.parse_module(self.prefix)
prefix_statements = list(prefix_module.body)
return updated_node.with_changes(body=prefix_statements + list(updated_node.body))
class Suffixer(cst.CSTTransformer):
"""
Adds code to the end of a module
"""
def __init__(self, suffix: str):
self.suffix = suffix
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
suffix_module = cst.parse_module(self.suffix)
suffix_statements = list(suffix_module.body)
return updated_node.with_changes(body=list(updated_node.body) + suffix_statements)
class AppLifespanReplacer(cst.CSTTransformer):
"""
Used by patch_build_app on src/laminar/app.py
Disables all background services.
"""
def leave_FunctionDef(self, original_node, updated_node):
if original_node.name.value == "app_lifespan":
replacement_code = dedent(
"""
async def app_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
"""
).strip()
replacement_module = cst.parse_module(replacement_code)
replacement_function = replacement_module.body[0]
return replacement_function
return updated_node
class MatchStatementReplacer(cst.CSTTransformer):
"""
Used by patch_build_app on src/laminar/app.py
- initializes the Mocks class
- enables apiserver routes
- enables /mock route
"""
def leave_FunctionDef(self, original_node, updated_node):
if original_node.name.value == "build_app":
new_body = []
for statement in updated_node.body.body:
if isinstance(statement, cst.Expr) and "db_connection_scope_middleware" in statement.code:
continue # don't enable any db stuff
if isinstance(statement, cst.Match):
# replace the match block which checks the AppKind
replacement_module = cst.parse_module(
dedent(
"""
from laminar.apitest.mocks import Mocks
from .apiserver import router as module # type: ignore[no-redef]
from .apitest import router as mock_module
app.state.mocks = Mocks()
app.include_router(mock_module.router)
"""
).strip()
)
new_body.extend(replacement_module.body)
else:
new_body.append(statement)
return updated_node.with_changes(body=updated_node.body.with_changes(body=tuple(new_body)))
return updated_node
def patch_build_app(src: Path):
file = src / "laminar" / "app.py"
apply_transformers(file, [MatchStatementReplacer(), AppLifespanReplacer()])
class MockHandlerReplacer(cst.CSTTransformer):
def __init__(self, path, method, new_handler):
self.path = f'"{path}"'
self.method = method
self.new_handler = new_handler
def leave_FunctionDef(self, original_node, updated_node):
for decorator in original_node.decorators:
if isinstance(decorator.decorator, cst.Call):
call = decorator.decorator
if isinstance(call.func, cst.Attribute) and call.func.attr.value == self.method:
if len(call.args) > 0 and isinstance(call.args[0].value, cst.SimpleString):
arg_value = call.args[0].value.value
if arg_value == self.path:
replacement_module = cst.parse_module(self.new_handler)
return replacement_module.body[0]
return updated_node
dag_structure_replacer = MockHandlerReplacer(
path="/deployments/{deployment_id}/dags/{dag_id}/structure",
method="get",
new_handler=dedent(
"""
@router.get("/deployments/{deployment_id}/dags/{dag_id}/structure")
async def dag_structure(request: Request) -> JSONResponse:
mock = get_mock_handler(
route_path="/deployments/{deployment_id}/dags/{dag_id}/structure",
RequestModel=None,
ResponseModel=schemas.v1.DagStructure
)
return await mock(request=request, log=log)
"""
),
)
def mock_handlers(src: Path):
apply_transformers(
src / "laminar/apiserver/routers/v1/dags.py",
[
Prefixer(
dedent(
"""
from fastapi import Request
from fastapi.responses import JSONResponse
from laminar.apitest.mocks import get_mock_handler
"""
).strip()
),
dag_structure_replacer,
],
)
def patch_dependencies(src: Path):
apply_transformers(
src / "laminar/apiserver/dependencies.py",
[
Prefixer("from laminar.apitest.mocks import Mocks"),
Suffixer(
dedent(
"""
def mocks() -> "Mocks":
from ..app import app
return cast("Mocks", app.state.mocks)
"""
).strip()
),
],
)
def main():
src = Path.cwd() / sys.argv[1]
assert src.exists()
assert src.is_dir()
copytree(Path(__file__).parent, src / "laminar" / "apitest")
patch_build_app(src)
mock_handlers(src)
patch_dependencies(src)
if __name__ == "__main__":
main()

It does what I need it to:

 ❯ poetry run python tools/apitest/install.py ./src/

 ❯ git --no-pager diff src
diff --git a/src/laminar/apiserver/dependencies.py b/src/laminar/apiserver/dependencies.py
index 1306d72..9027dc2 100644
--- a/src/laminar/apiserver/dependencies.py
+++ b/src/laminar/apiserver/dependencies.py
@@ -1,3 +1,4 @@
+from laminar.apitest.mocks import Mocks
 from typing import TYPE_CHECKING, cast

 if TYPE_CHECKING:
@@ -8,3 +9,7 @@ def deployment_service() -> "AstroDeploymentService":
     from ..app import app

     return cast("AstroDeploymentService", app.state.deployment_service)
+def mocks() -> "Mocks":
+    from ..app import app
+
+    return cast("Mocks", app.state.mocks)
diff --git a/src/laminar/apiserver/routers/v1/dags.py b/src/laminar/apiserver/routers/v1/dags.py
index af85135..f748e5e 100644
--- a/src/laminar/apiserver/routers/v1/dags.py
+++ b/src/laminar/apiserver/routers/v1/dags.py
@@ -1,3 +1,6 @@
+from fastapi import Request
+from fastapi.responses import JSONResponse
+from laminar.apitest.mocks import get_mock_handler
 from collections import Counter
 from collections.abc import AsyncIterable
 from typing import Annotated
@@ -35,28 +38,14 @@ async def get_deployment(
     # Make this Async so that the context vars we set persist to the views too.
     with structlog.contextvars.bound_contextvars(deployment=deployment.deployment_id):
         yield deployment
-
-
-@router.get(
-    "/deployments/{deployment_id}/dags/{dag_id}/structure",
-    dependencies=[Depends(read_permission)],
-)
-async def dag_structure(
-    dag_id: str,
-    deployment: Deployment = Depends(get_deployment),
-) -> schemas.v1.DagStructure:
-    async with await deployment.session as session:
-        airflow_models = await session.airflow_models
-
-        serialized_dag: SerializedDag | None = await session.get(airflow_models.SerializedDag, dag_id)
-        if not serialized_dag:
-            raise DagNotFoundError(deployment.deployment_id, dag_id)
-
-        try:
-            return serialized_dag.graph_data()  # type: ignore[no-any-return]
-        except Exception:
-            log.exception("Error generating graph_data", dag_id=dag_id)
-            raise HTTPException(status_code=500, detail="Unable to understand SerializedDag data")
+@router.get("/deployments/{deployment_id}/dags/{dag_id}/structure")
+async def dag_structure(request: Request) -> JSONResponse:
+    mock = get_mock_handler(
+        route_path="/deployments/{deployment_id}/dags/{dag_id}/structure",
+        RequestModel=None,
+        ResponseModel=schemas.v1.DagStructure
+    )
+    return await mock(request=request, log=log)


 @router.get(
diff --git a/src/laminar/app.py b/src/laminar/app.py
index a885c0f..2a6b843 100644
--- a/src/laminar/app.py
+++ b/src/laminar/app.py
@@ -108,55 +108,8 @@ def _deployment_service(
         airflows_mutator=airflows_mutator,
         global_connection_pool=global_connection_pool,
     )
-
-
-@asynccontextmanager
 async def app_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
-    settings = get_settings()
-
-    if settings.mode == Mode.kube:
-        await load_kube_config(settings)
-
-    if settings.enable_metrics and instrumentator is not None:
-        instrumentator.expose(app)
-
-    services: dict[str, Service] = {}
-    async with asyncio.TaskGroup() as tg:
-        if not hasattr(app.state, "connection_pool"):
-            pool_settings = settings.pool
-            pool = GlobalLimiterPool(
-                pool_size=pool_settings.max_connections,
-                max_overflow=pool_settings.overflow,
-                idle_ttl=pool_settings.idle_ttl,
-                timeout=pool_settings.wait_timeout,
-            )
-            app.state.connection_pool = pool
-            tg.create_task(pool.run(), name=f"GlobalLimiterPool-{hex(id(pool))}")
-            services["connection_pool"] = pool  # type: ignore[assignment]
-        if not hasattr(app.state, "deployment_service"):
-            svc = app.state.deployment_service = _deployment_service(settings, app.state.connection_pool)
-            services["astro_deployments"] = svc
-            tg.create_task(svc.run(), name=f"AstroDeploymentService-{hex(id(svc))}")
-        # If we're in the hypervisor, set up the watchers
-        if settings.app_kind == AppKind.hypervisor:
-            services = await set_up_incident_reaper(app, services, tg)
-            services = await set_up_kube_watchers(app, services, tg)
-        if permissions_enforced() and not hasattr(app.state, "jwt_validator"):
-            app.state.jwt_validator = jwt_validator()
-
-        app.state.services = services
-
-        # Since we are passing a lifespan, the startup event for routers don't get fired. TODO: report bug
-        # against FastAPI that adding a lifespan to a router and then doing `app.add_router` doesn't fire the
-        # router's lifespan!
-        await app.router.startup()
-
-        yield
-
-        await app.router.shutdown()
-
-        # Stop all services in parallel
-        await asyncio.gather(*[svc.stop() for svc in services.values()])
+    yield


 async def _add_svc(
@@ -327,19 +280,14 @@ def build_app(settings: Settings | None = None) -> FastAPI:
     app.add_exception_handler(WorkspaceNotFound, workspace_not_found_exception_handler)  # type: ignore[arg-type]
     app.include_router(debug_router)
     app.include_router(health_router)
-
-    match settings.app_kind:
-        case AppKind.hypervisor:
-            from .hypervisor import router as module
-
-            # If we do `router.mount` it doesn't get included when doing `app.include_router`, so we have to
-            # set up the static files here
-            app.mount(
-                "/static", StaticFiles(directory=Path(__file__).parent / "hypervisor/static"), name="static"
-            )
-
-        case AppKind.api_server:
-            from .apiserver import router as module  # type: ignore[no-redef]
+    from laminar.apitest.mocks import Mocks
+
+    from .apiserver import router as module  # type: ignore[no-redef]
+    from .apitest import router as mock_module
+
+    app.state.mocks = Mocks()
+
+    app.include_router(mock_module.router)

     log.info("Starting", app=module.__name__, prefix=app.root_path)
     app.include_router(module.router)

It was only later, when prompted by a colleague, that I thought to myself:

self, why did you introduce libcst when this project was already using refactor?

So I started switching to refactor. Here's what I came up with:

import ast
import sys
from pathlib import Path
from shutil import copytree
from textwrap import dedent
import refactor
from refactor import Action, Replace, Rule
def match_add_prefix(node: ast.AST, prefix_code: str) -> Action | None:
if isinstance(node, ast.Module):
prefix_nodes = ast.parse(prefix_code).body
new_body = prefix_nodes + node.body
return Replace(node, ast.Module(body=new_body, type_ignores=node.type_ignores))
return None
def match_add_suffix(node: ast.AST, prefix_code: str) -> Action | None:
if isinstance(node, ast.Module):
prefix_nodes = ast.parse(prefix_code).body
new_body = node.body + prefix_nodes
return Replace(node, ast.Module(body=new_body, type_ignores=node.type_ignores))
return None
def match_replace_function_body(node: ast.AST, function_name: str, new_body_code: str) -> Action | None:
if isinstance(node, ast.FunctionDef) and node.name == function_name:
new_body = ast.parse(new_body_code).body
return Replace(node, ast.FunctionDef(name=node.name, args=node.args, body=new_body, decorator_list=node.decorator_list, returns=node.returns))
return None
def match_replace_handler_with_mock(node: ast.AST, path: str, method: str, new_handler_code: str) -> Action | None:
if isinstance(node, ast.FunctionDef):
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call) and hasattr(decorator.func, 'attr') and decorator.func.attr == method:
if path in ast.unparse(decorator):
new_handler_node = ast.parse(new_handler_code).body[0]
return Replace(node, new_handler_node)
return None
def apply_rules(file_path: Path, rule_classes: list):
session = refactor.Session(rules=rule_classes)
change = session.run_file(file_path)
if change is not None:
change.apply_diff()
print(f"Applied changes to {file_path}")
else:
print(f"No changes applied to {file_path}")
def mock_handlers(src: Path):
class DagsRouteImports(refactor.Rule):
def match(self, node: ast.AST) -> refactor.Action | None:
return match_add_prefix(
node=node,
prefix_code=dedent(
"""
from fastapi import Request
from fastapi.responses import JSONResponse
from laminar.apitest.mocks import get_mock_handler
"""
).strip()
)
class DagStructureReplacer(refactor.Rule):
def match(self, node: ast.AST) -> refactor.Action | None:
return match_replace_handler_with_mock(
node,
path="/deployments/{deployment_id}/dags/{dag_id}/structure",
method="get",
new_handler_code=dedent(
"""
@router.get("/deployments/{deployment_id}/dags/{dag_id}/structure")
async def dag_structure(request: Request) -> JSONResponse:
mock = get_mock_handler(
route_path="/deployments/{deployment_id}/dags/{dag_id}/structure",
RequestModel=None,
ResponseModel=schemas.v1.DagStructure
)
return await mock(request=request, log=log)
"""
),
)
apply_rules(
src / "laminar/apiserver/routers/v1/dags.py",
[
DagsRouteImports,
DagStructureReplacer
],
)
def patch_dependencies(src: Path):
class DependenciesPrefixer(Rule):
def match(self, node: ast.AST) -> Action | None:
return match_add_prefix(
node=node,
prefix_code=dedent("""
from laminar.apitest.mocks import Mocks
""").strip()
)
class DependenciesSuffixer(Rule):
def match(self, node: ast.AST) -> Action | None:
return match_add_suffix(
node=node,
prefix_code=dedent("""
def get_mocks() -> Mocks:
return Mocks()
""").strip()
)
apply_rules(
src / "laminar/apiserver/dependencies.py",
[DependenciesPrefixer, DependenciesSuffixer]
)
def patch_build_app(src: Path):
class AppLifespanReplacer(Rule):
def match(self, node: ast.AST) -> Action | None:
return match_replace_function_body(
node=node,
function_name="app_lifespan",
new_body_code=dedent("""
async def app_lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
""").strip()
)
class BuildAppPatcher(Rule):
def match(self, node: ast.AST) -> Action | None:
if isinstance(node, ast.FunctionDef) and node.name == "build_app":
new_body = []
for stmt in node.body:
if (isinstance(stmt, ast.Expr) and
isinstance(stmt.value, ast.Call) and
'db_connection_scope_middleware' in ast.unparse(stmt)):
continue
if isinstance(stmt, ast.Match):
new_stmts_code = dedent("""
from laminar.apitest.mocks import Mocks
from .apiserver import router as module # type: ignore[no-redef]
from .apitest import router as mock_module
app.state.mocks = Mocks()
app.include_router(mock_module.router)
""")
new_stmts = ast.parse(new_stmts_code).body
new_body.extend(new_stmts)
else:
new_body.append(stmt)
# Constructing the new FunctionDef node with the modified body
new_node = ast.FunctionDef(
name=node.name,
args=node.args,
body=new_body,
decorator_list=node.decorator_list,
returns=node.returns
)
# Apply fix_missing_locations to the entire new function definition
ast.fix_missing_locations(new_node)
return Replace(node, new_node)
return None
apply_rules(
src / "laminar" / "app.py",
[AppLifespanReplacer, BuildAppPatcher]
)
def main():
src = Path.cwd() / sys.argv[1]
assert src.exists()
assert src.is_dir()
copytree(Path(__file__).parent, src / "laminar" / "apitest")
patch_build_app(src)
mock_handlers(src)
patch_dependencies(src)
if __name__ == "__main__":
main()

There's a few bugs here (in my code, not in refactor). When I run it, it only succeeds at matching one of the three files that I want to mutate.

 ❯ poetry run python tools/apitest/install.py ./src/
Applied changes to /Users/matt/src/laminar/src/laminar/app.py
No changes applied to /Users/matt/src/laminar/src/laminar/apiserver/routers/v1/dags.py
No changes applied to /Users/matt/src/laminar/src/laminar/apiserver/dependencies.py

That's fine, I'm sure I can find and fix those bugs. But there's enough here that I can start to compare the two libraries. But before I do, let's look at a diff:

 ❯ git --no-pager diff src
diff --git a/src/laminar/app.py b/src/laminar/app.py
index a885c0f..28978ef 100644
--- a/src/laminar/app.py
+++ b/src/laminar/app.py
@@ -298,12 +298,10 @@ def custom_openapi(app: FastAPI, settings: Settings) -> dict[str, Any]:
     return app.openapi_schema


-def build_app(settings: Settings | None = None) -> FastAPI:
+def build_app(settings: Settings | None=None) -> FastAPI:
     global app, instrumentator
-
     if settings is None:
         settings = get_settings()
-
     api_url = urlparse(settings.api_url)
     app = FastAPI(
         title="Laminar " + settings.app_kind.value,
@@ -311,37 +309,24 @@ def build_app(settings: Settings | None = None) -> FastAPI:
         root_path=api_url.path,
         lifespan=app_lifespan,
     )
-    app.openapi = functools.partial(custom_openapi, app, settings)  # type: ignore[method-assign]
-
+    app.openapi = functools.partial(custom_openapi, app, settings)
     if settings.enable_metrics:
         instrumentator = build_prom_instrumentator(app, settings)
-
-    app.middleware("http")(db_connection_scope_middleware)
     if not settings.enable_pretty_log:
         app.middleware("http")(structured_exception_logs)
     app.add_middleware(
         AccessLoggerMiddleware, prefix=api_url.path, request_id_header_name=settings.request_id_header_name
     )
-    app.add_exception_handler(ValidationError, pydantic_validation_exception_handler)  # type: ignore[arg-type]
-    app.add_exception_handler(DeploymentNotFound, deployment_not_found_exception_handler)  # type: ignore[arg-type]
-    app.add_exception_handler(WorkspaceNotFound, workspace_not_found_exception_handler)  # type: ignore[arg-type]
+    app.add_exception_handler(ValidationError, pydantic_validation_exception_handler)
+    app.add_exception_handler(DeploymentNotFound, deployment_not_found_exception_handler)
+    app.add_exception_handler(WorkspaceNotFound, workspace_not_found_exception_handler)
     app.include_router(debug_router)
     app.include_router(health_router)
-
-    match settings.app_kind:
-        case AppKind.hypervisor:
-            from .hypervisor import router as module
-
-            # If we do `router.mount` it doesn't get included when doing `app.include_router`, so we have to
-            # set up the static files here
-            app.mount(
-                "/static", StaticFiles(directory=Path(__file__).parent / "hypervisor/static"), name="static"
-            )
-
-        case AppKind.api_server:
-            from .apiserver import router as module  # type: ignore[no-redef]
-
+    from laminar.apitest.mocks import Mocks
+    from .apiserver import router as module
+    from .apitest import router as mock_module
+    app.state.mocks = Mocks()
+    app.include_router(mock_module.router)
     log.info("Starting", app=module.__name__, prefix=app.root_path)
     app.include_router(module.router)
-
     return app

diff noise

What I notice about this diff is that libcst preserved comments and whitespace, while refactor was a bit more heavy handed. This only matters because I have to read my refactored code while developing it, and these omissions distract from the change I'm looking to achieve. Otherwise they're functionally the same.

Confession time: I leaned on ChatGPT to write both of these. I am an expert at neither.

Is it fair to praise libcst and criticize refactor because ChatGPT makes decisions that I like when it uses libcst and it makes decisions that I dislike when it uses refactor? Probably not. Maybe there's a way to make refactor kick out diffs with less noise, but I don't know if it's worth taking the time to investigate.

rules as classes (refactor) vs rules as objects (libcst)

Compared with libcst's interface, which accepts objects, I dislike that refactor's rule interface accepts classes. It makes things awkward when I want to create a generic rule (Prefix, say) and then parameterize it later. libcst conforms to the zen re:

There should be one-- and preferably only one --obvious way to do it

I wanted to parameterize a rule, so I used __init__ to store parameters on its object and then I referenced those parameters in the rule code.

refactor sent me for a bit of a loop here. For a while I was messing around with metaclasses trying to figure out how to provide parameters and get classes. Probably I could've made a factory method for them. I ended up doing what you see above: defining the match methods at module scope and then creating rule classes which just call into them.

project liveness

As leery as I am of letting mere numbers inform an analysis like this one, especially numbers emitted by a microsoft product, there's something to be said for libcst being more actively developed:

Screenshot 2024-03-01 at 1 13 20 PM

So I've got a few weak reasons to prefer libcst over refactor. And then there's this one, which is a questionable:

If I chose libcst, I can move on to the next task, because I tried it first and I'm done with it. If I chose refactor, I still have to finish the job.

I'm not sure how important being consistent everywhere is, I'll have to run it by my teammate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment