Skip to content

Instantly share code, notes, and snippets.

@voluntas
Last active February 1, 2024 08:14
Show Gist options
  • Save voluntas/88bba0157546836a04b51ef9b82a38e6 to your computer and use it in GitHub Desktop.
Save voluntas/88bba0157546836a04b51ef9b82a38e6 to your computer and use it in GitHub Desktop.
笑い男
import argparse
import json
import math
import os
from pathlib import Path
import cv2
import mediapipe as mp
import numpy as np
from PIL import Image, ImageSequence
from sora_sdk import Sora
class LogoStreamer:
def __init__(
self,
signaling_urls,
role,
channel_id,
metadata,
camera_id,
video_width,
video_height,
):
self.mp_face_detection = mp.solutions.face_detection
self.sora = Sora()
self.video_source = self.sora.create_video_source()
self.connection = self.sora.create_connection(
signaling_urls=signaling_urls,
role=role,
channel_id=channel_id,
metadata=metadata,
video_source=self.video_source,
)
self.connection.on_disconnect = self.on_disconnect
self.video_capture = cv2.VideoCapture(camera_id)
if video_width is not None:
self.video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, video_width)
if video_height is not None:
self.video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height)
self.running = True
# GIFを読み込む
self.load_gif(Path(__file__).parent.joinpath("img_mark_04.gif"))
def load_gif(self, filepath):
gif = Image.open(filepath)
self.gif_frames = []
# GIFの各フレームを処理
for frame in ImageSequence.Iterator(gif):
# フレームをRGBAモードに変換して透過情報を保持
rgba_frame = frame.convert("RGBA")
self.gif_frames.append(rgba_frame)
self.current_frame = 0
def get_next_gif_frame(self):
frame = self.gif_frames[self.current_frame]
self.current_frame = (self.current_frame + 1) % len(self.gif_frames)
return frame
def on_disconnect(self, error_code, message):
print(f"Sora から切断されました: error_code='{error_code}' message='{message}'")
self.running = False
def run(self):
self.connection.connect()
try:
# 顔検出を用意する
with self.mp_face_detection.FaceDetection(
model_selection=0, min_detection_confidence=0.5
) as face_detection:
while self.running and self.video_capture.isOpened():
self.run_one_frame(face_detection)
except KeyboardInterrupt:
pass
finally:
self.connection.disconnect()
self.video_capture.release()
def run_one_frame(self, face_detection):
# フレームを取得する
success, frame = self.video_capture.read()
if not success:
return
# 高速化のための設定
frame.flags.writeable = False
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# mediapipe で顔を検出する
results = face_detection.process(frame)
frame_height, frame_width, _ = frame.shape
pil_image = Image.fromarray(frame)
if results.detections:
for detection in results.detections:
location = detection.location_data
if not location.HasField("relative_bounding_box"):
continue
bb = location.relative_bounding_box
# 逆正規化を行う
w_px = math.floor(bb.width * frame_width)
h_px = math.floor(bb.height * frame_height)
x_px = min(math.floor(bb.xmin * frame_width), frame_width - 1)
y_px = min(math.floor(bb.ymin * frame_height), frame_height - 1)
# 検出領域を調整
fixed_w_px = math.floor(w_px * 2.6)
fixed_h_px = math.floor(h_px * 2.6)
fixed_x_px = max(0, math.floor(x_px - (fixed_w_px - w_px) / 2))
fixed_y_px = max(0, math.floor(y_px - (fixed_h_px - h_px) / 1.5))
# GIFフレームを取得してリサイズ
gif_frame = self.get_next_gif_frame().resize((fixed_w_px, fixed_h_px))
# リサイズしたフレームをPILイメージに合成
pil_image.paste(gif_frame, (fixed_x_px, fixed_y_px), gif_frame)
frame.flags.writeable = True
frame = np.array(pil_image)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
# WebRTC に渡す
self.video_source.on_captured(frame)
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# 必須引数
default_signaling_urls = os.getenv("SORA_SIGNALING_URLS")
parser.add_argument(
"--signaling-urls",
default=default_signaling_urls,
type=str,
nargs="+",
required=not default_signaling_urls,
help="シグナリング URL",
)
default_channel_id = os.getenv("SORA_CHANNEL_ID")
parser.add_argument(
"--channel-id",
default=default_channel_id,
required=not default_channel_id,
help="チャネルID",
)
# オプション引数
parser.add_argument("--metadata", help="メタデータ JSON")
parser.add_argument(
"--camera-id", type=int, default=0, help="cv2.VideoCapture() に渡すカメラ ID"
)
parser.add_argument(
"--video-width",
type=int,
default=os.getenv("SORA_VIDEO_WIDTH"),
help="入力カメラ映像の横幅のヒント",
)
parser.add_argument(
"--video-height",
type=int,
default=os.getenv("SORA_VIDEO_HEIGHT"),
help="入力カメラ映像の高さのヒント",
)
args = parser.parse_args()
metadata = None
if args.metadata:
metadata = json.loads(args.metadata)
streamer = LogoStreamer(
signaling_urls=args.signaling_urls,
role="sendonly",
channel_id=args.channel_id,
metadata=args.metadata,
camera_id=args.camera_id,
video_height=args.video_height,
video_width=args.video_width,
)
streamer.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment