Created
June 24, 2022 17:58
-
-
Save raeidsaqur/174d745cf4d3df60d0623f5cf07edac8 to your computer and use it in GitHub Desktop.
Create a WSGI server to allow declarative headless rendering for unity based gaming environments.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## For general architecture see: https://silverweed.github.io/assets/docs/distributed_rendering_in_vulkan.pdf | |
import Xlib.display | |
import glob | |
import warnings | |
import os | |
import ctypes.util | |
import xml.etree.ElementTree | |
class Request: | |
def __init__(self, system, width, height, x_display, headless): | |
self.system = system | |
self.width = width | |
self.height = height | |
self.x_display = x_display | |
self.headless = headless | |
class BasePlatform: | |
enabled = True | |
@classmethod | |
def validate(cls, r): | |
return [] | |
@classmethod | |
def dependency_instructions(cls, request): | |
return None | |
@classmethod | |
def is_valid(cls, request): | |
return len(cls.validate(request)) == 0 | |
@classmethod | |
def name(cls): | |
return cls.__name__ | |
@classmethod | |
def launch_env(cls, width, height, x_display): | |
return {} | |
class BaseLinuxPlatform(BasePlatform): | |
@classmethod | |
def executable_path(cls, base_dir, name): | |
return os.path.join(base_dir, name) | |
@classmethod | |
def old_executable_path(cls, base_dir, name): | |
return cls.executable_path(base_dir, name) | |
class Linux64(BaseLinuxPlatform): | |
@classmethod | |
def dependency_instructions(cls, request): | |
message = "Linux64 requires a X11 server to be running with GLX. " | |
displays = cls._valid_x_displays(request.width, request.height) | |
if displays: | |
message += "The following valid displays were found %s" % ( | |
", ".join(displays) | |
) | |
else: | |
message += "If you have a NVIDIA GPU, please run: sudo ai2thor-xorg start" | |
return message | |
@classmethod | |
def _select_x_display(cls, width, height): | |
valid_displays = cls._valid_x_displays(width, height) | |
if valid_displays: | |
return valid_displays[0] | |
else: | |
return None | |
@classmethod | |
def launch_env(cls, width, height, x_display): | |
env = dict(DISPLAY=x_display) | |
if env["DISPLAY"] is None: | |
env["DISPLAY"] = cls._select_x_display(width, height) | |
return env | |
@classmethod | |
def _validate_screen(cls, display_screen_str, width, height): | |
errors = [] | |
try: | |
disp_screen = Xlib.display.Display( | |
display_screen_str | |
) # display_screen_str will have the format ":0.1" | |
screen_parts = display_screen_str.split(".") | |
if len(screen_parts) > 1: | |
# this Xlib.display will find a valid screen if an | |
# invalid screen was passed in (e.g. :0.9999999 -> :0.1) | |
if screen_parts[1] != str(disp_screen.get_default_screen()): | |
errors.append( | |
"Invalid display, non-existent screen: %s" % display_screen_str | |
) | |
if "GLX" not in disp_screen.list_extensions(): | |
errors.append( | |
"Display %s does not have the GLX extension loaded. GLX is required by Unity3D." | |
% display_screen_str | |
) | |
if ( | |
disp_screen.screen()["width_in_pixels"] < width | |
or disp_screen.screen()["height_in_pixels"] < height | |
): | |
errors.append( | |
"Display %s does not have a large enough resolution for the target resolution: %sx%s vs. %sx%s" | |
% ( | |
display_screen_str, | |
width, | |
height, | |
disp_screen.screen()["width_in_pixels"], | |
disp_screen.screen()["height_in_pixels"], | |
) | |
) | |
if disp_screen.screen()["root_depth"] != 24: | |
errors.append( | |
"Display %s does not have a color depth of 24: %s" | |
% (display_screen_str, disp_screen.screen()["root_depth"]) | |
) | |
except (Xlib.error.DisplayNameError, Xlib.error.DisplayConnectionError) as e: | |
errors.append( | |
"Invalid display: %s. Failed to connect %s " % (display_screen_str, e) | |
) | |
return errors | |
@classmethod | |
def _is_valid_screen(cls, display_screen_str, width, height): | |
return len(cls._validate_screen(display_screen_str, width, height)) == 0 | |
@classmethod | |
def _valid_x_displays(cls, width, height): | |
open_display_strs = [ | |
int(os.path.basename(s)[1:]) for s in glob.glob("/tmp/.X11-unix/X[0-9]*") | |
] | |
valid_displays = [] | |
for display_str in open_display_strs: | |
try: | |
disp = Xlib.display.Display(":%s" % display_str) | |
for screen in range(0, disp.screen_count()): | |
disp_screen_str = ":%s.%s" % (display_str, screen) | |
if cls._is_valid_screen(disp_screen_str, width, height): | |
valid_displays.append(disp_screen_str) | |
except Xlib.error.DisplayConnectionError as e: | |
warnings.warn( | |
"could not connect to X Display: %s, %s" % (display_str, e) | |
) | |
return valid_displays | |
@classmethod | |
def validate(cls, request): | |
if request.headless: | |
return [] | |
elif request.x_display: | |
return cls._validate_screen( | |
request.x_display, request.width, request.height | |
) | |
elif cls._select_x_display(request.width, request.height) is None: | |
return ["No valid X display found"] | |
else: | |
return [] | |
class OSXIntel64(BasePlatform): | |
@classmethod | |
def old_executable_path(cls, base_dir, name): | |
return os.path.join(base_dir, name + ".app", "Contents/MacOS", name) | |
@classmethod | |
def executable_path(cls, base_dir, name): | |
plist = cls.parse_plist(base_dir, name) | |
return os.path.join( | |
base_dir, name + ".app", "Contents/MacOS", plist["CFBundleExecutable"] | |
) | |
@classmethod | |
def parse_plist(cls, base_dir, name): | |
plist_path = os.path.join(base_dir, name + ".app", "Contents/Info.plist") | |
with open(plist_path) as f: | |
plist = f.read() | |
root = xml.etree.ElementTree.fromstring(plist) | |
keys = [x.text for x in root.findall("dict/key")] | |
values = [x.text for x in root.findall("dict/string")] | |
return dict(zip(keys, values)) | |
class CloudRendering(BaseLinuxPlatform): | |
enabled = True | |
@classmethod | |
def dependency_instructions(cls, request): | |
return "CloudRendering requires libvulkan1. Please install by running: sudo apt-get -y libvulkan1" | |
@classmethod | |
def failure_message(cls): | |
pass | |
@classmethod | |
def validate(cls, request): | |
if ctypes.util.find_library("vulkan") is not None: | |
return [] | |
else: | |
return ["Vulkan API driver missing."] | |
class WebGL(BasePlatform): | |
pass | |
def select_platforms(request): | |
candidates = [] | |
system_platform_map = dict(Linux=(Linux64, CloudRendering), Darwin=(OSXIntel64,)) | |
for p in system_platform_map.get(request.system, ()): | |
if not p.enabled: | |
continue | |
# | |
# if p == CloudRendering and request.x_display is not None: | |
# continue | |
candidates.append(p) | |
return candidates | |
STR_PLATFORM_MAP = dict( | |
CloudRendering=CloudRendering, Linux64=Linux64, OSXIntel64=OSXIntel64, WebGL=WebGL | |
) | |
### ------------------------------------------------------------------ ### | |
# Copyright Allen Institute for Artificial Intelligence 2017 | |
""" | |
ai2thor.server | |
Handles all communication with Unity through a Flask service. Messages | |
are sent to the controller using a pair of request/response queues. | |
""" | |
import ai2thor.server | |
import json | |
import msgpack | |
import os | |
import tempfile | |
from ai2thor.exceptions import UnityCrashException | |
from enum import IntEnum, unique | |
from collections import defaultdict | |
import struct | |
# FifoFields | |
@unique | |
class FieldType(IntEnum): | |
METADATA = 1 | |
ACTION = 2 | |
ACTION_RESULT = 3 | |
RGB_IMAGE = 4 | |
DEPTH_IMAGE = 5 | |
NORMALS_IMAGE = 6 | |
FLOWS_IMAGE = 7 | |
CLASSES_IMAGE = 8 | |
IDS_IMAGE = 9 | |
THIRD_PARTY_IMAGE = 10 | |
METADATA_PATCH = 11 | |
THIRD_PARTY_DEPTH = 12 | |
THIRD_PARTY_NORMALS = 13 | |
THIRD_PARTY_IMAGE_IDS = 14 | |
THIRD_PARTY_CLASSES = 15 | |
THIRD_PARTY_FLOW = 16 | |
END_OF_MESSAGE = 255 | |
class FifoServer(ai2thor.server.Server): | |
header_format = "!BI" | |
header_size = struct.calcsize(header_format) | |
field_types = {f.value: f for f in FieldType} | |
server_type = "FIFO" | |
def __init__( | |
self, | |
width, | |
height, | |
depth_format=ai2thor.server.DepthFormat.Meters, | |
add_depth_noise=False, | |
): | |
self.tmp_dir = tempfile.TemporaryDirectory() | |
self.server_pipe_path = os.path.join(self.tmp_dir.name, "server.pipe") | |
self.client_pipe_path = os.path.join(self.tmp_dir.name, "client.pipe") | |
self.server_pipe = None | |
self.client_pipe = None | |
self.raw_metadata = None | |
self.raw_files = None | |
self._last_action_message = None | |
# allows us to map the enum to form field names | |
# for backwards compatibility | |
# this can be removed when the wsgi server is removed | |
self.form_field_map = { | |
FieldType.RGB_IMAGE: "image", | |
FieldType.DEPTH_IMAGE: "image_depth", | |
FieldType.CLASSES_IMAGE: "image_classes", | |
FieldType.IDS_IMAGE: "image_ids", | |
FieldType.NORMALS_IMAGE: "image_normals", | |
FieldType.FLOWS_IMAGE: "image_flow", | |
FieldType.THIRD_PARTY_IMAGE: "image-thirdParty-camera", | |
FieldType.THIRD_PARTY_DEPTH: "image_thirdParty_depth", | |
FieldType.THIRD_PARTY_NORMALS: "image_thirdParty_normals", | |
FieldType.THIRD_PARTY_IMAGE_IDS: "image_thirdParty_image_ids", | |
FieldType.THIRD_PARTY_CLASSES: "image_thirdParty_classes", | |
FieldType.THIRD_PARTY_FLOW: "image_thirdParty_flow", | |
} | |
self.image_fields = { | |
FieldType.IDS_IMAGE, | |
FieldType.CLASSES_IMAGE, | |
FieldType.FLOWS_IMAGE, | |
FieldType.NORMALS_IMAGE, | |
FieldType.DEPTH_IMAGE, | |
FieldType.RGB_IMAGE, | |
FieldType.THIRD_PARTY_IMAGE, | |
FieldType.THIRD_PARTY_DEPTH, | |
FieldType.THIRD_PARTY_NORMALS, | |
FieldType.THIRD_PARTY_IMAGE_IDS, | |
FieldType.THIRD_PARTY_CLASSES, | |
FieldType.THIRD_PARTY_FLOW, | |
} | |
self.eom_header = self._create_header(FieldType.END_OF_MESSAGE, b"") | |
super().__init__(width, height, depth_format, add_depth_noise) | |
def _create_header(self, message_type, body): | |
return struct.pack(self.header_format, message_type, len(body)) | |
def _recv_message(self): | |
if self.server_pipe is None: | |
self.server_pipe = open(self.server_pipe_path, "rb") | |
metadata = None | |
files = defaultdict(list) | |
while True: | |
header = self.server_pipe.read(self.header_size) # message type + length | |
if len(header) == 0: | |
self.unity_proc.wait(timeout=5) | |
returncode = self.unity_proc.returncode | |
message = ( | |
"Unity process has exited - check Player.log for errors. Last action message: %s, returncode=%s" | |
% (self._last_action_message, self.unity_proc.returncode) | |
) | |
# we don't want to restart all process exits since its possible that a user | |
# kills off a Unity process with SIGTERM to end a training run | |
# SIGABRT is the returncode for when Unity crashes due to a segfault | |
if returncode in [-6, -11]: # SIGABRT, SIGSEGV | |
raise UnityCrashException(message) | |
else: | |
raise Exception(message) | |
if header[0] == FieldType.END_OF_MESSAGE.value: | |
# print("GOT EOM") | |
break | |
# print("got header %s" % header) | |
field_type_int, message_length = struct.unpack(self.header_format, header) | |
field_type = self.field_types[field_type_int] | |
body = self.server_pipe.read(message_length) | |
# print("field type") | |
# print(field_type) | |
if field_type is FieldType.METADATA: | |
# print("body length %s" % len(body)) | |
# print(body) | |
metadata = msgpack.loads(body, raw=False) | |
elif field_type is FieldType.METADATA_PATCH: | |
metadata_patch = msgpack.loads(body, raw=False) | |
agents = self.raw_metadata["agents"] | |
metadata = dict( | |
agents=[{} for i in range(len(agents))], | |
thirdPartyCameras=self.raw_metadata["thirdPartyCameras"], | |
sequenceId=self.sequence_id, | |
activeAgentId=metadata_patch["agentId"], | |
) | |
for i in range(len(agents)): | |
metadata["agents"][i].update(agents[i]) | |
metadata["agents"][metadata_patch["agentId"]].update(metadata_patch) | |
files = self.raw_files | |
elif field_type in self.image_fields: | |
files[self.form_field_map[field_type]].append(body) | |
else: | |
raise ValueError("Invalid field type: %s" % field_type) | |
self.raw_metadata = metadata | |
self.raw_files = files | |
return metadata, files | |
def _send_message(self, message_type, body): | |
# print("trying to write to ") | |
if self.client_pipe is None: | |
self.client_pipe = open(self.client_pipe_path, "wb") | |
header = self._create_header(message_type, body) | |
# print("len header %s" % len(header)) | |
# print("sending body %s" % body) | |
# used for debugging in case of an error | |
self._last_action_message = body | |
self.client_pipe.write(header + body + self.eom_header) | |
self.client_pipe.flush() | |
def receive(self): | |
metadata, files = self._recv_message() | |
if metadata is None: | |
raise ValueError("no metadata received from recv_message") | |
return self.create_event(metadata, files) | |
def send(self, action): | |
# print("got action to send") | |
if "sequenceId" in action: | |
self.sequence_id = action["sequenceId"] | |
else: | |
self.sequence_id += 1 | |
action["sequenceId"] = self.sequence_id | |
# print(action) | |
# need to switch this to msgpack | |
self._send_message( | |
FieldType.ACTION, | |
json.dumps(action, cls=ai2thor.server.NumpyAwareEncoder).encode("utf8"), | |
) | |
def start(self): | |
os.mkfifo(self.server_pipe_path) | |
os.mkfifo(self.client_pipe_path) | |
self.started = True | |
# params to pass up to unity | |
def unity_params(self): | |
params = dict( | |
fifo_server_pipe_path=self.server_pipe_path, | |
fifo_client_pipe_path=self.client_pipe_path, | |
) | |
return params | |
def stop(self): | |
self.client_pipe.close() | |
self.server_pipe.close() | |
### ------------------------------------------------------------------ ### | |
import ai2thor.server | |
import json | |
import logging | |
import threading | |
import os | |
try: | |
from queue import Queue, Empty | |
except ImportError: | |
from Queue import Queue, Empty | |
import time | |
from flask import Flask, request, make_response, abort | |
import werkzeug | |
import werkzeug.serving | |
import werkzeug.http | |
logging.getLogger("werkzeug").setLevel(logging.ERROR) | |
werkzeug.serving.WSGIRequestHandler.protocol_version = "HTTP/1.1" | |
# get with timeout to allow quit | |
def queue_get(que, unity_proc=None): | |
res = None | |
attempts = 0 | |
max_attempts = 200 | |
while True: | |
try: | |
res = que.get(block=True, timeout=0.5) | |
break | |
except Empty: | |
attempts += 1 | |
# we poll here for the unity proc in the event that it has | |
# exited otherwise we would wait indefinetly for the queue | |
if unity_proc: | |
if unity_proc.poll() is not None: | |
raise Exception("Unity process exited %s" % unity_proc.returncode) | |
# no Action should take > 100s to complete, so we assume that | |
# something has gone wrong within Unity | |
# max_attempts can also be triggered if an Exception is thrown from | |
# within the thread used to run the wsgi server, in which case | |
# Unity will receive a corrupted response | |
if attempts >= max_attempts: | |
raise Exception( | |
"Could not get a message from the queue after %s attempts " | |
% attempts | |
) | |
return res | |
class BufferedIO(object): | |
def __init__(self, wfile): | |
self.wfile = wfile | |
self.data = [] | |
def write(self, output): | |
self.data.append(output) | |
def flush(self): | |
self.wfile.write(b"".join(self.data)) | |
self.wfile.flush() | |
def close(self): | |
return self.wfile.close() | |
@property | |
def closed(self): | |
return self.wfile.closed | |
class ThorRequestHandler(werkzeug.serving.WSGIRequestHandler): | |
def run_wsgi(self): | |
old_wfile = self.wfile | |
self.wfile = BufferedIO(self.wfile) | |
result = super(ThorRequestHandler, self).run_wsgi() | |
self.wfile = old_wfile | |
return result | |
class MultipartFormParser(object): | |
@staticmethod | |
def get_boundary(request_headers): | |
for h, value in request_headers: | |
if h == "Content-Type": | |
ctype, ct_opts = werkzeug.http.parse_options_header(value) | |
boundary = ct_opts["boundary"].encode("ascii") | |
return boundary | |
return None | |
def __init__(self, data, boundary): | |
self.form = {} | |
self.files = {} | |
full_boundary = b"--" + boundary | |
mid_boundary = b"\r\n" + full_boundary | |
view = memoryview(data) | |
i = data.find(full_boundary) + len(full_boundary) | |
while i >= 0: | |
next_offset = data.find(mid_boundary, i) | |
if next_offset < 0: | |
break | |
headers_offset = i + 2 # add 2 for CRLF | |
body_offset = data.find(b"\r\n\r\n", headers_offset) | |
raw_headers = view[headers_offset:body_offset] | |
body = view[body_offset + 4 : next_offset] | |
i = next_offset + len(mid_boundary) | |
headers = {} | |
for header in raw_headers.tobytes().decode("ascii").strip().split("\r\n"): | |
k, v = header.split(":") | |
headers[k.strip()] = v.strip() | |
ctype, ct_opts = werkzeug.http.parse_options_header(headers["Content-Type"]) | |
cdisp, cd_opts = werkzeug.http.parse_options_header( | |
headers["Content-disposition"] | |
) | |
assert cdisp == "form-data" | |
if "filename" in cd_opts: | |
if cd_opts["name"] not in self.files: | |
self.files[cd_opts["name"]] = [] | |
self.files[cd_opts["name"]].append(body) | |
else: | |
if ctype == "text/plain" and "charset" in ct_opts: | |
body = body.tobytes().decode(ct_opts["charset"]) | |
if cd_opts["name"] not in self.form: | |
self.form[cd_opts["name"]] = [] | |
self.form[cd_opts["name"]].append(body) | |
class WsgiServer(ai2thor.server.Server): | |
server_type = "WSGI" | |
def __init__( | |
self, | |
host, | |
port=0, | |
threaded=False, | |
depth_format=ai2thor.server.DepthFormat.Meters, | |
add_depth_noise=False, | |
width=300, | |
height=300, | |
): | |
app = Flask( | |
__name__, | |
template_folder=os.path.realpath( | |
os.path.join( | |
os.path.dirname(os.path.abspath(__file__)), "..", "templates" | |
) | |
), | |
) | |
self.request_queue = Queue(maxsize=1) | |
self.response_queue = Queue(maxsize=1) | |
self.app = app | |
self.app.config.update( | |
PROPAGATE_EXCEPTIONS=False, JSONIFY_PRETTYPRINT_REGULAR=False | |
) | |
self.port = port | |
self.last_rate_timestamp = time.time() | |
self.frame_counter = 0 | |
self.debug_frames_per_interval = 50 | |
self.unity_proc = None | |
self.wsgi_server = werkzeug.serving.make_server( | |
host, | |
self.port, | |
self.app, | |
threaded=threaded, | |
request_handler=ThorRequestHandler, | |
) | |
# used to ensure that we are receiving frames for the action we sent | |
super().__init__(width, height, depth_format, add_depth_noise) | |
@app.route("/ping", methods=["get"]) | |
def ping(): | |
return "pong" | |
@app.route("/train", methods=["post"]) | |
def train(): | |
action_returns = [] | |
if request.headers["Content-Type"].split(";")[0] == "multipart/form-data": | |
form = MultipartFormParser( | |
request.get_data(), | |
MultipartFormParser.get_boundary(request.headers), | |
) | |
metadata = json.loads(form.form["metadata"][0]) | |
# backwards compatibility | |
if ( | |
"actionReturns" in form.form | |
and len(form.form["actionReturns"][0]) > 0 | |
): | |
action_returns = json.loads(form.form["actionReturns"][0]) | |
token = form.form["token"][0] | |
else: | |
form = request | |
metadata = json.loads(form.form["metadata"]) | |
# backwards compatibility | |
if "actionReturns" in form.form and len(form.form["actionReturns"]) > 0: | |
action_returns = json.loads(form.form["actionReturns"]) | |
token = form.form["token"] | |
if self.client_token and token != self.client_token: | |
abort(403) | |
if self.frame_counter % self.debug_frames_per_interval == 0: | |
now = time.time() | |
# rate = self.debug_frames_per_interval / float(now - self.last_rate_timestamp) | |
self.last_rate_timestamp = now | |
# import datetime | |
# print("%s %s/s" % (datetime.datetime.now().isoformat(), rate)) | |
for i, a in enumerate(metadata["agents"]): | |
if "actionReturn" not in a and i < len(action_returns): | |
a["actionReturn"] = action_returns[i] | |
event = self.create_event(metadata, form.files) | |
self.request_queue.put_nowait(event) | |
self.frame_counter += 1 | |
next_action = queue_get(self.response_queue) | |
if "sequenceId" not in next_action: | |
self.sequence_id += 1 | |
next_action["sequenceId"] = self.sequence_id | |
else: | |
self.sequence_id = next_action["sequenceId"] | |
resp = make_response( | |
json.dumps(next_action, cls=ai2thor.server.NumpyAwareEncoder) | |
) | |
return resp | |
def _start_server_thread(self): | |
self.wsgi_server.serve_forever() | |
def start(self): | |
self.started = True | |
self.server_thread = threading.Thread(target=self._start_server_thread) | |
self.server_thread.daemon = True | |
self.server_thread.start() | |
def receive(self): | |
return queue_get(self.request_queue, unity_proc=self.unity_proc) | |
def send(self, action): | |
assert self.request_queue.empty() | |
self.response_queue.put_nowait(action) | |
# params to pass up to unity | |
def unity_params(self): | |
host, port = self.wsgi_server.socket.getsockname() | |
params = dict(host=host, port=str(port)) | |
return params | |
def stop(self): | |
self.send({}) | |
self.wsgi_server.shutdown() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment