Skip to content

Instantly share code, notes, and snippets.

@ayan4m1
Last active October 15, 2022 15:55
Show Gist options
  • Save ayan4m1/92e3e186a0f5f1fe50baf4475319e32f to your computer and use it in GitHub Desktop.
Save ayan4m1/92e3e186a0f5f1fe50baf4475319e32f to your computer and use it in GitHub Desktop.
import torch
import cv2
import RRDBNet_arch as arch
import numpy as np
import BLIP.models.blip
import os
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, LMSDiscreteScheduler
from torch import autocast
from PIL import Image, PngImagePlugin
from flask import Flask, request
app = Flask(__name__)
sd_model_path = '../models/stable-diffusion-v1-4'
# sd_model_path = '../models/waifu-diffusion'
esrgan_model_path = './4x_foolhardy_Remacri_out.pth'
device = torch.device('cuda')
# create Stable Diffusion pipelines
lms = LMSDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear'
)
text_pipe = StableDiffusionPipeline.from_pretrained(sd_model_path, scheduler=lms, revision='fp16', torch_dtype=torch.float16)
text_pipe = text_pipe.to(device)
text_pipe.enable_attention_slicing()
image_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(sd_model_path, scheduler=lms, revision='fp16', torch_dtype=torch.float16)
image_pipe = image_pipe.to(device)
image_pipe.enable_attention_slicing()
# create ESRGAN model
upscale_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
upscale_model.load_state_dict(torch.load(esrgan_model_path), strict=True)
upscale_model.eval()
upscale_model = upscale_model.to(device)
# create interrogation model
blip_num_beams = 32
blip_min_length = 4
blip_max_length = 30
blip_image_eval_size = 384
blip_model_path = './model_base_caption_capfilt_large.pth'
blip_config_path = os.path.join("./", "BLIP", "configs", "med_config.json")
blip_model = BLIP.models.blip.blip_decoder(pretrained=blip_model_path, image_size=blip_image_eval_size, vit='base', med_config=blip_config_path).half()
blip_model = blip_model.to(device)
@app.route('/txt2img', methods=['POST'])
def txt2img():
prompt = request.form['prompt']
out_file = request.form['outFile']
seed = int(request.form['seed'])
height = int(request.form['height'])
width = int(request.form['width'])
steps = int(request.form['steps'])
generator = torch.Generator('cuda').manual_seed(seed)
# run iterations and save output
with autocast("cuda"):
image = text_pipe(
prompt=prompt,
num_inference_steps=steps,
generator=generator,
height=height,
width=width
)["sample"]
# embed metadata in exif
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', f'Prompt: {prompt} Seed: {seed} Steps: {steps}')
image.save(out_file, 'PNG', pnginfo=pnginfo)
return 'OK'
@app.route('/img2img', methods=['POST'])
def img2img():
prompt = request.form['prompt']
in_file = request.form['inFile']
out_file = request.form['outFile']
seed = int(request.form['seed'])
steps = int(request.form['steps'])
strength = float(request.form['strength'])
generator = torch.Generator('cuda').manual_seed(seed)
in_image = Image.open(in_file).convert('RGB')
# patches to unet_blocks.py in diffusers needed for this to work
width, height = in_image.size
aspect_ratio = width / height
if max(width, height) == width:
width = 512
height = round(64 / aspect_ratio) * 8
else:
height = 512
width = round(64 / aspect_ratio) * 8
in_image = in_image.resize((width, height), resample=Image.Resampling.LANCZOS)
# run iterations and save output
with autocast("cuda"):
image = image_pipe(
prompt=prompt,
init_image=in_image,
strength=strength,
num_inference_steps=steps,
generator=generator
)["sample"]
# embed metadata in exif
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', f'Prompt: {prompt} Seed: {seed} Steps: {steps} Strength: {strength}')
image.save(out_file, 'PNG', pnginfo=pnginfo)
return 'OK'
@app.route('/upscale', methods=['POST'])
def upscale():
in_file = request.form['inFile']
out_file = request.form['outFile']
# read image
img = cv2.imread(in_file, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.to(device)
# write upscaled image
with torch.no_grad():
output = upscale_model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
# downscale image so it fits on discord
res = cv2.resize(output, dsize=(1536, 1536), interpolation=cv2.INTER_LANCZOS4)
cv2.imwrite(out_file, res)
return 'OK'
@app.route('/interrogate', methods=['POST'])
def interrogate():
in_file = request.form['inFile']
img = Image.open(in_file).convert('RGB')
gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(img).unsqueeze(0).type(torch.cuda.HalfTensor).to(device)
with torch.no_grad():
caption = blip_model.generate(gpu_image, sample=False, num_beams=blip_num_beams, min_length=blip_min_length, max_length=blip_max_length)
return caption[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment