Created
April 8, 2020 11:35
-
-
Save myagues/aac0c597f8ad0fa7ebe7d017b0c5603b to your computer and use it in GitHub Desktop.
First Order Model Webcam
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// peer connection | |
let pc = null; | |
const mediaStreamConstraints = { | |
video: { | |
'width': 640, | |
'height': 480, | |
'frameRate': 10 | |
} | |
}; | |
// Define action buttons. | |
const startButton = document.getElementById('startButton'); | |
const callButton = document.getElementById('callButton'); | |
const stopButton = document.getElementById('stopButton'); | |
// Add click event handlers for buttons. | |
startButton.addEventListener('click', startAction); | |
stopButton.addEventListener('click', stopAction); | |
// Define peer connections, streams and video elements. | |
const localVideo = document.getElementById('localVideo'); | |
function negotiate() { | |
return pc.createOffer().then(function(offer) { | |
return pc.setLocalDescription(offer); | |
}).then(function() { | |
// wait for ICE gathering to complete | |
return new Promise(function(resolve) { | |
if (pc.iceGatheringState === 'complete') { | |
resolve(); | |
} else { | |
function checkState() { | |
if (pc.iceGatheringState === 'complete') { | |
pc.removeEventListener('icegatheringstatechange', checkState); | |
resolve(); | |
} | |
} | |
pc.addEventListener('icegatheringstatechange', checkState); | |
} | |
}); | |
}).then(function() { | |
var offer = pc.localDescription; | |
// document.getElementById('offer-sdp').textContent = offer.sdp; | |
return fetch('/offer', { | |
body: JSON.stringify({ | |
sdp: offer.sdp, | |
type: offer.type | |
}), | |
headers: { | |
'Content-Type': 'application/json' | |
}, | |
method: 'POST' | |
}); | |
}).then(function(response) { | |
return response.json(); | |
}).then(function(answer) { | |
// document.getElementById('answer-sdp').textContent = answer.sdp; | |
return pc.setRemoteDescription(answer); | |
}).catch(function(e) { | |
alert(e); | |
}); | |
} | |
function startAction() { | |
startButton.disabled = true; | |
stopButton.disabled = false; | |
const config = null | |
pc = new RTCPeerConnection(config); | |
// connect audio / video | |
pc.addEventListener('track', function(evt) { | |
localVideo.srcObject = evt.streams[0]; | |
}); | |
navigator.mediaDevices.getUserMedia(mediaStreamConstraints).then(function(stream) { | |
stream.getTracks().forEach(function(track) { | |
pc.addTrack(track, stream); | |
console.log(track.getSettings().frameRate); | |
}); | |
return negotiate(); | |
}); | |
} | |
function stopAction() { | |
stopButton.disabled = true; | |
startButton.disabled = false; | |
// close transceivers | |
if (pc.getTransceivers) { | |
pc.getTransceivers().forEach(function(transceiver) { | |
if (transceiver.stop) { | |
transceiver.stop(); | |
} | |
}); | |
} | |
// close local audio / video | |
pc.getSenders().forEach(function(sender) { | |
sender.track.stop(); | |
}); | |
// close peer connection | |
setTimeout(function() { | |
pc.close(); | |
}, 500); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html> | |
<head> | |
<meta charset="UTF-8"/> | |
<title>WebRTC demo</title> | |
<style> | |
button { | |
padding: 8px 16px; | |
} | |
pre { | |
overflow-x: hidden; | |
overflow-y: auto; | |
} | |
video { | |
width: 25%; | |
} | |
.option { | |
margin-bottom: 8px; | |
} | |
#media { | |
max-width: 1280px; | |
} | |
</style> | |
</head> | |
<body> | |
<div> | |
<button id="startButton">Start</button> | |
<button id="stopButton">Stop</button> | |
</div> | |
<!-- <video id="localVideo2" autoplay playsinline></video> --> | |
<video id="localVideo" autoplay playsinline></video> | |
<script type="application/javascript" src="client_slim.js"></script> | |
</body> | |
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import asyncio | |
import json | |
import logging | |
import os | |
import ssl | |
import uuid | |
import time | |
import cv2 | |
from aiohttp import web | |
from av import VideoFrame | |
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription | |
from aiortc.contrib.media import MediaBlackhole, MediaPlayer, MediaRecorder | |
from demo import load_checkpoints | |
from skimage.transform import resize | |
from skimage import img_as_ubyte | |
import imageio | |
import torch | |
import numpy as np | |
from animate import normalize_kp | |
ROOT = os.path.dirname(__file__) | |
routes = web.RouteTableDef() | |
logger = logging.getLogger("pc") | |
pcs = set() | |
class VideoTransformTrack(MediaStreamTrack): | |
""" | |
A video stream track that transforms frames from an another track. | |
""" | |
kind = "video" | |
def __init__(self, track): | |
super().__init__() # don't forget this! | |
self.track = track | |
self.generator, self.kp_detector = load_checkpoints( | |
config_path='config/vox-256.yaml', | |
checkpoint_path='/ckpt/first-order-motion-model/vox-cpk.pth.tar') | |
source_image_path = "/ckpt/first-order-motion-model/statue-02.png" | |
source_image = imageio.imread(source_image_path) | |
source_image = resize(source_image, (256, 256))[..., :3] | |
self.source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda() | |
self.kp_source = self.kp_detector(self.source) | |
self.first_iter = True | |
async def recv(self): | |
frame = await self.track.recv() | |
with torch.no_grad(): | |
img = frame.to_ndarray(format="rgb24") | |
# print(time.time()) | |
img = resize(img, (256, 256))[..., :3] | |
img = torch.tensor(img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda() | |
if self.first_iter: | |
self.img_init = img | |
self.first_iter = False | |
kp_driving_initial = self.kp_detector(self.img_init) | |
kp_driving = self.kp_detector(img) | |
kp_norm = normalize_kp( | |
kp_source=self.kp_source, kp_driving=kp_driving, | |
kp_driving_initial=kp_driving_initial, use_relative_movement=True, | |
use_relative_jacobian=True, adapt_movement_scale=True) | |
out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm) | |
out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0] | |
# rebuild a VideoFrame, preserving timing information | |
new_frame = VideoFrame.from_ndarray(img_as_ubyte(out), format="rgb24") | |
new_frame.pts = frame.pts | |
new_frame.time_base = frame.time_base | |
return new_frame | |
@routes.get('/') | |
async def index(request): | |
content = open(os.path.join(ROOT, "index.html"), "r").read() | |
return web.Response(content_type="text/html", text=content) | |
@routes.get('/client_slim.js') | |
async def javascript(request): | |
content = open(os.path.join(ROOT, "client_slim.js"), "r").read() | |
return web.Response(content_type="application/javascript", text=content) | |
@routes.post('/offer') | |
async def offer(request): | |
params = await request.json() | |
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) | |
pc = RTCPeerConnection() | |
pc_id = "PeerConnection(%s)" % uuid.uuid4() | |
pcs.add(pc) | |
@pc.on("track") | |
def on_track(track): | |
local_video = VideoTransformTrack(track) | |
pc.addTrack(local_video) | |
# handle offer | |
await pc.setRemoteDescription(offer) | |
# send answer | |
answer = await pc.createAnswer() | |
await pc.setLocalDescription(answer) | |
return web.Response( | |
content_type="application/json", | |
text=json.dumps( | |
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} | |
), | |
) | |
async def on_shutdown(app): | |
# close peer connections | |
coros = [pc.close() for pc in pcs] | |
await asyncio.gather(*coros) | |
pcs.clear() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="WebRTC audio / video / data-channels demo" | |
) | |
parser.add_argument("--cert-file", help="SSL certificate file (for HTTPS)") | |
parser.add_argument("--key-file", help="SSL key file (for HTTPS)") | |
parser.add_argument( | |
"--port", type=int, default=8080, help="Port for HTTP server (default: 8080)" | |
) | |
parser.add_argument("--verbose", "-v", action="count") | |
parser.add_argument("--write-audio", help="Write received audio to a file") | |
args = parser.parse_args() | |
if args.verbose: | |
logging.basicConfig(level=logging.DEBUG) | |
else: | |
logging.basicConfig(level=logging.INFO) | |
if args.cert_file: | |
ssl_context = ssl.SSLContext() | |
ssl_context.load_cert_chain(args.cert_file, args.key_file) | |
else: | |
ssl_context = None | |
app = web.Application() | |
app.on_shutdown.append(on_shutdown) | |
app.add_routes(routes) | |
web.run_app(app, access_log=None, port=args.port, ssl_context=ssl_context) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment