Created
June 13, 2024 05:41
-
-
Save laksjdjf/742fa0a17415f809bfce737351667102 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ================================================================================================================================================================ | |
| Layer (type (var_name)) Input Shape Output Shape Param # Kernel Shape | |
| ================================================================================================================================================================ | |
| SD3Transformer2DModel (SD3Transformer2DModel) -- [1, 16, 128, 128] -- -- | |
| ├─PatchEmbed (pos_embed) [1, 16, 128, 128] [1, 4096, 1536] -- -- | |
| │ └─Conv2d (proj) [1, 16, 128, 128] [1, 1536, 64, 64] 99,840 [2, 2] | |
| ├─CombinedTimestepTextProjEmbeddings (time_text_embed) [1] [1, 1536] -- -- | |
| │ └─Timesteps (time_proj) [1] [1, 256] -- -- | |
| │ └─TimestepEmbedding (timestep_embedder) [1, 256] [1, 1536] -- -- | |
| │ │ └─Linear (linear_1) [1, 256] [1, 1536] 394,752 -- | |
| │ │ └─SiLU (act) [1, 1536] [1, 1536] -- -- | |
| │ │ └─Linear (linear_2) [1, 1536] [1, 1536] 2,360,832 -- | |
| │ └─PixArtAlphaTextProjection (text_embedder) [1, 2048] [1, 1536] -- -- | |
| │ │ └─Linear (linear_1) [1, 2048] [1, 1536] 3,147,264 -- | |
| │ │ └─SiLU (act_1) [1, 1536] [1, 1536] -- -- | |
| │ │ └─Linear (linear_2) [1, 1536] [1, 1536] 2,360,832 -- | |
| ├─Linear (context_embedder) [1, 154, 4096] [1, 154, 1536] 6,292,992 -- | |
| ├─ModuleList (transformer_blocks) -- -- -- -- | |
| │ └─JointTransformerBlock (0) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (1) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (2) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (3) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (4) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (5) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (6) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (7) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (8) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (9) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (10) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (11) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (12) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (13) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (14) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (15) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (16) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (17) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (18) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (19) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (20) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (21) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (22) -- [1, 154, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormZero (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_add_out) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| │ │ └─LayerNorm (norm2_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─FeedForward (ff_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 154, 1536] [1, 154, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 154, 1536] [1, 154, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 154, 6144] [1, 154, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 154, 6144] [1, 154, 1536] 9,438,720 -- | |
| │ └─JointTransformerBlock (23) -- -- -- -- | |
| │ │ └─AdaLayerNormZero (norm1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 9216] 14,164,992 -- | |
| │ │ │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─AdaLayerNormContinuous (norm1_context) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ │ │ └─Linear (linear) [1, 1536] [1, 3072] 4,721,664 -- | |
| │ │ │ └─LayerNorm (norm) [1, 154, 1536] [1, 154, 1536] -- -- | |
| │ │ └─Attention (attn) -- [1, 4096, 1536] -- -- | |
| │ │ │ └─Linear (to_q) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_k) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (to_v) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_q_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_k_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─Linear (add_v_proj) [1, 154, 1536] [1, 154, 1536] 2,360,832 -- | |
| │ │ │ └─ModuleList (to_out) -- -- -- -- | |
| │ │ │ │ └─Linear (0) [1, 4096, 1536] [1, 4096, 1536] 2,360,832 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─LayerNorm (norm2) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ └─FeedForward (ff) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ │ │ └─ModuleList (net) -- -- -- -- | |
| │ │ │ │ └─GELU (0) [1, 4096, 1536] [1, 4096, 6144] -- -- | |
| │ │ │ │ │ └─Linear (proj) [1, 4096, 1536] [1, 4096, 6144] 9,443,328 -- | |
| │ │ │ │ └─Dropout (1) [1, 4096, 6144] [1, 4096, 6144] -- -- | |
| │ │ │ │ └─Linear (2) [1, 4096, 6144] [1, 4096, 1536] 9,438,720 -- | |
| ├─AdaLayerNormContinuous (norm_out) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| │ └─SiLU (silu) [1, 1536] [1, 1536] -- -- | |
| │ └─Linear (linear) [1, 1536] [1, 3072] 4,721,664 -- | |
| │ └─LayerNorm (norm) [1, 4096, 1536] [1, 4096, 1536] -- -- | |
| ├─Linear (proj_out) [1, 4096, 1536] [1, 4096, 64] 98,368 -- | |
| ================================================================================================================================================================ | |
| Total params: 2,028,328,000 | |
| Trainable params: 2,028,328,000 | |
| Non-trainable params: 0 | |
| Total mult-adds (G): 2.44 | |
| ================================================================================================================================================================ | |
| Input size (MB): 1.79 | |
| Forward/backward pass size (MB): 5663.46 | |
| Params size (MB): 4056.66 | |
| Estimated Total Size (MB): 9721.90 | |
| ================================================================================================================================================================ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment