Skip to content

Instantly share code, notes, and snippets.

@ttt733
Last active March 29, 2022 15:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ttt733/f8fa68a7c5b500d36d5ded465d059d78 to your computer and use it in GitHub Desktop.
Save ttt733/f8fa68a7c5b500d36d5ded465d059d78 to your computer and use it in GitHub Desktop.
To be placed in the "notebooks" dir of RQ-VAE
import uuid
import discord
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torchvision
import clip
import random
import os
from notebook_utils import TextEncoder, load_model, get_generated_images_by_texts
# load stage 1 model: RQ-VAE
vqvae_path = '/path/to/model/cc3m_cc12m_yfcc/stage1/model.pt'
model_vqvae, _ = load_model(vqvae_path, map_location='cuda')
print('stage 1 loaded')
# load stage 2 model: RQ-Transformer
model_path = '/path/to/model/cc3m_cc12m_yfcc/stage2/model.pt'
model_ar, config = load_model(model_path, ema=False, map_location='cuda')
print('stage 2 loaded')
# move models from cpu to gpu
model_ar = model_ar.cuda().eval()
model_vqvae = model_vqvae.cuda().eval()
print('moved to gpu')
# the checkpoint of CLIP will be downloaded at the first time.
model_clip, preprocess_clip = clip.load("ViT-B/32", device='cpu')
model_clip = model_clip.cuda().eval()
print('loaded clip')
# prepare text encoder to tokenize natual languages
text_encoder = TextEncoder(tokenizer_name=config.dataset.txt_tok_name,
context_length=config.dataset.context_length)
print('encoder loaded')
client = discord.Client()
@client.event
async def on_ready():
print(f'{client.user} has connected to Discord!')
@client.event
async def on_message(message):
if message.author == client.user:
return
if message.content.startswith('!rq '):
prompt = message.content[4:]
id = str(uuid.uuid4())
gen_image(prompt, id, .6, 1.1)
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message)
os.remove(f'{id}.jpg')
if message.content.startswith('!rq-crazy '):
prompt = message.content[10:]
id = str(uuid.uuid4())
gen_image(prompt, id, .9, 1.3)
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message)
os.remove(f'{id}.jpg')
if message.content.startswith('!rq-real '):
prompt = message.content[9:]
id = str(uuid.uuid4())
gen_image(prompt, id, .525, .725)
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message)
os.remove(f'{id}.jpg')
if message.content.startswith('!rq-tooreal '):
prompt = message.content[12:]
id = str(uuid.uuid4())
gen_image(prompt, id, .3, .55)
await message.channel.send(file=discord.File(f'{id}.jpg'), reference=message)
os.remove(f'{id}.jpg')
def gen_image(prompt, id, temp_low, temp_high):
num_samples = 8 # This can be increased to give a better chance at a better result
temperature= random.uniform(temp_low, temp_high)
top_k=1024
top_p=0.95
pixels = get_generated_images_by_texts(model_ar,
model_vqvae,
text_encoder,
model_clip,
preprocess_clip,
prompt,
num_samples,
temperature,
top_k,
top_p,
amp=False,
#cached=True,
)
num_visualize_samples = 1
images = [pixel.cpu().numpy() * 0.5 + 0.5 for pixel in pixels]
images = torch.from_numpy(np.array(images[:num_visualize_samples]))
images = torch.clamp(images, 0, 1)
grid = torchvision.utils.make_grid(images, nrow=1)
img = Image.fromarray(np.uint8(grid.numpy().transpose([1,2,0])*255))
img.save(f'{id}.jpg')
print('running client')
# TODO: before committing to source control, move to an environment variable
TOKEN = `your bot's token here`
client.run(TOKEN)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment