Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Last active October 27, 2020 07:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thehesiod/2e4094a1db1190f7e122e7043f1973a0 to your computer and use it in GitHub Desktop.
Save thehesiod/2e4094a1db1190f7e122e7043f1973a0 to your computer and use it in GitHub Desktop.
Moto Service Helper Class
import asyncio
import functools
import logging
import os
import threading
import socket
import http.server
from typing import Dict, Any, Optional
# Third Party
import aiohttp
import moto.server
import botocore.session
import netifaces
import wrapt
import werkzeug.serving
import aiobotocore.session
_SERVICE_ENDPOINT_TEMPLATE = '{service_name}_mock_endpoint_url'
def get_free_tcp_port(release_socket: bool = False):
sckt = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sckt.bind(('', 0))
addr, port = sckt.getsockname()
if release_socket:
sckt.close()
return port
return sckt, port
# AMI, OSX
_iface_whitelist_prefixes = {'eth', 'en'}
def get_ip_address():
for iface in netifaces.interfaces():
if not any(iface.startswith(prefix) for prefix in _iface_whitelist_prefixes):
continue
addrs: Dict[int, Dict[str, Any]] = netifaces.ifaddresses(iface)
if netifaces.AF_INET in addrs:
assert len(addrs[netifaces.AF_INET]) == 1
ip_address = addrs[netifaces.AF_INET][0]['addr']
return ip_address
assert False
# Enable keep-alive
http.server.BaseHTTPRequestHandler.protocol_version = "HTTP/1.1"
class MotoService:
""" Will Create MotoService.
Service is ref-counted so there will only be one per process. Real Service will
be returned by `__aenter__`."""
_services: Dict[str, Any] = dict() # {name: instance}
def __init__(self, service_name: str, port: Optional[int] = None, set_endpoint_url_env_var: bool = False):
self._service_name = service_name
if port:
self._socket = None
self._port = port
else:
self._socket, self._port = get_free_tcp_port()
self._thread = None
self._logger = logging.getLogger('MotoService')
self._refcount = None
self._ip_address = get_ip_address()
self._server: Optional[werkzeug.serving.ThreadedWSGIServer] = None
self._set_endpoint_url_env_var = set_endpoint_url_env_var
@staticmethod
def get_service_endpoint_url_from_env(service_name: str):
if service_name == "dynamodb2":
service_name = "dynamodb"
env_var = _SERVICE_ENDPOINT_TEMPLATE.format(service_name=service_name)
return os.environ.get(env_var)
@staticmethod
def set_service_endpoint_url_from_env(service_name: str, endpoint_url: str):
env_var = _SERVICE_ENDPOINT_TEMPLATE.format(service_name=service_name)
os.environ[env_var] = endpoint_url
@property
def endpoint_url(self):
return f'http://{self._ip_address}:{self._port}'
def __call__(self, func):
async def wrapper(*args, **kwargs):
await self._start()
try:
result = await func(*args, **kwargs)
finally:
await self._stop()
return result
functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func
return wrapper
async def __aenter__(self):
svc = self._services.get(self._service_name)
if svc is None:
self._services[self._service_name] = self
self._refcount = 1
await self._start()
return self
else:
svc._refcount += 1
return svc
async def __aexit__(self, exc_type, exc_val, exc_tb):
self._refcount -= 1
if self._socket:
self._socket.close()
self._socket = None
if self._refcount == 0:
del self._services[self._service_name]
await self._stop()
def _server_entry(self):
self._main_app = moto.server.DomainDispatcherApplication(moto.server.create_backend_app, service=self._service_name)
self._main_app.debug = True
if self._socket:
self._socket.close() # release right before we use it
self._socket = None
self._server = werkzeug.serving.make_server(self._ip_address, self._port, self._main_app, True)
self._server.serve_forever()
async def _start(self):
self._thread = threading.Thread(target=self._server_entry, daemon=True)
self._thread.start()
async with aiohttp.ClientSession() as session:
for i in range(0, 10):
if not self._thread.is_alive():
break
try:
# we need to bypass the proxies due to monkeypatches
async with session.get(self.endpoint_url + '/static/', timeout=0.5):
pass
break
except (asyncio.TimeoutError, aiohttp.ClientConnectionError):
await asyncio.sleep(0.5)
else:
await self._stop() # pytest.fail doesn't call stop_process
raise Exception(f"Can not start service: {self._service_name}")
if self._set_endpoint_url_env_var:
self.set_service_endpoint_url_from_env(self._service_name, self.endpoint_url)
async def _stop(self):
if self._server:
self._server.shutdown()
self._thread.join()
def _wrapt_boto_create_client(wrapped, instance, args, kwargs):
def unwrap_args(service_name, region_name=None, api_version=None,
use_ssl=True, verify=None, endpoint_url=None,
aws_access_key_id=None, aws_secret_access_key=None,
aws_session_token=None, config=None):
if endpoint_url is None:
endpoint_url = MotoService.get_service_endpoint_url_from_env(service_name)
# https://github.com/spulec/moto/issues/2058
aws_access_key_id = "foobar_key"
aws_secret_access_key = "foobar_secret"
return wrapped(service_name, region_name, api_version, use_ssl, verify,
endpoint_url, aws_access_key_id, aws_secret_access_key,
aws_session_token, config)
return unwrap_args(*args, **kwargs)
# https://github.com/spulec/moto/issues/2058
for key in {'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY'}:
if key in os.environ:
del os.environ[key]
def patch_boto():
"""
Will patch botocore to set endpoint_url to: {SERVICE_NAME}_endpoint_url if
available
"""
if not isinstance(botocore.session.Session.create_client, wrapt.ObjectProxy):
wrapt.wrap_function_wrapper(
'botocore.session',
'Session.create_client',
_wrapt_boto_create_client
)
def unpatch_boto():
if not isinstance(botocore.session.Session.create_client, wrapt.ObjectProxy):
return
botocore.session.Session.create_client = botocore.session.Session.create_client.__wrapped__
def patch_aioboto():
"""
Will patch aiobotocore to set endpoint_url to: {SERVICE_NAME}_endpoint_url if
available
"""
if not isinstance(aiobotocore.session.AioSession.create_client, wrapt.ObjectProxy):
wrapt.wrap_function_wrapper(
'aiobotocore.session',
'AioSession.create_client',
_wrapt_boto_create_client
)
def unpatch_aioboto():
if not isinstance(aiobotocore.session.AioSession.create_client, wrapt.ObjectProxy):
return
aiobotocore.session.AioSession.create_client = aiobotocore.session.AioSession.create_client.__wrapped__
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment