Created
May 2, 2025 05:33
-
-
Save mrmaheshrajput/12a3092eaf147496b990f0acbea476de to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def build_tsmixer_model(input_shape, forecast_horizon=1, hidden_dim=128, num_layers=2): | |
| inputs = Input(shape=input_shape) | |
| x = inputs | |
| # Time mixing layers | |
| for _ in range(num_layers): | |
| # Mix across time dimension | |
| time_mix = tf.keras.layers.Permute((2, 1))(x) # [batch, features, time] | |
| time_mix = Dense(input_shape[0], activation='relu')(time_mix) # Project each feature across time | |
| time_mix = tf.keras.layers.Permute((2, 1))(time_mix) # Back to [batch, time, features] | |
| # Residual connection | |
| x = x + time_mix | |
| x = LayerNormalization()(x) | |
| # Mix across feature dimension | |
| feature_mix = Dense(hidden_dim, activation='relu')(x) | |
| feature_mix = Dense(input_shape[-1])(feature_mix) | |
| # Residual connection | |
| x = x + feature_mix | |
| x = LayerNormalization()(x) | |
| # Output layer using the last time step | |
| x = x[:, -1, :] | |
| outputs = Dense(forecast_horizon)(x) | |
| model = Model(inputs=inputs, outputs=outputs) | |
| model.compile(optimizer='adam', loss='mse') | |
| return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment