Skip to content

Instantly share code, notes, and snippets.

@mrmaheshrajput
Created May 2, 2025 05:31
Show Gist options
  • Select an option

  • Save mrmaheshrajput/b34d62ac9135505fc7d323e08c16ea4b to your computer and use it in GitHub Desktop.

Select an option

Save mrmaheshrajput/b34d62ac9135505fc7d323e08c16ea4b to your computer and use it in GitHub Desktop.
def create_patches(x, patch_len, stride=1):
"""Convert time series data into patches"""
patches = []
for i in range(0, x.shape[1] - patch_len + 1, stride):
patches.append(x[:, i:i+patch_len, :])
# Stack patches along a new dimension
patches = tf.stack(patches, axis=1)
# Reshape to [batch, num_patches, patch_len * channels]
batch_size = tf.shape(patches)[0]
num_patches = tf.shape(patches)[1]
channels = tf.shape(patches)[3]
patches = tf.reshape(patches, [batch_size, num_patches, patch_len * channels])
return patches
def build_patchtst_model(input_shape, patch_len=16, stride=8, d_model=128, num_heads=4, ff_dim=256, num_layers=3, forecast_horizon=1):
inputs = Input(shape=input_shape)
# Create patches
patches = tf.keras.layers.Lambda(
lambda x: create_patches(x, patch_len=patch_len, stride=stride)
)(inputs)
embedded_patches = Dense(d_model)(patches)
positions = tf.range(start=0, limit=tf.shape(embedded_patches)[1], delta=1)
pos_embeddings = tf.keras.layers.Embedding(input_dim=1000, output_dim=d_model)(positions)
x = embedded_patches + pos_embeddings
# Apply transformer encoder layers
for _ in range(num_layers):
attn_output = MultiHeadAttention(
num_heads=num_heads, key_dim=d_model // num_heads
)(x, x, x)
x = LayerNormalization(epsilon=1e-6)(attn_output + x)
ffn_output = tf.keras.Sequential([
Dense(ff_dim, activation="relu"),
Dense(d_model)
])(x)
x = LayerNormalization(epsilon=1e-6)(ffn_output + x)
x = tf.keras.layers.GlobalAveragePooling1D()(x)
outputs = Dense(forecast_horizon)(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss='mse')
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment