Created
April 14, 2020 20:17
-
-
Save bfouts-osaro/afebd1e8e1b167b902f376621c0ccfaa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from google.protobuf.any_pb2 import Any | |
from google.protobuf.timestamp_pb2 import Timestamp | |
import datetime | |
import grpc | |
import os | |
import psycopg2 | |
import psycopg2.extras | |
import pytest | |
import time | |
import uuid | |
import tyr | |
from tyr.garmr.events.v1 import events_pb2 as tyr_garmr_events | |
from tyr.garmr.events.v1 import events_pb2_grpc as tyr_garmr_grpc | |
from tyr.odin.events.v1 import events_pb2 as tyr_odin_events | |
class EventEmitter: | |
EVENT_TYPES = { | |
'pick': tyr_odin_events.PickEvent, | |
'place': tyr_odin_events.PlaceEvent | |
} | |
def __init__(self, server_ip, server_port): | |
self.channel = grpc.insecure_channel(f'{server_ip}:{server_port}') | |
self.stub = tyr_garmr_grpc.EventWriterStub(self.channel) | |
def send_event(self, event_type, request_id, mission_id, robot, success, message=""): | |
event_class = self.EVENT_TYPES[event_type] | |
timestamp = Timestamp() | |
timestamp.GetCurrentTime() | |
event = event_class( | |
mission_id=mission_id, | |
robot=robot, | |
timestamp=timestamp, | |
message=message, | |
success=success | |
) | |
event_to_send = Any() | |
event_to_send.Pack(event) | |
resp = self.stub.WriteEvent( | |
tyr_garmr_events.WriteEventRequest( | |
ts=timestamp, | |
event=event_to_send, | |
tyr_version=tyr.__version__, | |
request_id=request_id | |
) | |
) | |
if not resp.status.message == "ok": | |
raise Exception(f"Failed to send event: {resp}") | |
#time.sleep(0.0005) | |
return timestamp | |
class DatabaseAdapter: | |
def __init__(self, host, port, database, user, password): | |
self. connection = psycopg2.connect(user = user, | |
password = password, | |
host = host, | |
port = port, | |
database = database) | |
self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.DictCursor) | |
def get_results(self, expected_count, max_wait_time=1.0, wait_time_interval = 0.005): | |
wait_time_remaining = max_wait_time | |
records = [] | |
while wait_time_remaining > 0: | |
time.sleep(wait_time_interval) | |
# POSSIBLE BUG. Events are not persisted in a deterministic order | |
self.cursor.execute("SELECT * from pick_place ORDER BY timestamp") | |
records = [dict(row) for row in self.cursor.fetchall()] | |
record_count = len(records) | |
if record_count > expected_count: | |
raise Exception(f"Too many events returned from the database. Expected {expected_count} but found {record_count}") | |
elif len(records) == expected_count: | |
return records | |
wait_time_remaining -= wait_time_interval | |
raise Exception(f"Too few events returned from the database. Expected {expected_count} but found {record_count} after {max_wait_time} seconds.") | |
def clear_records(self): | |
self.cursor.execute("DELETE FROM pick_place") | |
self.connection.commit() | |
@pytest.fixture | |
def database(): | |
db = DatabaseAdapter( | |
os.environ.get('POSTGRES_HOST'), | |
os.environ.get('POSTGRES_PORT'), | |
os.environ.get('POSTGRES_NAME'), | |
os.environ.get('POSTGRES_USER'), | |
os.environ.get('POSTGRES_PASSWORD')) | |
db.clear_records() | |
return db | |
@pytest.fixture | |
def event_emitter(): | |
return EventEmitter('localhost', 50057) | |
def get_dt_microseconds(dt): | |
epoch = datetime.datetime.utcfromtimestamp(0) | |
return (dt - epoch).total_seconds() * 1000000 | |
def test_stuff(): | |
def test_successful_pick_place(database, event_emitter): | |
pick_request_id = str(uuid.uuid4()) | |
pick_timestamp = event_emitter.send_event('pick', pick_request_id, "pick", "successful_pick", True) | |
place_request_id = str(uuid.uuid4()) | |
place_timestamp = event_emitter.send_event('place', place_request_id, "place", "successful_place", True) | |
records = database.get_results(2) | |
assert records[0]['mission_id'] == "pick" | |
assert not records[0]['pick_failure'] | |
assert not records[0]['pick_failure_reason'] | |
assert records[0]['pick_success'] == 1 | |
assert not records[0]['place_failure'] | |
assert not records[0]['place_failure_reason'] | |
assert records[0]['request_id'] == pick_request_id | |
assert records[0]['robot'] == "successful_pick" | |
assert get_dt_microseconds(records[0]['timestamp']) == pick_timestamp.ToMicroseconds() | |
assert records[1]['mission_id'] == "place" | |
assert not records[1]['pick_failure'] | |
assert not records[1]['pick_failure_reason'] | |
assert not records[1]['pick_success'] | |
assert not records[1]['place_failure'] | |
assert not records[1]['place_failure_reason'] | |
assert records[1]['request_id'] == place_request_id | |
assert records[1]['robot'] == "successful_place" | |
assert get_dt_microseconds(records[1]['timestamp']) == place_timestamp.ToMicroseconds() | |
def test_place_failed(database, event_emitter): | |
pick_request_id = str(uuid.uuid4()) | |
pick_timestamp = event_emitter.send_event('pick', pick_request_id, "pick", "successful_pick", True) | |
place_request_id = str(uuid.uuid4()) | |
place_timestamp = event_emitter.send_event('place', place_request_id, "place", "failed_place", False, "The placement has failed") | |
records = database.get_results(2) | |
assert records[0]['mission_id'] == "pick" | |
assert not records[0]['pick_failure'] | |
assert not records[0]['pick_failure_reason'] | |
assert records[0]['pick_success'] == 1 | |
assert not records[0]['place_failure'] | |
assert not records[0]['place_failure_reason'] | |
assert records[0]['request_id'] == pick_request_id | |
assert records[0]['robot'] == "successful_pick" | |
assert get_dt_microseconds(records[0]['timestamp']) == pick_timestamp.ToMicroseconds() | |
assert records[1]['mission_id'] == "place" | |
assert not records[1]['pick_failure'] | |
assert not records[1]['pick_failure_reason'] | |
assert not records[1]['pick_success'] | |
assert records[1]['place_failure'] == 1 | |
assert records[1]['place_failure_reason'] == "The placement has failed" | |
assert records[1]['request_id'] == place_request_id | |
assert records[1]['robot'] == "failed_place" | |
assert get_dt_microseconds(records[1]['timestamp']) == place_timestamp.ToMicroseconds() | |
def test_place_pick(database, event_emitter): | |
pick_request_id = str(uuid.uuid4()) | |
pick_timestamp = event_emitter.send_event('pick', pick_request_id, "pick", "failed_pick", False, "The pick has failed") | |
records = database.get_results(1) | |
assert records[0]['mission_id'] == "pick" | |
assert records[0]['pick_failure'] == 1 | |
assert records[0]['pick_failure_reason'] == "The pick has failed" | |
assert records[0]['pick_success'] == 0 | |
assert not records[0]['place_failure'] | |
assert not records[0]['place_failure_reason'] | |
assert records[0]['request_id'] == pick_request_id | |
assert records[0]['robot'] == "failed_pick" | |
assert get_dt_microseconds(records[0]['timestamp']) == pick_timestamp.ToMicroseconds() | |
def test_burst(database, event_emitter): | |
count = 2 | |
for i in range(count): | |
pick_request_id = str(uuid.uuid4()) | |
pick_timestamp = event_emitter.send_event('pick', pick_request_id, "pick", "failed_pick", False, "The pick has failed") | |
records = database.get_results(count) | |
assert len(records) == count |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment