Skip to content

Instantly share code, notes, and snippets.

@raeidsaqur
Created June 24, 2022 17:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raeidsaqur/174d745cf4d3df60d0623f5cf07edac8 to your computer and use it in GitHub Desktop.
Save raeidsaqur/174d745cf4d3df60d0623f5cf07edac8 to your computer and use it in GitHub Desktop.
Create a WSGI server to allow declarative headless rendering for unity based gaming environments.
## For general architecture see: https://silverweed.github.io/assets/docs/distributed_rendering_in_vulkan.pdf
import Xlib.display
import glob
import warnings
import os
import ctypes.util
import xml.etree.ElementTree
class Request:
def __init__(self, system, width, height, x_display, headless):
self.system = system
self.width = width
self.height = height
self.x_display = x_display
self.headless = headless
class BasePlatform:
enabled = True
@classmethod
def validate(cls, r):
return []
@classmethod
def dependency_instructions(cls, request):
return None
@classmethod
def is_valid(cls, request):
return len(cls.validate(request)) == 0
@classmethod
def name(cls):
return cls.__name__
@classmethod
def launch_env(cls, width, height, x_display):
return {}
class BaseLinuxPlatform(BasePlatform):
@classmethod
def executable_path(cls, base_dir, name):
return os.path.join(base_dir, name)
@classmethod
def old_executable_path(cls, base_dir, name):
return cls.executable_path(base_dir, name)
class Linux64(BaseLinuxPlatform):
@classmethod
def dependency_instructions(cls, request):
message = "Linux64 requires a X11 server to be running with GLX. "
displays = cls._valid_x_displays(request.width, request.height)
if displays:
message += "The following valid displays were found %s" % (
", ".join(displays)
)
else:
message += "If you have a NVIDIA GPU, please run: sudo ai2thor-xorg start"
return message
@classmethod
def _select_x_display(cls, width, height):
valid_displays = cls._valid_x_displays(width, height)
if valid_displays:
return valid_displays[0]
else:
return None
@classmethod
def launch_env(cls, width, height, x_display):
env = dict(DISPLAY=x_display)
if env["DISPLAY"] is None:
env["DISPLAY"] = cls._select_x_display(width, height)
return env
@classmethod
def _validate_screen(cls, display_screen_str, width, height):
errors = []
try:
disp_screen = Xlib.display.Display(
display_screen_str
) # display_screen_str will have the format ":0.1"
screen_parts = display_screen_str.split(".")
if len(screen_parts) > 1:
# this Xlib.display will find a valid screen if an
# invalid screen was passed in (e.g. :0.9999999 -> :0.1)
if screen_parts[1] != str(disp_screen.get_default_screen()):
errors.append(
"Invalid display, non-existent screen: %s" % display_screen_str
)
if "GLX" not in disp_screen.list_extensions():
errors.append(
"Display %s does not have the GLX extension loaded. GLX is required by Unity3D."
% display_screen_str
)
if (
disp_screen.screen()["width_in_pixels"] < width
or disp_screen.screen()["height_in_pixels"] < height
):
errors.append(
"Display %s does not have a large enough resolution for the target resolution: %sx%s vs. %sx%s"
% (
display_screen_str,
width,
height,
disp_screen.screen()["width_in_pixels"],
disp_screen.screen()["height_in_pixels"],
)
)
if disp_screen.screen()["root_depth"] != 24:
errors.append(
"Display %s does not have a color depth of 24: %s"
% (display_screen_str, disp_screen.screen()["root_depth"])
)
except (Xlib.error.DisplayNameError, Xlib.error.DisplayConnectionError) as e:
errors.append(
"Invalid display: %s. Failed to connect %s " % (display_screen_str, e)
)
return errors
@classmethod
def _is_valid_screen(cls, display_screen_str, width, height):
return len(cls._validate_screen(display_screen_str, width, height)) == 0
@classmethod
def _valid_x_displays(cls, width, height):
open_display_strs = [
int(os.path.basename(s)[1:]) for s in glob.glob("/tmp/.X11-unix/X[0-9]*")
]
valid_displays = []
for display_str in open_display_strs:
try:
disp = Xlib.display.Display(":%s" % display_str)
for screen in range(0, disp.screen_count()):
disp_screen_str = ":%s.%s" % (display_str, screen)
if cls._is_valid_screen(disp_screen_str, width, height):
valid_displays.append(disp_screen_str)
except Xlib.error.DisplayConnectionError as e:
warnings.warn(
"could not connect to X Display: %s, %s" % (display_str, e)
)
return valid_displays
@classmethod
def validate(cls, request):
if request.headless:
return []
elif request.x_display:
return cls._validate_screen(
request.x_display, request.width, request.height
)
elif cls._select_x_display(request.width, request.height) is None:
return ["No valid X display found"]
else:
return []
class OSXIntel64(BasePlatform):
@classmethod
def old_executable_path(cls, base_dir, name):
return os.path.join(base_dir, name + ".app", "Contents/MacOS", name)
@classmethod
def executable_path(cls, base_dir, name):
plist = cls.parse_plist(base_dir, name)
return os.path.join(
base_dir, name + ".app", "Contents/MacOS", plist["CFBundleExecutable"]
)
@classmethod
def parse_plist(cls, base_dir, name):
plist_path = os.path.join(base_dir, name + ".app", "Contents/Info.plist")
with open(plist_path) as f:
plist = f.read()
root = xml.etree.ElementTree.fromstring(plist)
keys = [x.text for x in root.findall("dict/key")]
values = [x.text for x in root.findall("dict/string")]
return dict(zip(keys, values))
class CloudRendering(BaseLinuxPlatform):
enabled = True
@classmethod
def dependency_instructions(cls, request):
return "CloudRendering requires libvulkan1. Please install by running: sudo apt-get -y libvulkan1"
@classmethod
def failure_message(cls):
pass
@classmethod
def validate(cls, request):
if ctypes.util.find_library("vulkan") is not None:
return []
else:
return ["Vulkan API driver missing."]
class WebGL(BasePlatform):
pass
def select_platforms(request):
candidates = []
system_platform_map = dict(Linux=(Linux64, CloudRendering), Darwin=(OSXIntel64,))
for p in system_platform_map.get(request.system, ()):
if not p.enabled:
continue
#
# if p == CloudRendering and request.x_display is not None:
# continue
candidates.append(p)
return candidates
STR_PLATFORM_MAP = dict(
CloudRendering=CloudRendering, Linux64=Linux64, OSXIntel64=OSXIntel64, WebGL=WebGL
)
### ------------------------------------------------------------------ ###
# Copyright Allen Institute for Artificial Intelligence 2017
"""
ai2thor.server
Handles all communication with Unity through a Flask service. Messages
are sent to the controller using a pair of request/response queues.
"""
import ai2thor.server
import json
import msgpack
import os
import tempfile
from ai2thor.exceptions import UnityCrashException
from enum import IntEnum, unique
from collections import defaultdict
import struct
# FifoFields
@unique
class FieldType(IntEnum):
METADATA = 1
ACTION = 2
ACTION_RESULT = 3
RGB_IMAGE = 4
DEPTH_IMAGE = 5
NORMALS_IMAGE = 6
FLOWS_IMAGE = 7
CLASSES_IMAGE = 8
IDS_IMAGE = 9
THIRD_PARTY_IMAGE = 10
METADATA_PATCH = 11
THIRD_PARTY_DEPTH = 12
THIRD_PARTY_NORMALS = 13
THIRD_PARTY_IMAGE_IDS = 14
THIRD_PARTY_CLASSES = 15
THIRD_PARTY_FLOW = 16
END_OF_MESSAGE = 255
class FifoServer(ai2thor.server.Server):
header_format = "!BI"
header_size = struct.calcsize(header_format)
field_types = {f.value: f for f in FieldType}
server_type = "FIFO"
def __init__(
self,
width,
height,
depth_format=ai2thor.server.DepthFormat.Meters,
add_depth_noise=False,
):
self.tmp_dir = tempfile.TemporaryDirectory()
self.server_pipe_path = os.path.join(self.tmp_dir.name, "server.pipe")
self.client_pipe_path = os.path.join(self.tmp_dir.name, "client.pipe")
self.server_pipe = None
self.client_pipe = None
self.raw_metadata = None
self.raw_files = None
self._last_action_message = None
# allows us to map the enum to form field names
# for backwards compatibility
# this can be removed when the wsgi server is removed
self.form_field_map = {
FieldType.RGB_IMAGE: "image",
FieldType.DEPTH_IMAGE: "image_depth",
FieldType.CLASSES_IMAGE: "image_classes",
FieldType.IDS_IMAGE: "image_ids",
FieldType.NORMALS_IMAGE: "image_normals",
FieldType.FLOWS_IMAGE: "image_flow",
FieldType.THIRD_PARTY_IMAGE: "image-thirdParty-camera",
FieldType.THIRD_PARTY_DEPTH: "image_thirdParty_depth",
FieldType.THIRD_PARTY_NORMALS: "image_thirdParty_normals",
FieldType.THIRD_PARTY_IMAGE_IDS: "image_thirdParty_image_ids",
FieldType.THIRD_PARTY_CLASSES: "image_thirdParty_classes",
FieldType.THIRD_PARTY_FLOW: "image_thirdParty_flow",
}
self.image_fields = {
FieldType.IDS_IMAGE,
FieldType.CLASSES_IMAGE,
FieldType.FLOWS_IMAGE,
FieldType.NORMALS_IMAGE,
FieldType.DEPTH_IMAGE,
FieldType.RGB_IMAGE,
FieldType.THIRD_PARTY_IMAGE,
FieldType.THIRD_PARTY_DEPTH,
FieldType.THIRD_PARTY_NORMALS,
FieldType.THIRD_PARTY_IMAGE_IDS,
FieldType.THIRD_PARTY_CLASSES,
FieldType.THIRD_PARTY_FLOW,
}
self.eom_header = self._create_header(FieldType.END_OF_MESSAGE, b"")
super().__init__(width, height, depth_format, add_depth_noise)
def _create_header(self, message_type, body):
return struct.pack(self.header_format, message_type, len(body))
def _recv_message(self):
if self.server_pipe is None:
self.server_pipe = open(self.server_pipe_path, "rb")
metadata = None
files = defaultdict(list)
while True:
header = self.server_pipe.read(self.header_size) # message type + length
if len(header) == 0:
self.unity_proc.wait(timeout=5)
returncode = self.unity_proc.returncode
message = (
"Unity process has exited - check Player.log for errors. Last action message: %s, returncode=%s"
% (self._last_action_message, self.unity_proc.returncode)
)
# we don't want to restart all process exits since its possible that a user
# kills off a Unity process with SIGTERM to end a training run
# SIGABRT is the returncode for when Unity crashes due to a segfault
if returncode in [-6, -11]: # SIGABRT, SIGSEGV
raise UnityCrashException(message)
else:
raise Exception(message)
if header[0] == FieldType.END_OF_MESSAGE.value:
# print("GOT EOM")
break
# print("got header %s" % header)
field_type_int, message_length = struct.unpack(self.header_format, header)
field_type = self.field_types[field_type_int]
body = self.server_pipe.read(message_length)
# print("field type")
# print(field_type)
if field_type is FieldType.METADATA:
# print("body length %s" % len(body))
# print(body)
metadata = msgpack.loads(body, raw=False)
elif field_type is FieldType.METADATA_PATCH:
metadata_patch = msgpack.loads(body, raw=False)
agents = self.raw_metadata["agents"]
metadata = dict(
agents=[{} for i in range(len(agents))],
thirdPartyCameras=self.raw_metadata["thirdPartyCameras"],
sequenceId=self.sequence_id,
activeAgentId=metadata_patch["agentId"],
)
for i in range(len(agents)):
metadata["agents"][i].update(agents[i])
metadata["agents"][metadata_patch["agentId"]].update(metadata_patch)
files = self.raw_files
elif field_type in self.image_fields:
files[self.form_field_map[field_type]].append(body)
else:
raise ValueError("Invalid field type: %s" % field_type)
self.raw_metadata = metadata
self.raw_files = files
return metadata, files
def _send_message(self, message_type, body):
# print("trying to write to ")
if self.client_pipe is None:
self.client_pipe = open(self.client_pipe_path, "wb")
header = self._create_header(message_type, body)
# print("len header %s" % len(header))
# print("sending body %s" % body)
# used for debugging in case of an error
self._last_action_message = body
self.client_pipe.write(header + body + self.eom_header)
self.client_pipe.flush()
def receive(self):
metadata, files = self._recv_message()
if metadata is None:
raise ValueError("no metadata received from recv_message")
return self.create_event(metadata, files)
def send(self, action):
# print("got action to send")
if "sequenceId" in action:
self.sequence_id = action["sequenceId"]
else:
self.sequence_id += 1
action["sequenceId"] = self.sequence_id
# print(action)
# need to switch this to msgpack
self._send_message(
FieldType.ACTION,
json.dumps(action, cls=ai2thor.server.NumpyAwareEncoder).encode("utf8"),
)
def start(self):
os.mkfifo(self.server_pipe_path)
os.mkfifo(self.client_pipe_path)
self.started = True
# params to pass up to unity
def unity_params(self):
params = dict(
fifo_server_pipe_path=self.server_pipe_path,
fifo_client_pipe_path=self.client_pipe_path,
)
return params
def stop(self):
self.client_pipe.close()
self.server_pipe.close()
### ------------------------------------------------------------------ ###
import ai2thor.server
import json
import logging
import threading
import os
try:
from queue import Queue, Empty
except ImportError:
from Queue import Queue, Empty
import time
from flask import Flask, request, make_response, abort
import werkzeug
import werkzeug.serving
import werkzeug.http
logging.getLogger("werkzeug").setLevel(logging.ERROR)
werkzeug.serving.WSGIRequestHandler.protocol_version = "HTTP/1.1"
# get with timeout to allow quit
def queue_get(que, unity_proc=None):
res = None
attempts = 0
max_attempts = 200
while True:
try:
res = que.get(block=True, timeout=0.5)
break
except Empty:
attempts += 1
# we poll here for the unity proc in the event that it has
# exited otherwise we would wait indefinetly for the queue
if unity_proc:
if unity_proc.poll() is not None:
raise Exception("Unity process exited %s" % unity_proc.returncode)
# no Action should take > 100s to complete, so we assume that
# something has gone wrong within Unity
# max_attempts can also be triggered if an Exception is thrown from
# within the thread used to run the wsgi server, in which case
# Unity will receive a corrupted response
if attempts >= max_attempts:
raise Exception(
"Could not get a message from the queue after %s attempts "
% attempts
)
return res
class BufferedIO(object):
def __init__(self, wfile):
self.wfile = wfile
self.data = []
def write(self, output):
self.data.append(output)
def flush(self):
self.wfile.write(b"".join(self.data))
self.wfile.flush()
def close(self):
return self.wfile.close()
@property
def closed(self):
return self.wfile.closed
class ThorRequestHandler(werkzeug.serving.WSGIRequestHandler):
def run_wsgi(self):
old_wfile = self.wfile
self.wfile = BufferedIO(self.wfile)
result = super(ThorRequestHandler, self).run_wsgi()
self.wfile = old_wfile
return result
class MultipartFormParser(object):
@staticmethod
def get_boundary(request_headers):
for h, value in request_headers:
if h == "Content-Type":
ctype, ct_opts = werkzeug.http.parse_options_header(value)
boundary = ct_opts["boundary"].encode("ascii")
return boundary
return None
def __init__(self, data, boundary):
self.form = {}
self.files = {}
full_boundary = b"--" + boundary
mid_boundary = b"\r\n" + full_boundary
view = memoryview(data)
i = data.find(full_boundary) + len(full_boundary)
while i >= 0:
next_offset = data.find(mid_boundary, i)
if next_offset < 0:
break
headers_offset = i + 2 # add 2 for CRLF
body_offset = data.find(b"\r\n\r\n", headers_offset)
raw_headers = view[headers_offset:body_offset]
body = view[body_offset + 4 : next_offset]
i = next_offset + len(mid_boundary)
headers = {}
for header in raw_headers.tobytes().decode("ascii").strip().split("\r\n"):
k, v = header.split(":")
headers[k.strip()] = v.strip()
ctype, ct_opts = werkzeug.http.parse_options_header(headers["Content-Type"])
cdisp, cd_opts = werkzeug.http.parse_options_header(
headers["Content-disposition"]
)
assert cdisp == "form-data"
if "filename" in cd_opts:
if cd_opts["name"] not in self.files:
self.files[cd_opts["name"]] = []
self.files[cd_opts["name"]].append(body)
else:
if ctype == "text/plain" and "charset" in ct_opts:
body = body.tobytes().decode(ct_opts["charset"])
if cd_opts["name"] not in self.form:
self.form[cd_opts["name"]] = []
self.form[cd_opts["name"]].append(body)
class WsgiServer(ai2thor.server.Server):
server_type = "WSGI"
def __init__(
self,
host,
port=0,
threaded=False,
depth_format=ai2thor.server.DepthFormat.Meters,
add_depth_noise=False,
width=300,
height=300,
):
app = Flask(
__name__,
template_folder=os.path.realpath(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "templates"
)
),
)
self.request_queue = Queue(maxsize=1)
self.response_queue = Queue(maxsize=1)
self.app = app
self.app.config.update(
PROPAGATE_EXCEPTIONS=False, JSONIFY_PRETTYPRINT_REGULAR=False
)
self.port = port
self.last_rate_timestamp = time.time()
self.frame_counter = 0
self.debug_frames_per_interval = 50
self.unity_proc = None
self.wsgi_server = werkzeug.serving.make_server(
host,
self.port,
self.app,
threaded=threaded,
request_handler=ThorRequestHandler,
)
# used to ensure that we are receiving frames for the action we sent
super().__init__(width, height, depth_format, add_depth_noise)
@app.route("/ping", methods=["get"])
def ping():
return "pong"
@app.route("/train", methods=["post"])
def train():
action_returns = []
if request.headers["Content-Type"].split(";")[0] == "multipart/form-data":
form = MultipartFormParser(
request.get_data(),
MultipartFormParser.get_boundary(request.headers),
)
metadata = json.loads(form.form["metadata"][0])
# backwards compatibility
if (
"actionReturns" in form.form
and len(form.form["actionReturns"][0]) > 0
):
action_returns = json.loads(form.form["actionReturns"][0])
token = form.form["token"][0]
else:
form = request
metadata = json.loads(form.form["metadata"])
# backwards compatibility
if "actionReturns" in form.form and len(form.form["actionReturns"]) > 0:
action_returns = json.loads(form.form["actionReturns"])
token = form.form["token"]
if self.client_token and token != self.client_token:
abort(403)
if self.frame_counter % self.debug_frames_per_interval == 0:
now = time.time()
# rate = self.debug_frames_per_interval / float(now - self.last_rate_timestamp)
self.last_rate_timestamp = now
# import datetime
# print("%s %s/s" % (datetime.datetime.now().isoformat(), rate))
for i, a in enumerate(metadata["agents"]):
if "actionReturn" not in a and i < len(action_returns):
a["actionReturn"] = action_returns[i]
event = self.create_event(metadata, form.files)
self.request_queue.put_nowait(event)
self.frame_counter += 1
next_action = queue_get(self.response_queue)
if "sequenceId" not in next_action:
self.sequence_id += 1
next_action["sequenceId"] = self.sequence_id
else:
self.sequence_id = next_action["sequenceId"]
resp = make_response(
json.dumps(next_action, cls=ai2thor.server.NumpyAwareEncoder)
)
return resp
def _start_server_thread(self):
self.wsgi_server.serve_forever()
def start(self):
self.started = True
self.server_thread = threading.Thread(target=self._start_server_thread)
self.server_thread.daemon = True
self.server_thread.start()
def receive(self):
return queue_get(self.request_queue, unity_proc=self.unity_proc)
def send(self, action):
assert self.request_queue.empty()
self.response_queue.put_nowait(action)
# params to pass up to unity
def unity_params(self):
host, port = self.wsgi_server.socket.getsockname()
params = dict(host=host, port=str(port))
return params
def stop(self):
self.send({})
self.wsgi_server.shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment