|
from contextlib import asynccontextmanager |
|
import logging |
|
import logging.config |
|
import re |
|
import json |
|
from typing import ( |
|
AsyncGenerator, |
|
Awaitable, |
|
Callable, |
|
Iterable, |
|
Mapping, |
|
Optional, |
|
) |
|
|
|
from aiohttp import ( |
|
ClientResponse, |
|
ClientSession, |
|
) |
|
from aiohttp.web import ( |
|
Application, |
|
json_response, |
|
middleware, |
|
Response, |
|
Request |
|
) |
|
from dataclasses import dataclass |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@asynccontextmanager |
|
async def _request(*args, **kwargs) -> AsyncGenerator[ClientResponse, None]: |
|
async with ClientSession() as session: |
|
async with session.request(*args, **kwargs) as resp: |
|
yield resp |
|
|
|
|
|
@dataclass |
|
class Session: |
|
base_url: str |
|
session_id: str |
|
desired_capabilities: dict |
|
capabilities: str |
|
|
|
async def is_alive(self) -> bool: |
|
"""check if this session is still alive""" |
|
url = f'{self.base_url}/wd/hub/session/{self.session_id}/window' |
|
async with _request('GET', url) as resp: |
|
return resp.status == 200 |
|
|
|
def make_response(self) -> Response: |
|
return json_response( |
|
{ |
|
'sessionId': self.session_id, |
|
'value': self.capabilities, |
|
}, |
|
status=200, |
|
) |
|
|
|
@classmethod |
|
async def from_request_and_response( |
|
cls, |
|
base_url: str, |
|
request: Request, |
|
response: Response, |
|
): |
|
desired_capabilities = (await request.json())['desiredCapabilities'] |
|
value = json.loads(response.body)['value'] |
|
capabilities = value['capabilities'] |
|
session_id = value['sessionId'] |
|
return cls( |
|
base_url=base_url, |
|
session_id=session_id, |
|
desired_capabilities=desired_capabilities, |
|
capabilities=capabilities, |
|
) |
|
|
|
|
|
class Sessions(Mapping): |
|
__slots__ = ['sessions'] |
|
|
|
def __init__(self, sessions: Optional[list[Session]] = None) -> None: |
|
self.sessions = { |
|
self._generate_key(sess.desired_capabilities): sess |
|
for sess in sessions or [] |
|
} |
|
|
|
def add(self, session: Session) -> None: |
|
key = self._generate_key(session.desired_capabilities) |
|
self.sessions[key] = session |
|
|
|
def get(self, key: dict) -> Optional[Session]: |
|
if key in self: |
|
return self[key] |
|
else: |
|
return None |
|
|
|
def _generate_key(self, desired_capabilities: dict) -> str: |
|
return json.dumps( |
|
desired_capabilities, |
|
ensure_ascii=True, |
|
sort_keys=True, |
|
) |
|
|
|
def __iter__(self) -> Iterable[Session]: |
|
return iter(self.sessions) |
|
|
|
def __len__(self) -> int: |
|
return len(self.sessions) |
|
|
|
def __contains__(self, desired_capabilities: dict) -> bool: |
|
key = self._generate_key(desired_capabilities) |
|
return key in self.sessions |
|
|
|
def __getitem__(self, desired_capabilities: dict) -> Session: |
|
key = self._generate_key(desired_capabilities) |
|
return self.sessions[key] |
|
|
|
|
|
@dataclass |
|
class WebDriverRequest: |
|
request: Request |
|
|
|
async def desired_capabilities(self): |
|
return (await self.request.json())['desiredCapabilities'] |
|
|
|
@property |
|
def is_new_session_request(self) -> bool: |
|
return ( |
|
self.request.method == 'POST' and |
|
str(self.request.rel_url) == '/wd/hub/session' |
|
) |
|
|
|
@property |
|
def is_delete_session_request(self) -> bool: |
|
return ( |
|
self.request.method == 'DELETE' and |
|
re.match( |
|
r'/wd/hub/session/[0-9a-fA-F]+(/window)?', |
|
str(self.request.rel_url), |
|
) |
|
) |
|
|
|
|
|
def session_cache_middleware_factory(): |
|
sessions = Sessions() |
|
|
|
@middleware |
|
async def _middleware( |
|
request: Request, |
|
handler: Callable[[Request], Awaitable[Response]], |
|
): |
|
wd_req = WebDriverRequest(request) |
|
|
|
if wd_req.is_new_session_request: |
|
session: Optional[Session] = sessions.get(await wd_req.desired_capabilities()) |
|
logger.info('session=%s', session and session.session_id) |
|
|
|
# new or renew a session |
|
if not (session and await session.is_alive()): |
|
resp = await handler(request) |
|
session = await Session.from_request_and_response( |
|
base_url=request.app['x_app_base_url'], |
|
request=request, |
|
response=resp, |
|
) |
|
sessions.add(session) |
|
|
|
return session.make_response() |
|
|
|
# ignore close session request |
|
# TODO: close all sessions on exit |
|
elif wd_req.is_delete_session_request: |
|
logger.info('delete session request is ignored') |
|
return json_response({'value': []}, status=200) |
|
|
|
# pass through |
|
else: |
|
return await handler(request) |
|
|
|
return _middleware |
|
|
|
|
|
async def proxy_handler(request: Request) -> Response: |
|
async with _request( |
|
request.method, |
|
f'{request.app["x_app_base_url"]}{request.rel_url}', |
|
headers=request.headers, |
|
data=await request.read(), |
|
) as resp: |
|
return Response( |
|
headers=resp.headers, |
|
status=resp.status, |
|
reason=resp.reason, |
|
body=await resp.read(), |
|
) |
|
|
|
|
|
def create_app( |
|
base_url: str = 'http://127.0.0.1:4444', |
|
log_level: str = 'DEBUG', |
|
) -> Application: |
|
logging.config.dictConfig({ |
|
'version': 1, |
|
'disable_existing_loggers': False, |
|
'formatters': { |
|
'default': { |
|
'format': '%(asctime)s %(levelname)-8s %(name)-15s %(message)s', |
|
'datefmt': '%Y-%m-%dT%H:%M:%S', |
|
}, |
|
}, |
|
'handlers': { |
|
'console': { |
|
'class': 'logging.StreamHandler', |
|
'formatter': 'default', |
|
'level': log_level, |
|
}, |
|
}, |
|
'root': { |
|
'handlers': ['console'], |
|
'level': 'DEBUG', |
|
}, |
|
'loggers': {}, |
|
}) |
|
app = Application( |
|
middlewares=[ |
|
session_cache_middleware_factory(), |
|
], |
|
) |
|
app['x_app_base_url'] = base_url |
|
app.router.add_route("*", "/{path:.*}", proxy_handler) |
|
return app |