Skip to content

Instantly share code, notes, and snippets.

@mrmaheshrajput
Created May 2, 2025 05:33
Show Gist options
  • Select an option

  • Save mrmaheshrajput/12a3092eaf147496b990f0acbea476de to your computer and use it in GitHub Desktop.

Select an option

Save mrmaheshrajput/12a3092eaf147496b990f0acbea476de to your computer and use it in GitHub Desktop.
def build_tsmixer_model(input_shape, forecast_horizon=1, hidden_dim=128, num_layers=2):
inputs = Input(shape=input_shape)
x = inputs
# Time mixing layers
for _ in range(num_layers):
# Mix across time dimension
time_mix = tf.keras.layers.Permute((2, 1))(x) # [batch, features, time]
time_mix = Dense(input_shape[0], activation='relu')(time_mix) # Project each feature across time
time_mix = tf.keras.layers.Permute((2, 1))(time_mix) # Back to [batch, time, features]
# Residual connection
x = x + time_mix
x = LayerNormalization()(x)
# Mix across feature dimension
feature_mix = Dense(hidden_dim, activation='relu')(x)
feature_mix = Dense(input_shape[-1])(feature_mix)
# Residual connection
x = x + feature_mix
x = LayerNormalization()(x)
# Output layer using the last time step
x = x[:, -1, :]
outputs = Dense(forecast_horizon)(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='mse')
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment