Skip to content

Instantly share code, notes, and snippets.

@noelnico
Created February 3, 2021 22:14
Show Gist options
  • Save noelnico/e343055290f770cb2f861cf69b9b3f33 to your computer and use it in GitHub Desktop.
Save noelnico/e343055290f770cb2f861cf69b9b3f33 to your computer and use it in GitHub Desktop.
Custom server code for damaged phone detection
import argparse
import asyncio
import json
import logging
import os
import ssl
import uuid
import cv2
from opencv_detection import analyse_frame
import time
import numpy as np
from aiohttp import web
from av import VideoFrame
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder
ROOT = os.path.dirname(__file__)
logger = logging.getLogger("pc")
pcs = set()
class VideoTransformTrack(MediaStreamTrack):
"""
A video stream track that transforms frames from an another track.
"""
kind = "video"
def __init__(self, track, transform, stats):
super().__init__() # don't forget this!
if args.play_from:
player = MediaPlayer(args.play_from)
self.track = player.video
else:
self.track = track
self.transform = transform
self.stats = stats
async def recv(self):
frame = await self.track.recv()
self.stats['total_frames'] += 1
# Drop all frames waiting in the buffer.
# Solution based on https://github.com/aiortc/aiortc/issues/91
while not self.track._queue.empty():
frame = await self.track.recv()
self.stats['total_frames'] += 1
self.stats['dropped_frames'] += 1
if self.transform == 'none':
return frame
elif self.transform == "edges":
# perform edge detection
img = frame.to_ndarray(format="bgr24")
img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)
# rebuild a VideoFrame, preserving timing information
new_frame = VideoFrame.from_ndarray(img, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
elif self.transform == "detection":
loop = asyncio.get_event_loop()
logger.debug(str(self.stats))
return await loop.run_in_executor(None, self.detect_phone, frame)
def detect_phone(self, frame):
""" Take a video frame, apply phone detection, and bundle back into a frame"""
self.stats['analyzed_frames'] += 1
img = frame.to_ndarray(format="bgr24")
new_img = analyse_frame(img)
new_frame = VideoFrame.from_ndarray(new_img, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
async def index(request):
content = open(os.path.join(ROOT, "index.html"), "r").read()
return web.Response(content_type="text/html", text=content)
async def javascript(request):
content = open(os.path.join(ROOT, "client.js"), "r").read()
return web.Response(content_type="application/javascript", text=content)
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
stats = {
'total_frames': 0,
'dropped_frames': 0,
'analyzed_frames': 0,
'ping_ts': ''
}
pc = RTCPeerConnection()
pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc)
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote)
@pc.on("datachannel")
def on_datachannel(channel):
@channel.on("message")
def on_message(message):
if isinstance(message, str) and message.startswith("ping"):
stats['ping_ts'] = message[5:]
channel.send("pong " + json.dumps(stats))
@pc.on("iceconnectionstatechange")
async def on_iceconnectionstatechange():
log_info("ICE connection state is %s", pc.iceConnectionState)
if pc.iceConnectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track):
log_info("Track %s received", track.kind)
if track.kind == "video":
local_video = VideoTransformTrack(
track, transform=params["video_transform"], stats=stats
)
pc.addTrack(local_video)
@track.on("ended")
async def on_ended():
log_info("Track %s ended", track.kind)
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="WebRTC audio / video / data-channels demo"
)
parser.add_argument("--cert-file", help="SSL certificate file (for HTTPS)")
parser.add_argument("--key-file", help="SSL key file (for HTTPS)")
parser.add_argument(
"--port", type=int, default=8080, help="Port for HTTP server (default: 8080)"
)
parser.add_argument("--verbose", "-v", action="count")
parser.add_argument("--write-audio", help="Write received audio to a file")
parser.add_argument("--play-from", help="Read the media from a file and sent it.")
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
if args.cert_file:
ssl_context = ssl.SSLContext()
ssl_context.load_cert_chain(args.cert_file, args.key_file)
else:
ssl_context = None
app = web.Application()
app.on_shutdown.append(on_shutdown)
app.router.add_get("/", index)
app.router.add_get("/client.js", javascript)
app.router.add_post("/offer", offer)
web.run_app(app, access_log=None, port=args.port, ssl_context=ssl_context)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment