Skip to content

Instantly share code, notes, and snippets.

@donghee
Last active July 24, 2024 14:21
Show Gist options
  • Save donghee/f1ebce30f8b1e4773913fd7ad25771e4 to your computer and use it in GitHub Desktop.
Save donghee/f1ebce30f8b1e4773913fd7ad25771e4 to your computer and use it in GitHub Desktop.
Testing AI avtar project using Gradio
import os
import glob
import time
import random
import subprocess
import numpy as np
import cv2
import torch
from collections import OrderedDict
from utils.deep_speech import DeepSpeech
from utils.data_processing import load_landmark_openface,compute_crop_radius
from config.config import InferenceOptions
from models.segDINet import segDINet
class OpenFaceExtractor:
def __init__(self, install_path):
self.install_path = install_path
def extract_features(self, video, output_dir):
temp_cwd = os.getcwd()
os.chdir(self.install_path)
command = f"./FeatureExtraction -f {video} -out_dir {output_dir} -2Dfp"
os.system(command)
os.chdir(temp_cwd)
os.system("pwd")
print(f"\nCompleted! Please check that it is extracted to {output_dir}\n")
class DeepSpeechExtractor:
def __init__(self, model_path):
if not os.path.exists(model_path):
raise FileNotFoundError('Please download the pretrained model of DeepSpeech')
self.model = DeepSpeech(model_path)
def extract_features(self, audio_path):
if not os.path.exists(audio_path):
raise FileNotFoundError(f'Wrong audio path: {audio_path}')
ds_feature = self.model.compute_audio_feature(audio_path)
return ds_feature
class FrameExtractor:
def __init__(self, video_path):
self.video_path = video_path
def extract_frames_from_video(self, video_path, save_dir):
videoCapture = cv2.VideoCapture(video_path)
fps = videoCapture.get(cv2.CAP_PROP_FPS)
if int(fps) != 25:
print('warning: the input video is not 25 fps, it would be better to trans it to 25 fps!')
# frames = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
frame_height = videoCapture.get(cv2.CAP_PROP_FRAME_HEIGHT)
frame_width = videoCapture.get(cv2.CAP_PROP_FRAME_WIDTH)
i=0
while(True):
ret, frame = videoCapture.read()
if not ret:
break
result_path = os.path.join(save_dir, str(i).zfill(6) + '.jpg')
cv2.imwrite(result_path, frame)
i+=1
print(f"\nall frames: {i} frame_height: {frame_height} frame_width: {frame_width}\n")
return (int(frame_width),int(frame_height))
def extract_frames(self):
video_frame_dir = self.video_path.replace('.mp4', '')
if not os.path.exists(video_frame_dir):
os.mkdir(video_frame_dir)
video_size = self.extract_frames_from_video(self.video_path, video_frame_dir)
print(f"\nCompleted! Please check that it is extracted to {video_frame_dir}\n")
return video_frame_dir, video_size
class DrivingImageSelector:
def __init__(self, mouth_region_size):
self.resize_w = int(mouth_region_size + mouth_region_size // 4)
self.resize_h = int((mouth_region_size // 2) * 3 + mouth_region_size // 8)
def select_images(self, video_frame_path_list_pad, video_landmark_data_pad, video_size):
driving_img_list = []
driving_index_list = random.sample(range(5, len(video_frame_path_list_pad) - 2), 5)
for driving_index in driving_index_list:
crop_flag, crop_radius = compute_crop_radius(video_size, video_landmark_data_pad[driving_index - 5:driving_index, :, :])
if not crop_flag:
raise ValueError('Our method cannot handle videos with large change of facial size!!')
crop_radius_1_4 = crop_radius // 4
driving_img = cv2.imread(video_frame_path_list_pad[driving_index - 3])[:, :, ::-1]
driving_landmark = video_landmark_data_pad[driving_index - 3, :, :]
driving_img_crop = driving_img[
driving_landmark[29, 1] - crop_radius: driving_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4,
driving_landmark[33, 0] - crop_radius - crop_radius_1_4: driving_landmark[33, 0] + crop_radius + crop_radius_1_4, :]
driving_img_crop = cv2.resize(driving_img_crop, (self.resize_w, self.resize_h))
driving_img_crop = driving_img_crop / 255.0
driving_img_list.append(driving_img_crop)
driving_video_frame = np.concatenate(driving_img_list, 2)
driving_img_tensor = torch.from_numpy(driving_video_frame).permute(2, 0, 1).unsqueeze(0).float().cuda()
return driving_img_tensor
class ModelHandler:
def __init__(self, model_path, source_channel, ref_channel, audio_channel):
if not os.path.exists(model_path):
raise FileNotFoundError(f'Wrong path of the pretrained model weight: {model_path}')
self.model = segDINet(source_channel, ref_channel, audio_channel).cuda()
state_dict = torch.load(model_path)['state_dict']['net_g']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)
self.model.eval()
def infer_frame(self, crop_frame_tensor, driving_img_tensor, deepspeech_tensor, gt_frame_tensor):
with torch.no_grad():
out = self.model(crop_frame_tensor, driving_img_tensor, deepspeech_tensor)
out_frame, out_mask = out[:, :-1], out[:, -1:]
idx = out_mask < 250 / 255.0
pre_frame = torch.where(idx, gt_frame_tensor, out_frame)
pre_frame = pre_frame.squeeze(0).permute(1, 2, 0).detach().cpu().numpy() * 255
return pre_frame
return None
class VideoSynchronizer:
def __init__(self, video_landmark_data, video_frame_path_list, ds_feature_padding, video_size, mouth_region_size):
self.video_landmark_data = video_landmark_data
self.video_frame_path_list = video_frame_path_list
self.ds_feature_padding = ds_feature_padding
self.video_size = video_size
self.mouth_region_size = mouth_region_size
def align_frames_with_audio(self, res_frame_length):
# Frame Alignment Code
video_frame_path_list_cycle = self.video_frame_path_list + self.video_frame_path_list[::-1]
video_landmark_data_cycle = np.concatenate([self.video_landmark_data, np.flip(self.video_landmark_data, 0)], 0)
video_frame_path_list_cycle_length = len(video_frame_path_list_cycle)
# res_frame_length = self.ds_feature_padding.shape[0]
if video_frame_path_list_cycle_length >= res_frame_length:
res_video_frame_path_list = video_frame_path_list_cycle[:res_frame_length]
res_video_landmark_data = video_landmark_data_cycle[:res_frame_length, :, :]
else:
divisor = res_frame_length // video_frame_path_list_cycle_length
remainder = res_frame_length % video_frame_path_list_cycle_length
res_video_frame_path_list = video_frame_path_list_cycle * divisor + video_frame_path_list_cycle[:remainder]
res_video_landmark_data = np.concatenate([video_landmark_data_cycle] * divisor + [video_landmark_data_cycle[:remainder, :, :]], 0)
res_video_frame_path_list_pad = [video_frame_path_list_cycle[0]] * 2 + res_video_frame_path_list + [video_frame_path_list_cycle[-1]] * 2
res_video_landmark_data_pad = np.pad(res_video_landmark_data, ((2, 2), (0, 0), (0, 0)), mode='edge')
assert self.ds_feature_padding.shape[0] == len(res_video_frame_path_list_pad) == res_video_landmark_data_pad.shape[0]
return res_video_frame_path_list_pad, res_video_landmark_data_pad
class VideoPlayer:
@staticmethod
def play_video(video_path):
vid = cv2.VideoCapture(video_path)
if vid.isOpened():
fps = vid.get(cv2.CAP_PROP_FPS)
f_count = vid.get(cv2.CAP_PROP_FRAME_COUNT)
f_width = vid.get(cv2.CAP_PROP_FRAME_WIDTH)
f_height = vid.get(cv2.CAP_PROP_FRAME_HEIGHT)
print('Frames per second : ', fps, 'FPS')
print('Frame count : ', f_count)
print('Frame width : ', f_width)
print('Frame height : ', f_height)
while vid.isOpened():
ret, frame = vid.read()
if ret:
cv2.imshow('Generated Video', frame)
key = cv2.waitKey(1)
if key == ord('q'):
break
else:
break
vid.release()
cv2.destroyAllWindows()
class FacialDubbingPipeline:
def __init__(self, opt):
self.opt = opt
self.openface_extractor = OpenFaceExtractor(opt.OpenFace_install_path)
self.frame_extractor = FrameExtractor(opt.target_video_path)
self.deepspeech_extractor = DeepSpeechExtractor(opt.deepspeech_model_path)
self.driving_image_selector = DrivingImageSelector(opt.mouth_region_size)
self.model_handler = ModelHandler(opt.pretrained_clip_segDinet_path, opt.source_channel, opt.ref_channel, opt.audio_channel)
def run(self):
# OpenFace Feature Extraction
target_video = os.path.abspath(self.opt.target_video_path)
output_landmark_dir = os.path.abspath(os.path.split(self.opt.target_openface_landmark_path)[0])
self.openface_extractor.extract_features(target_video, output_landmark_dir)
# input("if you want next step, press any key.")
# Frame Extraction
video_frame_dir, video_size = self.frame_extractor.extract_frames()
# input("if you want next step, press any key.")
# DeepSpeech Feature Extraction
ds_feature = self.deepspeech_extractor.extract_features(self.opt.source_audio_path)
res_frame_length = ds_feature.shape[0]
ds_feature_padding = np.pad(ds_feature, ((2, 2), (0, 0)), mode='edge')
# input("if you want next step, press any key.")
# Frame & Audio Alignment
## Load OpenFace Landmark Data
if not os.path.exists(self.opt.target_openface_landmark_path):
raise FileNotFoundError(f'Wrong target openface landmark path: {opt.target_openface_landmark_path}')
video_landmark_data = load_landmark_openface(self.opt.target_openface_landmark_path).astype(np.int)
video_frame_path_list = glob.glob(os.path.join(video_frame_dir, '*.jpg'))
video_frame_path_list.sort()
if len(video_frame_path_list) != video_landmark_data.shape[0]:
raise ValueError('video frames are misaligned with detected landmarks')
video_synchronizer = VideoSynchronizer(video_landmark_data, video_frame_path_list, ds_feature_padding, video_size, self.opt.mouth_region_size)
res_video_frame_path_list_pad, res_video_landmark_data_pad = video_synchronizer.align_frames_with_audio(res_frame_length)
pad_length = ds_feature_padding.shape[0]
print("complet aligning frames with source audio")
# Select Driving Images
print("select randomly select 5 driving images")
driving_img_tensor = self.driving_image_selector.select_images(res_video_frame_path_list_pad, res_video_landmark_data_pad, video_size)
# Create Video Writer
res_video_path = os.path.join(self.opt.res_video_dir, os.path.basename(self.opt.target_video_path)[:-4] + '_facial_dubbing.mp4')
res_face_path = res_video_path.replace('_facial_dubbing.mp4', '_synthetic_face.mp4')
videowriter = cv2.VideoWriter(res_video_path, cv2.VideoWriter_fourcc(*'XVID'), 25, video_size)
videowriter_face = cv2.VideoWriter(res_face_path, cv2.VideoWriter_fourcc(*'XVID'), 25, (int(self.opt.mouth_region_size + self.opt.mouth_region_size // 4),
int((self.opt.mouth_region_size // 2) * 3 + self.opt.mouth_region_size // 8)))
# synthesize video and audio
time_sum = self.inference_frame(res_video_frame_path_list_pad, res_video_landmark_data_pad, ds_feature_padding, video_size, driving_img_tensor, pad_length, videowriter, videowriter_face)
video_add_audio_path = self.combine_audio(res_video_path, self.opt.source_audio_path, time_sum, pad_length)
# self.select_closing(video_add_audio_path)
return video_add_audio_path
def inference_frame(self, res_video_frame_path_list_pad, res_video_landmark_data_pad, ds_feature_padding, video_size, driving_img_tensor, pad_length, videowriter, videowriter_face):
time_sum = 0
for clip_end_index in range(5, pad_length, 1):
print(f'synthesizing {clip_end_index - 5}/{pad_length - 5} frame')
start = time.time()
crop_flag, crop_radius = compute_crop_radius(video_size, res_video_landmark_data_pad[clip_end_index - 5:clip_end_index:,:,:], random_scale=1.05)
if not crop_flag:
raise ValueError('Our method cannot handle videos with a large change in facial size!')
crop_radius_1_4 = crop_radius // 4
frame_data = cv2.imread(res_video_frame_path_list_pad[clip_end_index - 3])[:, :, ::-1]
frame_landmark = res_video_landmark_data_pad[clip_end_index - 3, :, :]
crop_frame_data = frame_data[
frame_landmark[29, 1] - crop_radius: frame_landmark[29, 1] + crop_radius * 2 + crop_radius_1_4,
frame_landmark[33, 0] - crop_radius - crop_radius_1_4: frame_landmark[33, 0] + crop_radius + crop_radius_1_4, :]
crop_frame_h, crop_frame_w = crop_frame_data.shape[0], crop_frame_data.shape[1]
crop_frame_data = cv2.resize(crop_frame_data, (int(self.opt.mouth_region_size + self.opt.mouth_region_size // 4), int((self.opt.mouth_region_size // 2) *
3 + self.opt.mouth_region_size // 8)))
crop_frame_data = crop_frame_data / 255.0
gt_frame_data = crop_frame_data.copy()
crop_frame_data[self.opt.mouth_region_size // 2: self.opt.mouth_region_size // 2 + self.opt.mouth_region_size,
self.opt.mouth_region_size // 8: self.opt.mouth_region_size // 8 + self.opt.mouth_region_size, :] = 0
gt_frame_tensor = torch.from_numpy(gt_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0)
crop_frame_tensor = torch.from_numpy(crop_frame_data).float().cuda().permute(2, 0, 1).unsqueeze(0)
deepspeech_tensor = torch.from_numpy(ds_feature_padding[clip_end_index - 5:clip_end_index, :]).permute(1, 0).unsqueeze(0).float().cuda()
pre_frame = self.model_handler.infer_frame(crop_frame_tensor, driving_img_tensor, deepspeech_tensor, gt_frame_tensor)
videowriter_face.write(pre_frame[:, :, ::-1].copy().astype(np.uint8))
pre_frame_resize = cv2.resize(pre_frame, (crop_frame_w, crop_frame_h))
frame_data[
frame_landmark[29, 1] - crop_radius: frame_landmark[29, 1] + crop_radius * 2,
frame_landmark[33, 0] - crop_radius - crop_radius_1_4: frame_landmark[33, 0] + crop_radius + crop_radius_1_4, :] = pre_frame_resize[:crop_radius * 3, :, :]
videowriter.write(frame_data[:, :, ::-1])
end = time.time()
time_sum += end - start
videowriter.release()
videowriter_face.release()
return time_sum
def combine_audio(self, res_video_path, audio_path, time_sum, pad_length):
video_add_audio_path = res_video_path.replace('.mp4', '_add_audio.mp4')
if os.path.exists(video_add_audio_path):
os.remove(video_add_audio_path)
cmd = f'ffmpeg -i {res_video_path} -i {audio_path} -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 {video_add_audio_path}'
subprocess.call(cmd, shell=True)
print(f"\n\nVideo generation complete (average generation time: {time_sum/(pad_length-5):.2f}s, total time: {time_sum:.2f})\n")
print(f"Saved folder path: {self.opt.res_video_dir}")
return video_add_audio_path
def select_closing(self, video_add_audio_path):
while True:
print("----------------------------------------------------------------------------")
print("\n 6. Please select the closing option.\n\n")
print(f"1. View created images immediately. ({os.path.basename(video_add_audio_path)})")
print("2. Just turn it off.")
option = input("\n\nSelect option: ")
if option == '1':
print("\nLet's play the generated video.\n")
VideoPlayer.play_video(video_add_audio_path)
elif option == '2':
break
else:
print("Invalid option, please try again.")
print("\nExit the program.")
def usage():
print("Usage: python facial_dubbing.py --target_video_path [target_video_path] --source_audio_path [source_audio_path] --res_video_dir [res_video_dir] --OpenFace_install_path [OpenFace_install_path] --deepspeech_model_path [deepspeech_model_path] --pretrained_clip_segDinet_path [pretrained_clip_segDinet_path] --source_channel [source_channel] --ref_channel [ref_channel] --audio_channel [audio_channel] --mouth_region_size [mouth_region_size]")
def segdinet_banner(opt):
print("""
_____ _____ _
| __ \_ _| | |
___ ___ __ _| | | || | _ __ ___| |_
/ __|/ _ \/ _` | | | || | | '_ \ / _ \ __|
\__ \ __/ (_| | |__| || |_| | | | __/ |_
|___/\___|\__, |_____/_____|_| |_|\___|\__|
__/ |
|___/
""")
print("Load the pre-trained SegDINet to run a program that synthesizes images \nin which the target person (target video) speaks according to the source audio.")
print(f"target video: {os.path.basename(opt.target_video_path)}")
print(f"source audio: {os.path.basename(opt.source_audio_path)}")
if __name__ == "__main__":
opt = InferenceOptions().parse_args()
# Check target video exist
if not os.path.exists(opt.target_video_path):
raise FileNotFoundError(f'Wrong target video path: {opt.target_video_path}')
# Check source audio exist
if not os.path.exists(opt.source_audio_path):
raise FileNotFoundError(f'Wrong source audio path: {opt.source_audio_path}')
segdinet_banner(opt)
pipeline = FacialDubbingPipeline(opt)
pipeline.run()
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
import os
from config.config import InferenceOptions
from facial_dubbing import FacialDubbingPipeline, segdinet_banner
import shutil
app = FastAPI()
app.mount("/result", StaticFiles(directory="result"), name="result")
@app.get("/", response_class=HTMLResponse)
def home():
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Facial Dubbing</title>
</head>
<body>
<h1>Facial Dubbing</h1>
<form action="/inference" method="post" enctype="multipart/form-data">
<label for="target_video">Target Video</label>
<input type="file" id="target_video" name="target_video" accept="video/*" required><br><br>
<label for="source_audio">Source Audio</label>
<input type="file" id="source_audio" name="source_audio" accept="audio/*" required><br><br>
<button type="submit">Submit</button>
</form>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.post("/inference")
async def inference(target_video: UploadFile = File(...), source_audio: UploadFile = File(...)):
opt = InferenceOptions().parse_args()
if not os.path.exists('uploads'):
os.makedirs('uploads')
target_video_path = os.path.join('uploads', target_video.filename)
source_audio_path = os.path.join('uploads', source_audio.filename)
target_openface_landmark_path = os.path.join('input_file', target_video.filename.split('.')[0] + '.csv')
with open(target_video_path, 'wb') as f:
shutil.copyfileobj(target_video.file, f)
with open(source_audio_path, 'wb') as f:
shutil.copyfileobj(source_audio.file, f)
segdinet_banner(opt)
opt_ = vars(opt)
opt_['target_video_path'] = target_video_path
opt_['source_audio_path'] = source_audio_path
opt_['target_openface_landmark_path'] = target_openface_landmark_path
print(opt)
pipeline = FacialDubbingPipeline(opt)
video_add_audio_path = pipeline.run()
host_ip = os.popen('hostname -I').read().split()[0]
return JSONResponse(content={'video_path': 'http://wearable.lan:8888/' + video_add_audio_path})
if __name__ == '__main__':
import uvicorn
uvicorn.run(app, host='0.0.0.0', port=8888)
viewer.py donghee@wearable
import sqlite3
import gradio as gr
import pandas as pd
import random
from pathlib import Path
from apscheduler.schedulers.background import BackgroundScheduler
import requests
import os
import tempfile
DB_FILE = "./result.db"
db = sqlite3.connect(DB_FILE)
try:
db.execute("SELECT * FROM survey").fetchall()
db.close()
except Exception as e:
db.execute(
'''
CREATE TABLE survey (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
name TEXT, age INTEGER, model TEXT)
''')
db.commit()
db.execute(
'''
CREATE TABLE user_study (id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
create_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
name TEXT, metric_a TEXT, metric_b TEXT, metric_c TEXT)
''')
db.commit()
db.close()
gr.set_static_paths(paths=["data/"])
def get_surveys(db):
surveys = db.execute("SELECT * FROM survey").fetchall()
total_surveys = db.execute("SELECT COUNT(*) FROM survey").fetchone()[0]
# surveys = [{"name": name, "age": age} for name, age in surveys]
surveys = pd.DataFrame(surveys, columns=["id", "date_created", "name", "age", "model"])
return surveys, total_surveys
def get_user_studies(db):
user_studies = db.execute("SELECT * FROM user_study").fetchall()
total_studies = db.execute("SELECT COUNT(*) FROM user_study").fetchone()[0]
user_studies = pd.DataFrame(user_studies, columns=["id", "date_created", "name", "metric_a", "metric_b", "metric_c"])
return user_studies, total_studies
def insert_survey(name, age, model):
db = sqlite3.connect(DB_FILE)
db.execute("INSERT INTO survey (name, age, model) VALUES (?, ?, ?)", (name, age, model))
db.commit()
surveys, total_surveys = get_surveys(db)
db.close()
return surveys, total_surveys
def insert_user_study(name, metric_a, metric_b, metric_c):
db = sqlite3.connect(DB_FILE)
db.execute("INSERT INTO user_study (name, metric_a, metric_b, metric_c) VALUES (?, ?, ?, ?)", (name, metric_a, metric_b, metric_c))
db.commit()
db.close()
def validate_survey(name, age, model):
if not name:
raise gr.Error("Name is required")
if not age:
raise gr.Error("Age is required")
if not model:
raise gr.Error("Model is required")
return insert_survey(name, age, model)
def validate_user_study(metric_a, metric_b, metric_c, name):
if not metric_a:
raise gr.Error("Metric A is required")
if not metric_b:
raise gr.Error("Metric B is required")
if not metric_c:
raise gr.Error("Metric C is required")
if not name:
raise gr.Error("Name is required")
insert_user_study(name, metric_a, metric_b, metric_c)
gr.Info('Successfully submitted!')
def load_surveys():
db = sqlite3.connect(DB_FILE)
surveys, total_surveys = get_surveys(db)
db.close()
return surveys, total_surveys
def load_user_studies():
db = sqlite3.connect(DB_FILE)
user_studies, total_user_studies = get_user_studies(db)
db.close()
return user_studies, total_user_studies
insert_survey("John", 25, "model1")
insert_survey("Alice", 30, "model2")
# print(load_surveys())
def generate_images():
images = [
(random.choice(
[
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-1151ce9f4b2043de0d2e3b7826127998.jpg",
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-116b5e92936b766b7fdfc242649337f7.jpg",
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-1163530ca19b5cebe1b002b8ec67b6fc.jpg",
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-1116395d6e6a6581eef8b8038f4c8e55.jpg",
"http://www.marketingtool.online/en/face-generator/img/faces/avatar-11319be65db395d0e8e6855d18ddcef0.jpg",
]
), f"model {i}")
for i in range(3)
]
print(images)
return images
def generate_videos():
test_video = f"test00" + random.choice(["1", "2", "3", "4"])
videos = [
f"data/{test_video}/{test_video}_modelA.mp4",
f"data/{test_video}/{test_video}_modelB.mp4",
f"data/{test_video}/{test_video}_modelC.mp4",
f"data/{test_video}/{test_video}_modelD.mp4",
f"data/{test_video}/{test_video}_modelE.mp4",
]
return videos
def replay_videos():
return [gr.Video(autoplay=True, value=video) for video in generate_videos()]
API_URL = "http://wearable.lan:8888/inference"
#API_URL = "http://127.0.0.1:8887/inference"
def inference_video(video_model, voice_model, image_input, sound_input, text_input, pose_video_input):
# send pose_video_input, and sound_input to AI model using rest api
r = requests.post(API_URL, files={"target_video": open(pose_video_input, 'rb'), "source_audio": open(sound_input, 'rb')})
inferenced_video_url = r.json()["video_path"]
inference_video_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(inference_video_dir):
os.makedirs(inference_video_dir)
inferenced_video = os.path.join(inference_video_dir, Path(inferenced_video_url).name)
with open(inferenced_video, 'wb') as f:
f.write(requests.get(inferenced_video_url).content)
return inferenced_video
css = """
.radio-group .warp {
display: flex !important;
}
.radio-group label {
flex: 1 1 auto;
}
"""
pose_video_url = "./pose_video.mp4"
with gr.Blocks(theme=gr.themes.Soft(), css=css) as SurveyDemo:
gr.Markdown("# AI Avatar")
with gr.Tab(label="AI framework"):
gr.Markdown("## AI Framework")
with gr.Row():
with gr.Column():
gr.Markdown("### Input")
video_model = gr.Textbox(label="Video Model", placeholder="Enter model name", value="segDInet")
voice_model = gr.Textbox(label="Voice Model", placeholder="Enter model name", value="?")
image_input = gr.Image(label="Target Image")
sound_input = gr.Audio(label="Driving Audio", type="filepath")
text_input = gr.Textbox(label="Source Text", placeholder="Enter text")
pose_video_input = gr.Video(label="Driving Video", format="mp4")
def on_pose_video(value):
print(value)
def on_sound_input(value):
print(value)
pose_video_input.upload(on_pose_video, pose_video_input)
sound_input.upload(on_sound_input, sound_input)
submit = gr.Button(value="Submit")
with gr.Column():
gr.Markdown("### Result")
video_output = gr.Video(label="Output Video")
submit.click(inference_video, [video_model, voice_model, image_input, sound_input, text_input, pose_video_input], [video_output])
with gr.Tab(label="AI Avatar Result") as ai_avatar_result_tab:
gr.Markdown("## AI Avatar Result")
with gr.Row():
with gr.Column():
gr.Markdown("### 파일 목록")
gr.FileExplorer(file_count="multiple", root="./", ignore_glob=".*",
interactive=True, glob="**/*.*")
with gr.Column():
gr.Markdown("### 원본 영상")
pose_video_output = gr.Video(autoplay=True)
with gr.Column():
gr.Markdown("### 합성 영상")
synthesis_video_output = gr.Video(autoplay=True)
with gr.Tab(label="User Study"):
gr.Markdown("## User Study")
with gr.Row():
video0 = gr.Video(autoplay=True, label="ModelA")
video1 = gr.Video(autoplay=True, label="ModelB")
video2 = gr.Video(autoplay=True, label="ModelC")
video3 = gr.Video(autoplay=True, label="ModelD")
video4 = gr.Video(autoplay=True, label="ModelE")
metric_a = gr.Radio(["A", "B", "C", "D", "E"], label="Metric A", info="영상 선명도", elem_classes="radio-group")
metric_b = gr.Radio(["A", "B", "C", "D", "E"], label="Metric B", info="입술 동기화", elem_classes="radio-group")
metric_c = gr.Radio(["A", "B", "C", "D", "E"], label="Metric C", info="영상 품질", elem_classes="radio-group")
with gr.Row():
with gr.Column():
name = gr.Textbox(label="Name", placeholder="Enter your name")
with gr.Column():
retry = gr.ClearButton([metric_a, metric_b, metric_c, name], value='Retry', scale=2)
with gr.Column():
check = gr.Button(value="Check", scale=2)
SurveyDemo.load(generate_videos, None, [video0, video1, video2, video3, video4])
retry.click(replay_videos, None, [video0, video1, video2, video3, video4])
check.click(validate_user_study, [metric_a, metric_b, metric_c, name], None)
with gr.Tab(label="User Study Results") as user_study_results_tab:
gr.Markdown("## User Study Results")
data = gr.Dataframe(headers=["Name", "MetricA", "MetricB", "MetricC"], visible=True)
count = gr.Number(label="Total User Studies")
SurveyDemo.load(load_user_studies, None, [data, count])
def on_ai_avatar_result_tab_update(pose_video_input, video_output):
return pose_video_input, video_output
def on_user_study_results_tabs_update():
return load_user_studies()
ai_avatar_result_tab.select(on_ai_avatar_result_tab_update, [pose_video_input, video_output], [pose_video_output, synthesis_video_output])
user_study_results_tab.select(on_user_study_results_tabs_update, None, [data, count])
def backup_data():
user_studies, _ = load_user_studies()
user_studies.to_csv("./result.csv", index=False)
scheduler = BackgroundScheduler()
scheduler.add_job(backup_data, trigger='interval', seconds=1)
scheduler.start()
SurveyDemo.launch(share=True, server_name='0.0.0.0', server_port=7860)
#SurveyDemo.launch(share=True, server_name='0.0.0.0', server_port=8888)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment