Skip to content

Instantly share code, notes, and snippets.

@max-arnold
Last active March 18, 2024 09:28
Show Gist options
  • Save max-arnold/004f0827d563638039bed8adb413df95 to your computer and use it in GitHub Desktop.
Save max-arnold/004f0827d563638039bed8adb413df95 to your computer and use it in GitHub Desktop.
Test Yandex Cloud Functions written in Python and invoked as API Gateway integrations locally
#!/usr/bin/env python
"""
Test Yandex Cloud Functions written in Python and invoked as API Gateway integrations locally.
Heavily based on https://github.com/amancevice/python-lambda-gateway
MIT License
Copyright (c) 2020 Alexander Mancevice
Copyright (c) 2022-2023 Max Arnold
"""
import argparse
import asyncio
import datetime
import importlib
import json
import logging
import os
import random
import re
import socket
import string
import sys
import time
import uuid
from contextlib import contextmanager
from http import server
from urllib import parse
def set_stream_logger(name, level=logging.DEBUG, format_string=None):
"""
Adapted from boto3.set_stream_logger()
"""
if format_string is None:
format_string = "%(addr)s - - [%(asctime)s] %(levelname)s - %(message)s"
logger = logging.getLogger(name)
handler = logging.StreamHandler()
formatter = logging.Formatter(format_string, "%-d/%b/%Y %H:%M:%S")
adapter = logging.LoggerAdapter(logger, dict(addr="::1"))
logger.setLevel(level)
handler.setLevel(level)
handler.setFormatter(formatter)
logger.addHandler(handler)
return adapter
logger = set_stream_logger(__name__)
@contextmanager
def context_start(timeout=None):
"""
Yield mock YCF context object.
"""
yield RuntimeContext(timeout)
class RuntimeContext:
"""
Mock YCF runtime context object.
:param int timeout: YCF timeout in seconds
"""
def __init__(self, timeout=None):
self._start = datetime.datetime.now(datetime.UTC)
self._timeout = timeout or 30
@staticmethod
def _random_id():
return "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(20)
)
@property
def function_name(self):
return self._random_id()
@property
def function_version(self):
return self._random_id()
@property
def function_folder_id(self):
return self._random_id()
@property
def invoked_function_arn(self):
return self.function_name
@property
def memory_limit_in_mb(self):
return 256
@property
def aws_request_id(self):
return self.request_id
@property
def request_id(self):
return str(uuid.uuid1())
@property
def log_group_name(self):
return self._random_id()
@property
def log_stream_name(self):
return self.function_version
@property
def deadline_ms(self):
deadline = self._start + datetime.timedelta(seconds=self._timeout)
return int(time.mktime(deadline.timetuple())) * 1000
@property
def token(self):
return {
"access_token": str(uuid.uuid1()),
"expires_in": 42299,
"token_type": "Bearer",
}
def get_remaining_time_in_millis(self):
"""
Get remaining TTL for YCF context.
"""
delta = datetime.datetime.now(datetime.UTC) - self._start
remaining_time_in_s = self._timeout - delta.total_seconds()
if remaining_time_in_s < 0:
return 0
return remaining_time_in_s * 1000
class EventProxy:
def __init__(self, handler, src_path, url_pattern, operation, timeout=None):
self.handler = handler
self.src_path = src_path
self.url_pattern = url_pattern
self.operation = operation
self.timeout = timeout
def get_handler(self):
"""
Load handler function.
:returns function: YCF handler function
"""
*path, func = self.handler.split(".")
name = ".".join(path)
if not name:
raise ValueError(f"Bad handler signature '{self.handler}'")
try:
if self.src_path not in sys.path:
sys.path.append(os.path.abspath(self.src_path))
module = importlib.import_module(name)
if name in sys.modules:
importlib.reload(module)
handler = getattr(module, func)
return handler
except ModuleNotFoundError as e:
logger.exception("Module import error: %s", e)
raise ValueError(f"Unable to import module '{name}'")
except AttributeError as e:
logger.exception("Handler import error: %s", e)
raise ValueError(f"Handler '{func}' missing on module '{name}'")
def get_httpMethod(self, event):
"""
Helper to get httpMethod from v0.1 or v1.0 events.
"""
if event.get("version") == "1.0":
return event["requestContext"]["httpMethod"]
elif event.get("httpMethod"):
return event["httpMethod"]
raise ValueError( # pragma: no cover
f"Unknown API Gateway payload version: {event.get('version')}"
)
def get_path(self, event):
"""
Helper to get path from v0.1 or v1.0 events.
"""
if event.get("version") == "1.0":
return event["path"]
elif event.get("url"):
return event["url"]
raise ValueError( # pragma: no cover
f"Unknown API Gateway payload version: {event.get('version')}"
)
def invoke(self, event):
with context_start(self.timeout) as context:
logger.info('Invoking "%s"', self.handler)
return asyncio.run(self.invoke_async_with_timeout(event, context))
def match_pattern(self, path):
return re.match(self.url_pattern, path)
async def invoke_async(self, event, context=None):
"""
Wrapper to invoke the YCF handler asynchronously.
:param dict event: YCF event object
:param RuntimeContext context: Mock YCF runtime context
:returns dict: YCF invocation result
"""
httpMethod = self.get_httpMethod(event)
path = self.get_path(event)
# Reject request if not matching the pattern
if not self.match_pattern(path):
err = f"Rejected {path} :: URL pattern is {self.url_pattern}"
logger.error(err)
return self.jsonify(httpMethod, 403, message="Forbidden")
# Get & invoke YCF handler
try:
handler = self.get_handler()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, handler, event, context)
except Exception as err:
logger.exception(err)
message = "Internal server error"
return self.jsonify(httpMethod, 502, message=message)
async def invoke_async_with_timeout(self, event, context=None):
"""
Wrapper to invoke the YCF handler with a timeout.
:param dict event: YCF event object
:param RuntimeContext context: Mock YCF runtime context
:returns dict: YCF invocation result or 408 TIMEOUT
"""
try:
coroutine = self.invoke_async(event, context)
return await asyncio.wait_for(coroutine, self.timeout)
except asyncio.TimeoutError:
httpMethod = self.get_httpMethod(event)
message = "Endpoint request timed out"
return self.jsonify(httpMethod, 504, message=message)
@staticmethod
def jsonify(httpMethod, statusCode, **kwargs):
"""
Convert dict into API Gateway response object.
:params str httpMethod: HTTP request method
:params int statusCode: Response status code
:params dict kwargs: Response object
"""
body = "" if httpMethod in ["HEAD"] else json.dumps(kwargs)
return {
"body": body,
"statusCode": statusCode,
"headers": {
"Content-Type": "application/json",
"Content-Length": len(body),
},
}
class YCFRequestHandler(server.SimpleHTTPRequestHandler):
def do_DELETE(self):
self.invoke("DELETE")
def do_GET(self):
self.invoke("GET")
def do_HEAD(self):
self.invoke("HEAD")
def do_OPTIONS(self):
self.invoke("OPTIONS")
def do_PATCH(self):
self.invoke("PATCH")
def do_POST(self):
self.invoke("POST")
def do_PUT(self):
self.invoke("PUT")
def get_body(self):
"""
Get request body to forward to YCF handler.
"""
try:
content_length = int(self.headers.get("Content-Length"))
return self.rfile.read(content_length).decode()
except TypeError:
return ""
def get_event(self, httpMethod):
"""
Get YCF input event object.
:param str httpMethod: HTTP request method
:return dict: YCF event object
"""
if self.version == "0.1":
return self.get_event_v01(httpMethod)
elif self.version == "1.0":
return self.get_event_v10(httpMethod)
raise ValueError( # pragma: no cover
f"Unknown API Gateway payload version: {self.version}"
)
def get_params(self):
match = self.proxy.match_pattern(self.path)
if match:
return match.groupdict()
else:
return {}
def get_event_v01(self, httpMethod):
"""
Get YCF input event object (v0.1).
:param str httpMethod: HTTP request method
:return dict: YCF event object
"""
url = parse.urlparse(self.path)
path, *_ = url.path.split("?")
params = dict(parse.parse_qsl(url.query))
req_time = datetime.datetime.now(datetime.UTC)
headers = dict(self.headers)
headers["X-Forwarded-For"] = "1.1.1.1"
headers["X-Forwarded-Proto"] = "http"
return {
"httpMethod": httpMethod,
"headers": headers,
"url": path, # "/slug/123?abc=d"
"params": self.get_params(), # {"param": "123"}
"multiValueParams": {
k: [v] for k, v in self.get_params().items()
}, # {"param": ["123"]}
"pathParams": self.get_params(), # {"param": "123"}
"multiValueHeaders": {k: [v] for k, v in headers.items()},
"queryStringParameters": params,
"multiValueQueryStringParameters": {k: [v] for k, v in params.items()},
"requestContext": {
"identity": {
"sourceIp": "1.1.1.1",
"userAgent": "Mozilla/5.0",
},
"httpMethod": httpMethod,
"requestId": str(uuid.uuid1()),
"requestTime": req_time.strftime("%d/%b/%Y:%H:%M:%S +0000"),
"requestTimeEpoch": int(time.mktime(req_time.timetuple())),
},
"body": self.get_body(),
"isBase64Encoded": False,
"path": path, # /slug/{param}
}
def get_event_v10(self, httpMethod):
"""
Get YCF input event object (v1.0).
:param str httpMethod: HTTP request method
:return dict: YCF event object
"""
url = parse.urlparse(self.path)
path, *_ = url.path.split("?")
params = dict(parse.parse_qsl(url.query))
req_time = datetime.datetime.now(datetime.UTC)
headers = dict(self.headers)
headers["X-Forwarded-For"] = "1.1.1.1"
headers["X-Forwarded-Proto"] = "http"
return {
"httpMethod": httpMethod,
"headers": headers,
"multiValueHeaders": {k: [v] for k, v in headers.items()},
"queryStringParameters": dict(parse.parse_qsl(url.query)),
"multiValueQueryStringParameters": {k: [v] for k, v in params.items()},
"requestContext": {
"identity": {
"sourceIp": "1.1.1.1",
"userAgent": "Mozilla/5.0",
},
"httpMethod": httpMethod,
"requestId": str(uuid.uuid1()),
"requestTime": req_time.strftime("%d/%b/%Y:%H:%M:%S +0000"),
"requestTimeEpoch": int(time.mktime(req_time.timetuple())),
},
"version": "1.0",
"resource": path, # /slug/{param}
"path": path, # /slug/123
"pathParameters": self.get_params(), # {"param": "123"}
"body": self.get_body(),
"isBase64Encoded": False,
"parameters": self.get_params(), # {"param": "123"}
"multiValueParameters": {
k: [v] for k, v in self.get_params().items()
}, # {"param": ["123"]}
"operationId": self.proxy.operation,
}
def invoke(self, httpMethod):
"""
Proxy requests to YCF handler
:param dict event: YCF event object
:param RuntimeContext context: Mock YCF runtime context
:returns dict: YCF invocation result
"""
# Get YCF event
event = self.get_event(httpMethod)
cors = {
"access-control-allow-headers": "*",
"access-control-allow-methods": "OPTIONS, GET, HEAD, POST",
"access-control-allow-origin": "*",
}
if httpMethod == "OPTIONS":
status = 200
headers = cors
mvheaders = {}
body = ""
else:
# Get YCF result
res = self.proxy.invoke(event)
# Parse response
status = res.get("statusCode") or 500
headers = res.get("headers") or {}
headers.update(cors)
mvheaders = res.get("multiValueHeaders") or {}
body = res.get("body") or ""
# Send response
self.send_response(status)
for key, val in headers.items():
self.send_header(key, val)
for key, val in mvheaders.items():
for v in val:
self.send_header(key, v)
self.end_headers()
self.wfile.write(body.encode())
@classmethod
def set_proxy(cls, proxy, version):
"""
Set up YCFRequestHandler.
"""
cls.proxy = proxy
cls.version = version
def get_best_family(*address): # pragma: no cover
"""
Helper for Python 3.7 compat.
:params tuple address: host/port tuple
"""
# Python 3.8+
try:
return server._get_best_family(*address)
# Python 3.7 -- taken from http.server._get_best_family() in 3.8
except AttributeError:
infos = socket.getaddrinfo(
*address,
type=socket.SOCK_STREAM,
flags=socket.AI_PASSIVE,
)
family, type, proto, canonname, sockaddr = next(iter(infos))
return family, sockaddr
def run(httpd):
"""
Run API Gateway server.
:param object httpd: ThreadingHTTPServer instance
:param str base_path: REST API base path
"""
host, port = httpd.socket.getsockname()[:2]
url_host = f"[{host}]" if ":" in host else host
sys.stderr.write(
f"Serving HTTP on {host} port {port} " f"(http://{url_host}:{port}) ...\n"
)
try:
httpd.serve_forever()
except KeyboardInterrupt:
sys.stderr.write("\nKeyboard interrupt received, exiting.\n")
finally:
httpd.shutdown()
def export_variables(env_file):
"""
Export environment variables from JSON file
"""
with open(env_file) as json_file:
env_vars = json.loads(json_file.read())
for env_name, env_value in env_vars.items():
os.environ[str(env_name)] = str(env_value)
def get_opts():
"""
Get CLI options.
"""
parser = argparse.ArgumentParser(
description="Start a simple YC API Gateway server",
)
parser.add_argument(
"-e",
"--env",
dest="env",
help="Path to environment JSON file",
metavar="ENV",
)
parser.add_argument(
"-s",
"--src-path",
dest="src_path",
help="Set base path for source code",
metavar="SRC_PATH",
default="",
)
parser.add_argument(
"-o",
"--operation",
dest="operation",
help="Operation ID",
metavar="OPERATION",
default="operation_id",
)
parser.add_argument(
"-u",
"--url-pattern",
dest="url_pattern",
help="URL pattern regex with named parameter groups",
metavar="URL_PATTERN",
default="/",
)
parser.add_argument(
"-b",
"--bind",
dest="bind",
metavar="ADDR",
help="Specify alternate bind address [default: all interfaces]",
)
parser.add_argument(
"-p",
"--port",
dest="port",
default=8000,
help="Specify alternate port [default: 8000]",
type=int,
)
parser.add_argument(
"-t",
"--timeout",
dest="timeout",
help="YCF timeout.",
metavar="SECONDS",
type=int,
)
parser.add_argument(
"-V",
"--payload-version",
choices=["0.1", "1.0"],
default="1.0",
help="API Gateway payload version [default: 1.0]",
)
parser.add_argument(
"HANDLER",
help="YCF handler signature",
)
return parser.parse_args()
def main():
"""
Main entrypoint.
"""
# Parse opts
opts = get_opts()
if opts.env:
export_variables(opts.env)
# Setup handler
address_family, addr = get_best_family(opts.bind, opts.port)
proxy = EventProxy(
opts.HANDLER, opts.src_path, f"^{opts.url_pattern}$", opts.operation, opts.timeout
)
YCFRequestHandler.set_proxy(proxy, opts.payload_version)
server.ThreadingHTTPServer.address_family = address_family
# Start server
with server.ThreadingHTTPServer(addr, YCFRequestHandler) as httpd:
run(httpd)
if __name__ == "__main__": # pragma: no cover
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment