Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created June 13, 2024 05:41
Show Gist options
  • Save laksjdjf/742fa0a17415f809bfce737351667102 to your computer and use it in GitHub Desktop.
Save laksjdjf/742fa0a17415f809bfce737351667102 to your computer and use it in GitHub Desktop.
================================================================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Kernel Shape
================================================================================================================================================================
SD3Transformer2DModel (SD3Transformer2DModel) -- [1, 16, 128, 128] -- --
├─PatchEmbed (pos_embed) [1, 16, 128, 128] [1, 4096, 1536] -- --
│ └─Conv2d (proj) [1, 16, 128, 128] [1, 1536, 64, 64] 99,840 [2, 2]
├─CombinedTimestepTextProjEmbeddings (time_text_embed) [1] [1, 1536] -- --
│ └─Timesteps (time_proj) [1] [1, 256] -- --
│ └─TimestepEmbedding (timestep_embedder) [1, 256] [1, 1536] -- --
│ │ └─Linear (linear_1) [1, 256] [1, 1536] 394,752 --
│ │ └─SiLU (act) [1, 1536] [1, 1536] -- --
│ │ └─Linear (linear_2) [1, 1536] [1, 1536] 2,360,832 --
│ └─PixArtAlphaTextProjection (text_embedder) [1, 2048] [1, 1536] -- --
│ │ └─Linear (linear_1) [1, 2048] [1, 1536] 3,147,264 --
│ │ └─SiLU (act_1) [1, 1536] [1, 1536] -- --
│ │ └─Linear (linear_2) [1, 1536] [1, 1536] 2,360,832 --
├─Linear (context_embedder) [1, 154, 4096] [1, 154, 1536] 6,292,992 --
├─ModuleList (transformer_blocks) -- -- -- --
│ └─JointTransformerBlock (0) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (1) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (2) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (3) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (4) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (5) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (6) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (7) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (8) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (9) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (10) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (11) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (12) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (13) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (14) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (15) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (16) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (17) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (18) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (19) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (20) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (21) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (22) -- [1, 154, 1536] -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
│ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- --
│ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 --
│ └─JointTransformerBlock (23) -- -- -- --
│ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 --
│ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─AdaLayerNormContinuous (norm1_context) [1, 154, 1536] [1, 154, 1536] -- --
│ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ │ │ └─Linear (linear) [1, 1536] [1, 3072] 4,721,664 --
│ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- --
│ │ └─Attention (attn) -- [1, 4096, 1536] -- --
│ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 --
│ │ │ └─ModuleList (to_out) -- -- -- --
│ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 --
│ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- --
│ │ │ └─ModuleList (net) -- -- -- --
│ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- --
│ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 --
│ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- --
│ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 --
├─AdaLayerNormContinuous (norm_out) [1, 4096, 1536] [1, 4096, 1536] -- --
│ └─SiLU (silu) [1, 1536] [1, 1536] -- --
│ └─Linear (linear) [1, 1536] [1, 3072] 4,721,664 --
│ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- --
├─Linear (proj_out) [1, 4096, 1536] [1, 4096, 64] 98,368 --
================================================================================================================================================================
Total params: 2,028,328,000
Trainable params: 2,028,328,000
Non-trainable params: 0
Total mult-adds (G): 2.44
================================================================================================================================================================
Input size (MB): 1.79
Forward/backward pass size (MB): 5663.46
Params size (MB): 4056.66
Estimated Total Size (MB): 9721.90
================================================================================================================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment