Skip to content

Instantly share code, notes, and snippets.

@leontrolski
Created December 3, 2021 16:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save leontrolski/5d5aff64d4fa12a5d735dbf75f541ac8 to your computer and use it in GitHub Desktop.
Save leontrolski/5d5aff64d4fa12a5d735dbf75f541ac8 to your computer and use it in GitHub Desktop.
pubsub REST
import base64
from functools import lru_cache
import logging
from time import sleep, time
from typing import Dict, TypedDict, Tuple, List, TypeVar
import requests
from google.auth.transport.requests import AuthorizedSession
import google.auth
from bulb.platform.common.config import get_project_id
from bulb.platform.common.program_handlers import (
HealthStatusHandler,
ShutdownHandler,
)
from event_meta import EventMeta
logger = logging.getLogger(__name__)
T = TypeVar("T")
CYCLE_SECONDS = 1
# PubSub does not allow messages over 10MB. This is slightly lower
# threshhold.
MAX_PAYLOAD_SIZE = 10 * 1000 * 1000
# tmp
get_project_id = lambda: "bulb-platform-rapid-53e7"
class PayloadTooBigException(Exception):
pass
@lru_cache(None)
def get_project_id_and_session() -> Tuple[str, requests.Session]:
creds, _ = google.auth.default()
return get_project_id(), AuthorizedSession(creds)
class RawMessage(TypedDict):
data: str # base64 encoded
attributes: Dict[str, str]
class ReceivedRawMessage(TypedDict):
ackId: str
message: RawMessage
def create_topic(topic: str) -> None:
project_id, session = get_project_id_and_session()
topic = f"projects/{project_id}/topics/{topic}"
url = f"https://pubsub.googleapis.com/v1/{topic}"
session.put(url)
def create_subscription(subscription: str) -> None:
project_id, session = get_project_id_and_session()
subscription = f"projects/{project_id}/subscriptions/{subscription}"
url = f"https://pubsub.googleapis.com/v1/{subscription}"
session.put(url)
def publish(topic: str, messages: List[RawMessage]) -> List[str]:
# TODO: add compress with gzip
# TODO: add async publish with pool?
# TODO: do this *per message*
for message in messages:
if len(base64.b64decode(message.data)) > MAX_PAYLOAD_SIZE:
raise PayloadTooBigException(f"Payload too big")
project_id, session = get_project_id_and_session()
topic = f"projects/{project_id}/topics/{topic}"
url = f"https://pubsub.googleapis.com/v1/{topic}:publish"
resp = session.post(url, json={"messages": messages})
resp.raise_for_status()
return resp.json()["messageIds"]
def pull(subscription: str, max_messages: int = 10) -> List[ReceivedRawMessage]:
project_id, session = get_project_id_and_session()
subscription = f"projects/{project_id}/subscriptions/{subscription}"
url = f"https://pubsub.googleapis.com/v1/{subscription}:pull"
resp = session.post(url, json={"maxMessages": max_messages})
resp.raise_for_status()
return resp.json()["receivedMessages"]
def ack(subscription: str, ack_ids: List[str]) -> None:
project_id, session = get_project_id_and_session()
subscription = f"projects/{project_id}/subscriptions/{subscription}"
url = f"https://pubsub.googleapis.com/v1/{subscription}:acknowledge"
resp = session.post(url, json={"ackIds": ack_ids})
resp.raise_for_status()
def get_topics() -> List[str]:
project_id, session = get_project_id_and_session()
url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/topics"
resp = session.get(url)
resp.raise_for_status()
return [t["name"] for t in resp.json()["topics"]]
def get_subscriptions() -> List[str]:
project_id, session = get_project_id_and_session()
url = f"https://pubsub.googleapis.com/v1/projects/{project_id}/subscriptions"
resp = session.get(url)
resp.raise_for_status()
return [t["name"] for t in resp.json()["subscriptions"]]
from typing import Callable, Sequence
import multiprocessing.pool
F = Callable[[Message], None]
processors: Dict[str, List[F]] = {"foo": [lambda n: sleep(n)]}
def pull_some_messages(subscriptions: Sequence[str]) -> List[Message]:
messages = []
pull_number = 0
while True:
subscription = subscriptions[pull_number % len(subscriptions)]
try:
messages.extend(pull_n_messages(subscription))
except:
logger.exception()
if len(messages) > 50:
return messages
if pull_number > 10:
if messages:
return messages
pull_number = 0
sleep(CYCLE_SECONDS) # so as not to hammer API
pull_number += 1
def run_in_pool(message: Message, fs: List[F]):
try:
for f in fs:
f(message)
except:
add_to_dlq(message)
# ack(message.ack_id)
def run(n: int = 4, timeout: int = 10):
shutdown_handler = ShutdownHandler()
workers = set()
with multiprocessing.pool.ThreadPool(n) as pool:
while True:
for message in pull_some_messages(processors):
if shutdown_handler.should_exit():
for worker in workers:
worker.get(timeout=timeout)
return
# wait until there's a slot
if len(workers) == n:
ready = None
while ready is None:
ready = next((w for w in workers if w.ready()), None)
workers.remove(ready)
print("adding worker")
fs = processors[message.subscription]
workers.add(pool.apply_async(run_in_pool, [message, fs]))
def publish(
data: T,
aggregate_id: str,
territory: territories.Territory,
parent_event_id: Optional[uuid.UUID] = None,
parent_aggregate_id: Optional[str] = None,
correlation_id: Optional[uuid.UUID] = None,
event_id: Optional[uuid.UUID] = None,
timestamp: Optional[datetime] = None,
reason: Optional[constants.EventReason] = None,
sync: bool = True,
) -> str:
if self.topic is None:
raise ValueError("topic must be set on the publisher")
publisher = _get_envelope_publisher(self.topic)
# we remove None values as these have defaults in the EventMeta
kwargs = dict(
aggregate_id=aggregate_id,
territory=territory,
parent_event_id=parent_event_id,
parent_aggregate_id=parent_aggregate_id,
correlation_id=correlation_id,
event_id=event_id,
timestamp=timestamp,
reason=reason,
event_type=self.event_type,
aggregate_type=(
self.event_type.aggregate_type
if self.aggregate_type is None
else self.aggregate_type
),
parent_event_type=self.parent_event_type,
parent_aggregate_type=self.parent_aggregate_type,
)
kwargs = {k: v for k, v in kwargs.items() if v is not None}
event_meta = EventMeta(**kwargs) # type: ignore
if sync:
return publisher.sync_publish(
data=self.schema().dump(data).data, event_meta=event_meta,
)
return publisher.async_publish(
data=self.schema().dump(data).data, event_meta=event_meta,
)
#
@app.publisher(
topic="v1.payment.payment-transactions",
event_type=constants.EventType.PAYMENT_TRANSACTION,
schema=schemas.Transaction,
)
def publish_transaction(data: dict) -> Tuple[EventMeta, types.Transaction]:
transaction = translate(types.Transaction, data)
event_meta = EventMeta(
aggregate_id=aggregate_id,
territory=territory,
parent_event_id=parent_event_id,
parent_aggregate_id=parent_aggregate_id,
correlation_id=correlation_id,
event_id=event_id,
timestamp=timestamp,
reason=reason,
event_type=self.event_type,
aggregate_type=(
self.event_type.aggregate_type
if self.aggregate_type is None
else self.aggregate_type
),
)
return event_meta, transaction
#
#
# @app.subscriber(
# topic="v1.payment.gocardless-webhook-event",
# subscription="v1.payment~v1.payment.gocardless-webhook-event",
# event_type=constants.EventType.PAYMENT_GOCARDLESS_WEBHOOK_EVENT,
# schema=schemas.GocardlessWebhookEvent,
# )
# def handle_payment_events(meta: EventMeta, event: types.GocardlessWebhookEvent):
# pass
#
if __name__ == '__main__':
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment