Skip to content

Instantly share code, notes, and snippets.

@pjxiao
Created April 17, 2021 09:51
Show Gist options
  • Save pjxiao/a2df934aad6fe829f20b5a4d4dfe5155 to your computer and use it in GitHub Desktop.
Save pjxiao/a2df934aad6fe829f20b5a4d4dfe5155 to your computer and use it in GitHub Desktop.

A web driver API proxy

A proxy app to reuse a web driver session during development. This reduces browswer launch time and let you run your script on the running session.

Quick start

> pip install aiohttp aiohttp-devtools
> aiohttp-devtools runserver --livereload -p 8080 --root . webdriver_proxy.py
from contextlib import asynccontextmanager
import logging
import logging.config
import re
import json
from typing import (
AsyncGenerator,
Awaitable,
Callable,
Iterable,
Mapping,
Optional,
)
from aiohttp import (
ClientResponse,
ClientSession,
)
from aiohttp.web import (
Application,
json_response,
middleware,
Response,
Request
)
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@asynccontextmanager
async def _request(*args, **kwargs) -> AsyncGenerator[ClientResponse, None]:
async with ClientSession() as session:
async with session.request(*args, **kwargs) as resp:
yield resp
@dataclass
class Session:
base_url: str
session_id: str
desired_capabilities: dict
capabilities: str
async def is_alive(self) -> bool:
"""check if this session is still alive"""
url = f'{self.base_url}/wd/hub/session/{self.session_id}/window'
async with _request('GET', url) as resp:
return resp.status == 200
def make_response(self) -> Response:
return json_response(
{
'sessionId': self.session_id,
'value': self.capabilities,
},
status=200,
)
@classmethod
async def from_request_and_response(
cls,
base_url: str,
request: Request,
response: Response,
):
desired_capabilities = (await request.json())['desiredCapabilities']
value = json.loads(response.body)['value']
capabilities = value['capabilities']
session_id = value['sessionId']
return cls(
base_url=base_url,
session_id=session_id,
desired_capabilities=desired_capabilities,
capabilities=capabilities,
)
class Sessions(Mapping):
__slots__ = ['sessions']
def __init__(self, sessions: Optional[list[Session]] = None) -> None:
self.sessions = {
self._generate_key(sess.desired_capabilities): sess
for sess in sessions or []
}
def add(self, session: Session) -> None:
key = self._generate_key(session.desired_capabilities)
self.sessions[key] = session
def get(self, key: dict) -> Optional[Session]:
if key in self:
return self[key]
else:
return None
def _generate_key(self, desired_capabilities: dict) -> str:
return json.dumps(
desired_capabilities,
ensure_ascii=True,
sort_keys=True,
)
def __iter__(self) -> Iterable[Session]:
return iter(self.sessions)
def __len__(self) -> int:
return len(self.sessions)
def __contains__(self, desired_capabilities: dict) -> bool:
key = self._generate_key(desired_capabilities)
return key in self.sessions
def __getitem__(self, desired_capabilities: dict) -> Session:
key = self._generate_key(desired_capabilities)
return self.sessions[key]
@dataclass
class WebDriverRequest:
request: Request
async def desired_capabilities(self):
return (await self.request.json())['desiredCapabilities']
@property
def is_new_session_request(self) -> bool:
return (
self.request.method == 'POST' and
str(self.request.rel_url) == '/wd/hub/session'
)
@property
def is_delete_session_request(self) -> bool:
return (
self.request.method == 'DELETE' and
re.match(
r'/wd/hub/session/[0-9a-fA-F]+(/window)?',
str(self.request.rel_url),
)
)
def session_cache_middleware_factory():
sessions = Sessions()
@middleware
async def _middleware(
request: Request,
handler: Callable[[Request], Awaitable[Response]],
):
wd_req = WebDriverRequest(request)
if wd_req.is_new_session_request:
session: Optional[Session] = sessions.get(await wd_req.desired_capabilities())
logger.info('session=%s', session and session.session_id)
# new or renew a session
if not (session and await session.is_alive()):
resp = await handler(request)
session = await Session.from_request_and_response(
base_url=request.app['x_app_base_url'],
request=request,
response=resp,
)
sessions.add(session)
return session.make_response()
# ignore close session request
# TODO: close all sessions on exit
elif wd_req.is_delete_session_request:
logger.info('delete session request is ignored')
return json_response({'value': []}, status=200)
# pass through
else:
return await handler(request)
return _middleware
async def proxy_handler(request: Request) -> Response:
async with _request(
request.method,
f'{request.app["x_app_base_url"]}{request.rel_url}',
headers=request.headers,
data=await request.read(),
) as resp:
return Response(
headers=resp.headers,
status=resp.status,
reason=resp.reason,
body=await resp.read(),
)
def create_app(
base_url: str = 'http://127.0.0.1:4444',
log_level: str = 'DEBUG',
) -> Application:
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'default': {
'format': '%(asctime)s %(levelname)-8s %(name)-15s %(message)s',
'datefmt': '%Y-%m-%dT%H:%M:%S',
},
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'formatter': 'default',
'level': log_level,
},
},
'root': {
'handlers': ['console'],
'level': 'DEBUG',
},
'loggers': {},
})
app = Application(
middlewares=[
session_cache_middleware_factory(),
],
)
app['x_app_base_url'] = base_url
app.router.add_route("*", "/{path:.*}", proxy_handler)
return app
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment