Skip to content

Instantly share code, notes, and snippets.

@torridgristle
Last active July 26, 2022 20:24
Show Gist options
  • Save torridgristle/cfa644a74e1a3ad96080127a7ef627ba to your computer and use it in GitHub Desktop.
Save torridgristle/cfa644a74e1a3ad96080127a7ef627ba to your computer and use it in GitHub Desktop.
VQGAN F8 Decoding with downscaled attention
def vqgan_dec_skip_lores_attn(h, temb=None):
# middle
h = vqgan.decoder.mid.block_1(h, temb)
h_half = F.upsample(h,scale_factor=0.5,mode='bicubic',align_corners=False)
h_half = vqgan.decoder.mid.attn_1(h_half) - h_half
h_half = F.upsample(h_half,scale_factor=2,mode='bicubic',align_corners=False)
h = h + h_half
h = vqgan.decoder.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(vqgan.decoder.num_resolutions)):
for i_block in range(vqgan.decoder.num_res_blocks+1):
h = vqgan.decoder.up[i_level].block[i_block](h, temb)
if len(vqgan.decoder.up[i_level].attn) > 0:
h_half = F.upsample(h,scale_factor=0.5,mode='bicubic',align_corners=False)
h_half = vqgan.decoder.up[i_level].attn[i_block](h_half) - h_half
h_half = F.upsample(h_half,scale_factor=2,mode='bicubic',align_corners=False)
h = h + h_half
if i_level != 0:
h = vqgan.decoder.up[i_level].upsample(h)
# end
if vqgan.decoder.give_pre_end:
return h
h = vqgan.decoder.norm_out(h)
h = h*torch.sigmoid(h)
h = vqgan.decoder.conv_out(h)
return h * 0.5 + 0.5
# The conv_in is split from the rest due to a project involving modifying an image's latents
# with a small conv model in an attempt to retain content and change style
decoded_image = vqgan.decoder.conv_in(vqgan.post_quant_conv(encoded_image))
decoded_image = vqgan_dec_skip_lores_attn(encoded_image)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment