Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created June 23, 2023 15:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save laksjdjf/22d1522b52c9668314882e7ad6c897ea to your computer and use it in GitHub Desktop.
Save laksjdjf/22d1522b52c9668314882e7ad6c897ea to your computer and use it in GitHub Desktop.
================================================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Kernel Shape
================================================================================================================================================================
SdxlUNet2DConditionModel (SdxlUNet2DConditionModel) [1, 4, 128, 128] [1, 4, 128, 128] -- --
├─Sequential (time_embed) [1, 320] [1, 1280] -- --
│ └─Linear (0) [1, 320] [1, 1280] 410,880 --
│ └─SiLU (1) [1, 1280] [1, 1280] -- --
│ └─Linear (2) [1, 1280] [1, 1280] 1,639,680 --
├─Sequential (label_emb) [1, 2816] [1, 1280] -- --
│ └─Sequential (0) [1, 2816] [1, 1280] -- --
│ │ └─Linear (0) [1, 2816] [1, 1280] 3,605,760 --
│ │ └─SiLU (1) [1, 1280] [1, 1280] -- --
│ │ └─Linear (2) [1, 1280] [1, 1280] 1,639,680 --
├─ModuleList (input_blocks) -- -- -- --
│ └─Sequential (0) -- -- -- --
│ │ └─Conv2d (0) [1, 4, 128, 128] [1, 320, 128, 128] 11,840 [3, 3]
│ └─ModuleList (1) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ └─Sequential (in_layers) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Conv2d (2) [1, 320, 128, 128] [1, 320, 128, 128] 921,920 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 320] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 320] 409,920 --
│ │ │ └─Sequential (out_layers) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Identity (2) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Conv2d (3) [1, 320, 128, 128] [1, 320, 128, 128] 921,920 [3, 3]
│ │ │ └─Identity (skip_connection) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ └─ModuleList (2) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ └─Sequential (in_layers) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Conv2d (2) [1, 320, 128, 128] [1, 320, 128, 128] 921,920 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 320] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 320] 409,920 --
│ │ │ └─Sequential (out_layers) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Identity (2) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Conv2d (3) [1, 320, 128, 128] [1, 320, 128, 128] 921,920 [3, 3]
│ │ │ └─Identity (skip_connection) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ └─Sequential (3) -- -- -- --
│ │ └─Downsample2D (0) [1, 320, 128, 128] [1, 320, 64, 64] -- --
│ │ │ └─Conv2d (op) [1, 320, 128, 128] [1, 320, 64, 64] 921,920 [3, 3]
│ └─ModuleList (4) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 320, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─Sequential (in_layers) [1, 320, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 64, 64] [1, 320, 64, 64] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 64, 64] [1, 320, 64, 64] -- --
│ │ │ │ └─Conv2d (2) [1, 320, 64, 64] [1, 640, 64, 64] 1,843,840 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 640] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 640] 819,840 --
│ │ │ └─Sequential (out_layers) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Identity (2) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Conv2d (3) [1, 640, 64, 64] [1, 640, 64, 64] 3,687,040 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 320, 64, 64] [1, 640, 64, 64] 205,440 [1, 1]
│ │ └─Transformer2DModel (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─GroupNorm (norm) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ └─Linear (proj_in) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ └─Linear (proj_out) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ └─ModuleList (5) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─Sequential (in_layers) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Conv2d (2) [1, 640, 64, 64] [1, 640, 64, 64] 3,687,040 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 640] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 640] 819,840 --
│ │ │ └─Sequential (out_layers) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Identity (2) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Conv2d (3) [1, 640, 64, 64] [1, 640, 64, 64] 3,687,040 [3, 3]
│ │ │ └─Identity (skip_connection) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ └─Transformer2DModel (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─GroupNorm (norm) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ └─Linear (proj_in) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ └─Linear (proj_out) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ └─Sequential (6) -- -- -- --
│ │ └─Downsample2D (0) [1, 640, 64, 64] [1, 640, 32, 32] -- --
│ │ │ └─Conv2d (op) [1, 640, 64, 64] [1, 640, 32, 32] 3,687,040 [3, 3]
│ └─ModuleList (7) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 640, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Sequential (in_layers) [1, 640, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 32, 32] [1, 640, 32, 32] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 32, 32] [1, 640, 32, 32] -- --
│ │ │ │ └─Conv2d (2) [1, 640, 32, 32] [1, 1280, 32, 32] 7,374,080 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 1280] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 1280] 1,639,680 --
│ │ │ └─Sequential (out_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Identity (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Conv2d (3) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 640, 32, 32] [1, 1280, 32, 32] 820,480 [1, 1]
│ │ └─Transformer2DModel (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm (norm) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─Linear (proj_in) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (3) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (4) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (5) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (6) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (7) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (8) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (9) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─Linear (proj_out) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ └─ModuleList (8) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Sequential (in_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Conv2d (2) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 1280] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 1280] 1,639,680 --
│ │ │ └─Sequential (out_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Identity (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Conv2d (3) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ │ └─Identity (skip_connection) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ └─Transformer2DModel (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm (norm) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─Linear (proj_in) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (3) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (4) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (5) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (6) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (7) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (8) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (9) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─Linear (proj_out) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
├─ModuleList (middle_block) -- -- -- --
│ └─ResnetBlock2D (0) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ └─Sequential (in_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Conv2d (2) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ └─Sequential (emb_layers) [1, 1280] [1, 1280] -- --
│ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ └─Linear (1) [1, 1280] [1, 1280] 1,639,680 --
│ │ └─Sequential (out_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Identity (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Conv2d (3) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ └─Identity (skip_connection) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ └─Transformer2DModel (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ └─GroupNorm (norm) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ └─Linear (proj_in) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ └─BasicTransformerBlock (0) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (3) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (4) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (5) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (6) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (7) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (8) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─BasicTransformerBlock (9) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ └─Linear (proj_out) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ └─ResnetBlock2D (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ └─Sequential (in_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Conv2d (2) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ └─Sequential (emb_layers) [1, 1280] [1, 1280] -- --
│ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ └─Linear (1) [1, 1280] [1, 1280] 1,639,680 --
│ │ └─Sequential (out_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Identity (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Conv2d (3) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ └─Identity (skip_connection) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
├─ModuleList (output_blocks) -- -- -- --
│ └─ModuleList (0) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 2560, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Sequential (in_layers) [1, 2560, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 2560, 32, 32] [1, 2560, 32, 32] 5,120 --
│ │ │ │ └─SiLU (1) [1, 2560, 32, 32] [1, 2560, 32, 32] -- --
│ │ │ │ └─Conv2d (2) [1, 2560, 32, 32] [1, 1280, 32, 32] 29,492,480 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 1280] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 1280] 1,639,680 --
│ │ │ └─Sequential (out_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Identity (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Conv2d (3) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 2560, 32, 32] [1, 1280, 32, 32] 3,278,080 [1, 1]
│ │ └─Transformer2DModel (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm (norm) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─Linear (proj_in) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (3) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (4) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (5) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (6) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (7) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (8) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (9) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─Linear (proj_out) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ └─ModuleList (1) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 2560, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Sequential (in_layers) [1, 2560, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 2560, 32, 32] [1, 2560, 32, 32] 5,120 --
│ │ │ │ └─SiLU (1) [1, 2560, 32, 32] [1, 2560, 32, 32] -- --
│ │ │ │ └─Conv2d (2) [1, 2560, 32, 32] [1, 1280, 32, 32] 29,492,480 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 1280] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 1280] 1,639,680 --
│ │ │ └─Sequential (out_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Identity (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Conv2d (3) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 2560, 32, 32] [1, 1280, 32, 32] 3,278,080 [1, 1]
│ │ └─Transformer2DModel (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm (norm) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─Linear (proj_in) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (3) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (4) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (5) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (6) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (7) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (8) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (9) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─Linear (proj_out) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ └─ModuleList (2) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 1920, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─Sequential (in_layers) [1, 1920, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1920, 32, 32] [1, 1920, 32, 32] 3,840 --
│ │ │ │ └─SiLU (1) [1, 1920, 32, 32] [1, 1920, 32, 32] -- --
│ │ │ │ └─Conv2d (2) [1, 1920, 32, 32] [1, 1280, 32, 32] 22,119,680 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 1280] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 1280] 1,639,680 --
│ │ │ └─Sequential (out_layers) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ │ └─SiLU (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Identity (2) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ │ └─Conv2d (3) [1, 1280, 32, 32] [1, 1280, 32, 32] 14,746,880 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 1920, 32, 32] [1, 1280, 32, 32] 2,458,880 [1, 1]
│ │ └─Transformer2DModel (1) [1, 1280, 32, 32] [1, 1280, 32, 32] -- --
│ │ │ └─GroupNorm (norm) [1, 1280, 32, 32] [1, 1280, 32, 32] 2,560 --
│ │ │ └─Linear (proj_in) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (3) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (4) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (5) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (6) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (7) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (8) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ │ └─BasicTransformerBlock (9) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_v) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 1024, 1280] [1, 1024, 1280] 1,638,400 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 1280] 2,621,440 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 1024, 1280] [1, 1024, 1280] 2,560 --
│ │ │ │ │ └─FeedForward (ff) [1, 1024, 1280] [1, 1024, 1280] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 1024, 1280] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 1024, 1280] [1, 1024, 10240] 13,117,440 --
│ │ │ │ │ │ │ └─Identity (1) [1, 1024, 5120] [1, 1024, 5120] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 1024, 5120] [1, 1024, 1280] 6,554,880 --
│ │ │ └─Linear (proj_out) [1, 1024, 1280] [1, 1024, 1280] 1,639,680 --
│ │ └─Upsample2D (2) [1, 1280, 32, 32] [1, 1280, 64, 64] -- --
│ │ │ └─Conv2d (conv) [1, 1280, 64, 64] [1, 1280, 64, 64] 14,746,880 [3, 3]
│ └─ModuleList (3) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 1920, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─Sequential (in_layers) [1, 1920, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1920, 64, 64] [1, 1920, 64, 64] 3,840 --
│ │ │ │ └─SiLU (1) [1, 1920, 64, 64] [1, 1920, 64, 64] -- --
│ │ │ │ └─Conv2d (2) [1, 1920, 64, 64] [1, 640, 64, 64] 11,059,840 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 640] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 640] 819,840 --
│ │ │ └─Sequential (out_layers) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Identity (2) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Conv2d (3) [1, 640, 64, 64] [1, 640, 64, 64] 3,687,040 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 1920, 64, 64] [1, 640, 64, 64] 1,229,440 [1, 1]
│ │ └─Transformer2DModel (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─GroupNorm (norm) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ └─Linear (proj_in) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ └─Linear (proj_out) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ └─ModuleList (4) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 1280, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─Sequential (in_layers) [1, 1280, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 1280, 64, 64] [1, 1280, 64, 64] 2,560 --
│ │ │ │ └─SiLU (1) [1, 1280, 64, 64] [1, 1280, 64, 64] -- --
│ │ │ │ └─Conv2d (2) [1, 1280, 64, 64] [1, 640, 64, 64] 7,373,440 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 640] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 640] 819,840 --
│ │ │ └─Sequential (out_layers) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Identity (2) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Conv2d (3) [1, 640, 64, 64] [1, 640, 64, 64] 3,687,040 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 1280, 64, 64] [1, 640, 64, 64] 819,840 [1, 1]
│ │ └─Transformer2DModel (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─GroupNorm (norm) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ └─Linear (proj_in) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ └─Linear (proj_out) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ └─ModuleList (5) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 960, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─Sequential (in_layers) [1, 960, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 960, 64, 64] [1, 960, 64, 64] 1,920 --
│ │ │ │ └─SiLU (1) [1, 960, 64, 64] [1, 960, 64, 64] -- --
│ │ │ │ └─Conv2d (2) [1, 960, 64, 64] [1, 640, 64, 64] 5,530,240 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 640] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 640] 819,840 --
│ │ │ └─Sequential (out_layers) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Identity (2) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ │ └─Conv2d (3) [1, 640, 64, 64] [1, 640, 64, 64] 3,687,040 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 960, 64, 64] [1, 640, 64, 64] 615,040 [1, 1]
│ │ └─Transformer2DModel (1) [1, 640, 64, 64] [1, 640, 64, 64] -- --
│ │ │ └─GroupNorm (norm) [1, 640, 64, 64] [1, 640, 64, 64] 1,280 --
│ │ │ └─Linear (proj_in) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ └─ModuleList (transformer_blocks) -- -- -- --
│ │ │ │ └─BasicTransformerBlock (0) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ │ └─BasicTransformerBlock (1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ └─LayerNorm (norm1) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn1) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_v) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm2) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─CrossAttention (attn2) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─Linear (to_q) [1, 4096, 640] [1, 4096, 640] 409,600 --
│ │ │ │ │ │ └─Linear (to_k) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─Linear (to_v) [1, 77, 2048] [1, 77, 640] 1,310,720 --
│ │ │ │ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ │ │ │ └─Linear (0) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ │ │ │ └─LayerNorm (norm3) [1, 4096, 640] [1, 4096, 640] 1,280 --
│ │ │ │ │ └─FeedForward (ff) [1, 4096, 640] [1, 4096, 640] -- --
│ │ │ │ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ │ │ │ └─GEGLU (0) [1, 4096, 640] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ │ └─Linear (proj) [1, 4096, 640] [1, 4096, 5120] 3,281,920 --
│ │ │ │ │ │ │ └─Identity (1) [1, 4096, 2560] [1, 4096, 2560] -- --
│ │ │ │ │ │ │ └─Linear (2) [1, 4096, 2560] [1, 4096, 640] 1,639,040 --
│ │ │ └─Linear (proj_out) [1, 4096, 640] [1, 4096, 640] 410,240 --
│ │ └─Upsample2D (2) [1, 640, 64, 64] [1, 640, 128, 128] -- --
│ │ │ └─Conv2d (conv) [1, 640, 128, 128] [1, 640, 128, 128] 3,687,040 [3, 3]
│ └─ModuleList (6) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 960, 128, 128] [1, 320, 128, 128] -- --
│ │ │ └─Sequential (in_layers) [1, 960, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 960, 128, 128] [1, 960, 128, 128] 1,920 --
│ │ │ │ └─SiLU (1) [1, 960, 128, 128] [1, 960, 128, 128] -- --
│ │ │ │ └─Conv2d (2) [1, 960, 128, 128] [1, 320, 128, 128] 2,765,120 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 320] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 320] 409,920 --
│ │ │ └─Sequential (out_layers) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Identity (2) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Conv2d (3) [1, 320, 128, 128] [1, 320, 128, 128] 921,920 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 960, 128, 128] [1, 320, 128, 128] 307,520 [1, 1]
│ └─ModuleList (7) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 640, 128, 128] [1, 320, 128, 128] -- --
│ │ │ └─Sequential (in_layers) [1, 640, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 128, 128] [1, 640, 128, 128] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 128, 128] [1, 640, 128, 128] -- --
│ │ │ │ └─Conv2d (2) [1, 640, 128, 128] [1, 320, 128, 128] 1,843,520 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 320] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 320] 409,920 --
│ │ │ └─Sequential (out_layers) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Identity (2) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Conv2d (3) [1, 320, 128, 128] [1, 320, 128, 128] 921,920 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 640, 128, 128] [1, 320, 128, 128] 205,120 [1, 1]
│ └─ModuleList (8) -- -- -- --
│ │ └─ResnetBlock2D (0) [1, 640, 128, 128] [1, 320, 128, 128] -- --
│ │ │ └─Sequential (in_layers) [1, 640, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 640, 128, 128] [1, 640, 128, 128] 1,280 --
│ │ │ │ └─SiLU (1) [1, 640, 128, 128] [1, 640, 128, 128] -- --
│ │ │ │ └─Conv2d (2) [1, 640, 128, 128] [1, 320, 128, 128] 1,843,520 [3, 3]
│ │ │ └─Sequential (emb_layers) [1, 1280] [1, 320] -- --
│ │ │ │ └─SiLU (0) [1, 1280] [1, 1280] -- --
│ │ │ │ └─Linear (1) [1, 1280] [1, 320] 409,920 --
│ │ │ └─Sequential (out_layers) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ │ │ │ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Identity (2) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ │ │ │ └─Conv2d (3) [1, 320, 128, 128] [1, 320, 128, 128] 921,920 [3, 3]
│ │ │ └─Conv2d (skip_connection) [1, 640, 128, 128] [1, 320, 128, 128] 205,120 [1, 1]
├─ModuleList (out) -- -- -- --
│ └─GroupNorm32 (0) [1, 320, 128, 128] [1, 320, 128, 128] 640 --
│ └─SiLU (1) [1, 320, 128, 128] [1, 320, 128, 128] -- --
│ └─Conv2d (2) [1, 320, 128, 128] [1, 4, 128, 128] 11,524 [3, 3]
================================================================================================================================================================
Total params: 2,567,463,684
Trainable params: 2,567,463,684
Non-trainable params: 0
Total mult-adds (G): 813.94
================================================================================================================================================================
Input size (MB): 0.90
Forward/backward pass size (MB): 17986.64
Params size (MB): 10269.85
Estimated Total Size (MB): 28257.40
================================================================================================================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment