Last active
March 29, 2022 15:18
-
-
Save ttt733/f8fa68a7c5b500d36d5ded465d059d78 to your computer and use it in GitHub Desktop.
To be placed in the "notebooks" dir of RQ-VAE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import uuid | |
import discord | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import torch | |
import torchvision | |
import clip | |
import random | |
import os | |
from notebook_utils import TextEncoder, load_model, get_generated_images_by_texts | |
# load stage 1 model: RQ-VAE | |
vqvae_path = '/path/to/model/cc3m_cc12m_yfcc/stage1/model.pt' | |
model_vqvae, _ = load_model(vqvae_path, map_location='cuda') | |
print('stage 1 loaded') | |
# load stage 2 model: RQ-Transformer | |
model_path = '/path/to/model/cc3m_cc12m_yfcc/stage2/model.pt' | |
model_ar, config = load_model(model_path, ema=False, map_location='cuda') | |
print('stage 2 loaded') | |
# move models from cpu to gpu | |
model_ar = model_ar.cuda().eval() | |
model_vqvae = model_vqvae.cuda().eval() | |
print('moved to gpu') | |
# the checkpoint of CLIP will be downloaded at the first time. | |
model_clip, preprocess_clip = clip.load("ViT-B/32", device='cpu') | |
model_clip = model_clip.cuda().eval() | |
print('loaded clip') | |
# prepare text encoder to tokenize natual languages | |
text_encoder = TextEncoder(tokenizer_name=config.dataset.txt_tok_name, | |
context_length=config.dataset.context_length) | |
print('encoder loaded') | |
client = discord.Client() | |
@client.event | |
async def on_ready(): | |
print(f'{client.user} has connected to Discord!') | |
@client.event | |
async def on_message(message): | |
if message.author == client.user: | |
return | |
if message.content.startswith('!rq '): | |
prompt = message.content[4:] | |
id = str(uuid.uuid4()) | |
gen_image(prompt, id, .6, 1.1) | |
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message) | |
os.remove(f'{id}.jpg') | |
if message.content.startswith('!rq-crazy '): | |
prompt = message.content[10:] | |
id = str(uuid.uuid4()) | |
gen_image(prompt, id, .9, 1.3) | |
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message) | |
os.remove(f'{id}.jpg') | |
if message.content.startswith('!rq-real '): | |
prompt = message.content[9:] | |
id = str(uuid.uuid4()) | |
gen_image(prompt, id, .525, .725) | |
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message) | |
os.remove(f'{id}.jpg') | |
if message.content.startswith('!rq-tooreal '): | |
prompt = message.content[12:] | |
id = str(uuid.uuid4()) | |
gen_image(prompt, id, .3, .55) | |
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message) | |
os.remove(f'{id}.jpg') | |
def gen_image(prompt, id, temp_low, temp_high): | |
num_samples = 8 # This can be increased to give a better chance at a better result | |
temperature= random.uniform(temp_low, temp_high) | |
top_k=1024 | |
top_p=0.95 | |
pixels = get_generated_images_by_texts(model_ar, | |
model_vqvae, | |
text_encoder, | |
model_clip, | |
preprocess_clip, | |
prompt, | |
num_samples, | |
temperature, | |
top_k, | |
top_p, | |
amp=False, | |
#cached=True, | |
) | |
num_visualize_samples = 1 | |
images = [pixel.cpu().numpy() * 0.5 + 0.5 for pixel in pixels] | |
images = torch.from_numpy(np.array(images[:num_visualize_samples])) | |
images = torch.clamp(images, 0, 1) | |
grid = torchvision.utils.make_grid(images, nrow=1) | |
img = Image.fromarray(np.uint8(grid.numpy().transpose([1,2,0])*255)) | |
img.save(f'{id}.jpg') | |
print('running client') | |
# TODO: before committing to source control, move to an environment variable | |
TOKEN = `your bot's token here` | |
client.run(TOKEN) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment