Skip to content

Instantly share code, notes, and snippets.

@koorukuroo
Created November 14, 2023 07:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koorukuroo/4efdfa36394ee1f925b4397156c58660 to your computer and use it in GitHub Desktop.
Save koorukuroo/4efdfa36394ee1f925b4397156c58660 to your computer and use it in GitHub Desktop.
from diffusers import LCMScheduler
from diffusers import AutoPipelineForText2Image
import torch
from dotenv import load_dotenv
import os
from openai import OpenAI
import gradio as gr
import random
import time
import string
load_dotenv(dotenv_path="./.env", verbose=True, override=True)
# OpenAI 클라이언트 생성
client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY")
)
###############
# OpenAI GPT
###############
def chat_completion(messages: list) -> list[str]:
"""
Generates a list of completed messages using OpenAI's GPT-4 model.
Args:
messages (list): A list of messages to be completed.
Returns:
list[str]: A list of completed messages.
"""
stream = client.chat.completions.create(
# model='gpt-3.5-turbo',
model='gpt-4-1106-preview',
messages=messages,
# temperature=1.0,
stream=True
)
collected_messages = []
for part in stream:
content = part.choices[0].delta.content or ""
if content:
collected_messages.append(
content
)
return collected_messages
def generate_messages(messages: list) -> list:
"""
Formats a list of messages into a list of dictionaries with 'role' and 'content' keys.
Args:
messages (list): A list of tuples containing user messages and assistant responses.
Returns:
list: A list of dictionaries with 'role' and 'content' keys representing the formatted messages.
"""
formated_messages = [
{
'role': 'system',
'content': '네 이름은 피카츄야. 너는 피카츄라서 피카, 피카피카피, 피카츄, 피카비 같은 말 밖에 못해.'
}
]
for m in messages:
formated_messages.append({
'role': 'user',
'content': m[0]
})
if m[1] != None:
formated_messages.append({
'role': 'assistant',
'content': m[1]
})
return formated_messages
def generate_response(chat_history: list) -> list:
messages = generate_messages(chat_history)
bot_message = chat_completion(messages)
print(chat_history)
chat_history[-1][1] = ''
print(chat_history)
print(bot_message)
for bm in bot_message:
chat_history[-1][1] += bm
yield chat_history
def set_user_response(user_message: str, chat_history: list) -> tuple:
chat_history += [[user_message, None]]
return '', chat_history
#############
# SDXL
#############
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
adapter_id = "latent-consistency/lcm-lora-sdxl"
pipe = AutoPipelineForText2Image.from_pretrained(
model_id, torch_dtype=torch.float32, variant="fp32")
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
# pipe.to("cuda")
# load and fuse lcm lora
pipe.load_lora_weights(adapter_id)
pipe.fuse_lora()
def generate_random_string(length):
# 문자열 생성에 사용할 문자 집합 정의
characters = string.ascii_letters + string.digits
# 지정된 길이의 랜덤 문자열 생성
random_string = ''.join(random.choice(characters) for i in range(length))
return random_string
def generate_image_url(text):
# 여기서는 예시를 위해 간단히 텍스트를 이미지 URL로 변환합니다.
# 실제로는 사용자의 텍스트에 따라 적절한 이미지 URL을 생성하거나 검색해야 합니다.
# prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
prompt = text
# disable guidance_scale by passing 0
image = pipe(prompt=prompt, num_inference_steps=4,
guidance_scale=0).images[0]
filename = f"img_{generate_random_string(5)}.jpg"
image.save(filename)
return filename
##########################
# Gradio
##########################
with gr.Blocks() as demo:
chatbot = gr.components.Chatbot(label='챗봇', height=600)
msg = gr.components.Textbox(label='입력 메시지', autofocus=True)
clear = gr.components.ClearButton(value="다시 시작")
msg.submit(set_user_response, [msg, chatbot], [msg, chatbot], queue=False).then(
generate_response, chatbot, chatbot)
clear.click(lambda: None, None, chatbot, queue=False)
with gr.Row():
text_input = gr.Textbox(placeholder="텍스트를 입력하세요...")
submit_button = gr.Button("이미지 업데이트")
image_output = gr.Image()
def update_image(text):
image_url = generate_image_url(text)
return image_url
# image_output.update(value=image_url)
submit_button.click(update_image, text_input, image_output)
if __name__ == '__main__':
demo.queue().launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment