Created
November 14, 2023 07:26
-
-
Save koorukuroo/4efdfa36394ee1f925b4397156c58660 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from diffusers import LCMScheduler | |
from diffusers import AutoPipelineForText2Image | |
import torch | |
from dotenv import load_dotenv | |
import os | |
from openai import OpenAI | |
import gradio as gr | |
import random | |
import time | |
import string | |
load_dotenv(dotenv_path="./.env", verbose=True, override=True) | |
# OpenAI 클라이언트 생성 | |
client = OpenAI( | |
api_key=os.environ.get("OPENAI_API_KEY") | |
) | |
############### | |
# OpenAI GPT | |
############### | |
def chat_completion(messages: list) -> list[str]: | |
""" | |
Generates a list of completed messages using OpenAI's GPT-4 model. | |
Args: | |
messages (list): A list of messages to be completed. | |
Returns: | |
list[str]: A list of completed messages. | |
""" | |
stream = client.chat.completions.create( | |
# model='gpt-3.5-turbo', | |
model='gpt-4-1106-preview', | |
messages=messages, | |
# temperature=1.0, | |
stream=True | |
) | |
collected_messages = [] | |
for part in stream: | |
content = part.choices[0].delta.content or "" | |
if content: | |
collected_messages.append( | |
content | |
) | |
return collected_messages | |
def generate_messages(messages: list) -> list: | |
""" | |
Formats a list of messages into a list of dictionaries with 'role' and 'content' keys. | |
Args: | |
messages (list): A list of tuples containing user messages and assistant responses. | |
Returns: | |
list: A list of dictionaries with 'role' and 'content' keys representing the formatted messages. | |
""" | |
formated_messages = [ | |
{ | |
'role': 'system', | |
'content': '네 이름은 피카츄야. 너는 피카츄라서 피카, 피카피카피, 피카츄, 피카비 같은 말 밖에 못해.' | |
} | |
] | |
for m in messages: | |
formated_messages.append({ | |
'role': 'user', | |
'content': m[0] | |
}) | |
if m[1] != None: | |
formated_messages.append({ | |
'role': 'assistant', | |
'content': m[1] | |
}) | |
return formated_messages | |
def generate_response(chat_history: list) -> list: | |
messages = generate_messages(chat_history) | |
bot_message = chat_completion(messages) | |
print(chat_history) | |
chat_history[-1][1] = '' | |
print(chat_history) | |
print(bot_message) | |
for bm in bot_message: | |
chat_history[-1][1] += bm | |
yield chat_history | |
def set_user_response(user_message: str, chat_history: list) -> tuple: | |
chat_history += [[user_message, None]] | |
return '', chat_history | |
############# | |
# SDXL | |
############# | |
model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
adapter_id = "latent-consistency/lcm-lora-sdxl" | |
pipe = AutoPipelineForText2Image.from_pretrained( | |
model_id, torch_dtype=torch.float32, variant="fp32") | |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) | |
# pipe.to("cuda") | |
# load and fuse lcm lora | |
pipe.load_lora_weights(adapter_id) | |
pipe.fuse_lora() | |
def generate_random_string(length): | |
# 문자열 생성에 사용할 문자 집합 정의 | |
characters = string.ascii_letters + string.digits | |
# 지정된 길이의 랜덤 문자열 생성 | |
random_string = ''.join(random.choice(characters) for i in range(length)) | |
return random_string | |
def generate_image_url(text): | |
# 여기서는 예시를 위해 간단히 텍스트를 이미지 URL로 변환합니다. | |
# 실제로는 사용자의 텍스트에 따라 적절한 이미지 URL을 생성하거나 검색해야 합니다. | |
# prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k" | |
prompt = text | |
# disable guidance_scale by passing 0 | |
image = pipe(prompt=prompt, num_inference_steps=4, | |
guidance_scale=0).images[0] | |
filename = f"img_{generate_random_string(5)}.jpg" | |
image.save(filename) | |
return filename | |
########################## | |
# Gradio | |
########################## | |
with gr.Blocks() as demo: | |
chatbot = gr.components.Chatbot(label='챗봇', height=600) | |
msg = gr.components.Textbox(label='입력 메시지', autofocus=True) | |
clear = gr.components.ClearButton(value="다시 시작") | |
msg.submit(set_user_response, [msg, chatbot], [msg, chatbot], queue=False).then( | |
generate_response, chatbot, chatbot) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
with gr.Row(): | |
text_input = gr.Textbox(placeholder="텍스트를 입력하세요...") | |
submit_button = gr.Button("이미지 업데이트") | |
image_output = gr.Image() | |
def update_image(text): | |
image_url = generate_image_url(text) | |
return image_url | |
# image_output.update(value=image_url) | |
submit_button.click(update_image, text_input, image_output) | |
if __name__ == '__main__': | |
demo.queue().launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment