Skip to content

Instantly share code, notes, and snippets.

@mrmaheshrajput
Created May 2, 2025 05:33
Show Gist options
  • Select an option

  • Save mrmaheshrajput/c647c93a6ebc77533e93031c60ddbe87 to your computer and use it in GitHub Desktop.

Select an option

Save mrmaheshrajput/c647c93a6ebc77533e93031c60ddbe87 to your computer and use it in GitHub Desktop.
def build_tft_model(static_shape, past_shape, future_shape, forecast_horizon):
static_inputs = Input(shape=static_shape)
past_inputs = Input(shape=past_shape)
future_inputs = Input(shape=future_shape) # Known future covariates
static_context = Dense(64, activation='relu')(static_inputs)
past_selected = Dense(past_shape[-1], activation='sigmoid')(tf.concat([past_inputs, static_context], axis=-1))
past_weighted = past_inputs * past_selected
future_selected = Dense(future_shape[-1], activation='sigmoid')(tf.concat([future_inputs, static_context], axis=-1))
future_weighted = future_inputs * future_selected
# Past LSTM encoder
past_encoded = LSTM(64, return_sequences=True)(past_weighted)
# Future LSTM decoder
future_encoded = LSTM(64, return_sequences=True)(future_weighted)
# Combine past and future encodings
combined_features = tf.concat([past_encoded, future_encoded], axis=1)
attn_output = MultiHeadAttention(
num_heads=4, key_dim=16
)(combined_features, combined_features, combined_features)
# Position-wise feed-forward
x = LayerNormalization()(combined_features + attn_output)
ffn = Dense(128, activation='relu')(x)
ffn = Dense(64)(ffn)
x = LayerNormalization()(x + ffn)
# Output projection
outputs = Dense(forecast_horizon)(x[:, -forecast_horizon:, :])
model = Model(inputs=[static_inputs, past_inputs, future_inputs], outputs=outputs)
model.compile(optimizer='adam', loss='mse')
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment