Skip to content

Instantly share code, notes, and snippets.

@myagues
Created April 8, 2020 11:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save myagues/aac0c597f8ad0fa7ebe7d017b0c5603b to your computer and use it in GitHub Desktop.
Save myagues/aac0c597f8ad0fa7ebe7d017b0c5603b to your computer and use it in GitHub Desktop.
First Order Model Webcam
// peer connection
let pc = null;
const mediaStreamConstraints = {
video: {
'width': 640,
'height': 480,
'frameRate': 10
}
};
// Define action buttons.
const startButton = document.getElementById('startButton');
const callButton = document.getElementById('callButton');
const stopButton = document.getElementById('stopButton');
// Add click event handlers for buttons.
startButton.addEventListener('click', startAction);
stopButton.addEventListener('click', stopAction);
// Define peer connections, streams and video elements.
const localVideo = document.getElementById('localVideo');
function negotiate() {
return pc.createOffer().then(function(offer) {
return pc.setLocalDescription(offer);
}).then(function() {
// wait for ICE gathering to complete
return new Promise(function(resolve) {
if (pc.iceGatheringState === 'complete') {
resolve();
} else {
function checkState() {
if (pc.iceGatheringState === 'complete') {
pc.removeEventListener('icegatheringstatechange', checkState);
resolve();
}
}
pc.addEventListener('icegatheringstatechange', checkState);
}
});
}).then(function() {
var offer = pc.localDescription;
// document.getElementById('offer-sdp').textContent = offer.sdp;
return fetch('/offer', {
body: JSON.stringify({
sdp: offer.sdp,
type: offer.type
}),
headers: {
'Content-Type': 'application/json'
},
method: 'POST'
});
}).then(function(response) {
return response.json();
}).then(function(answer) {
// document.getElementById('answer-sdp').textContent = answer.sdp;
return pc.setRemoteDescription(answer);
}).catch(function(e) {
alert(e);
});
}
function startAction() {
startButton.disabled = true;
stopButton.disabled = false;
const config = null
pc = new RTCPeerConnection(config);
// connect audio / video
pc.addEventListener('track', function(evt) {
localVideo.srcObject = evt.streams[0];
});
navigator.mediaDevices.getUserMedia(mediaStreamConstraints).then(function(stream) {
stream.getTracks().forEach(function(track) {
pc.addTrack(track, stream);
console.log(track.getSettings().frameRate);
});
return negotiate();
});
}
function stopAction() {
stopButton.disabled = true;
startButton.disabled = false;
// close transceivers
if (pc.getTransceivers) {
pc.getTransceivers().forEach(function(transceiver) {
if (transceiver.stop) {
transceiver.stop();
}
});
}
// close local audio / video
pc.getSenders().forEach(function(sender) {
sender.track.stop();
});
// close peer connection
setTimeout(function() {
pc.close();
}, 500);
}
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8"/>
<title>WebRTC demo</title>
<style>
button {
padding: 8px 16px;
}
pre {
overflow-x: hidden;
overflow-y: auto;
}
video {
width: 25%;
}
.option {
margin-bottom: 8px;
}
#media {
max-width: 1280px;
}
</style>
</head>
<body>
<div>
<button id="startButton">Start</button>
<button id="stopButton">Stop</button>
</div>
<!-- <video id="localVideo2" autoplay playsinline></video> -->
<video id="localVideo" autoplay playsinline></video>
<script type="application/javascript" src="client_slim.js"></script>
</body>
</html>
import argparse
import asyncio
import json
import logging
import os
import ssl
import uuid
import time
import cv2
from aiohttp import web
from av import VideoFrame
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder
from demo import load_checkpoints
from skimage.transform import resize
from skimage import img_as_ubyte
import imageio
import torch
import numpy as np
from animate import normalize_kp
ROOT = os.path.dirname(__file__)
routes = web.RouteTableDef()
logger = logging.getLogger("pc")
pcs = set()
class VideoTransformTrack(MediaStreamTrack):
"""
A video stream track that transforms frames from an another track.
"""
kind = "video"
def __init__(self, track):
super().__init__() # don't forget this!
self.track = track
self.generator, self.kp_detector = load_checkpoints(
config_path='config/vox-256.yaml',
checkpoint_path='/ckpt/first-order-motion-model/vox-cpk.pth.tar')
source_image_path = "/ckpt/first-order-motion-model/statue-02.png"
source_image = imageio.imread(source_image_path)
source_image = resize(source_image, (256, 256))[..., :3]
self.source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()
self.kp_source = self.kp_detector(self.source)
self.first_iter = True
async def recv(self):
frame = await self.track.recv()
with torch.no_grad():
img = frame.to_ndarray(format="rgb24")
# print(time.time())
img = resize(img, (256, 256))[..., :3]
img = torch.tensor(img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()
if self.first_iter:
self.img_init = img
self.first_iter = False
kp_driving_initial = self.kp_detector(self.img_init)
kp_driving = self.kp_detector(img)
kp_norm = normalize_kp(
kp_source=self.kp_source, kp_driving=kp_driving,
kp_driving_initial=kp_driving_initial, use_relative_movement=True,
use_relative_jacobian=True, adapt_movement_scale=True)
out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm)
out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
# rebuild a VideoFrame, preserving timing information
new_frame = VideoFrame.from_ndarray(img_as_ubyte(out), format="rgb24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
@routes.get('/')
async def index(request):
content = open(os.path.join(ROOT, "index.html"), "r").read()
return web.Response(content_type="text/html", text=content)
@routes.get('/client_slim.js')
async def javascript(request):
content = open(os.path.join(ROOT, "client_slim.js"), "r").read()
return web.Response(content_type="application/javascript", text=content)
@routes.post('/offer')
async def offer(request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc)
@pc.on("track")
def on_track(track):
local_video = VideoTransformTrack(track)
pc.addTrack(local_video)
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="WebRTC audio / video / data-channels demo"
)
parser.add_argument("--cert-file", help="SSL certificate file (for HTTPS)")
parser.add_argument("--key-file", help="SSL key file (for HTTPS)")
parser.add_argument(
"--port", type=int, default=8080, help="Port for HTTP server (default: 8080)"
)
parser.add_argument("--verbose", "-v", action="count")
parser.add_argument("--write-audio", help="Write received audio to a file")
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
if args.cert_file:
ssl_context = ssl.SSLContext()
ssl_context.load_cert_chain(args.cert_file, args.key_file)
else:
ssl_context = None
app = web.Application()
app.on_shutdown.append(on_shutdown)
app.add_routes(routes)
web.run_app(app, access_log=None, port=args.port, ssl_context=ssl_context)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment