Skip to content

Instantly share code, notes, and snippets.

@RF5
Created August 3, 2025 08:59
Show Gist options
  • Select an option

  • Save RF5/4d5de0713fcefea252ea684c447559e8 to your computer and use it in GitHub Desktop.

Select an option

Save RF5/4d5de0713fcefea252ea684c447559e8 to your computer and use it in GitHub Desktop.
Simple 2D STFT and 2D ISTFT Pytorch module
"""
A simple 2D STFT and 2D ISTFT Pytorch module.
Author: M. Baas
2025
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window, check_COLA # type: ignore
import warnings
class STFT2D(nn.Module):
def __init__(self, win_len=(64, 64), win_hop=(32, 32), fft_len=(64, 64),
win_type='hann', win_sqrt=False, pad_center=True):
"""
2D Short-Time Fourier Transform (STFT) module for multi-channel images.
The parameters win_len, win_hop, fft_len are tuples (value_for_width, value_for_height).
Args:
win_len (tuple of int): Window lengths for width and height dimensions. (len_w, len_h).
Defaults to (64, 64).
win_hop (tuple of int): Hop lengths for width and height dimensions. (hop_w, hop_h).
Defaults to (32, 32).
fft_len (tuple of int): FFT lengths for width and height dimensions. (fft_w, fft_h).
Should be >= win_len. Defaults to (64, 64).
win_type (str): Type of window to use (e.g., 'hann', 'hamming'). Passed to scipy.signal.get_window.
Defaults to 'hann'.
win_sqrt (bool): If True, use square root of the window for analysis and synthesis (for perfect
reconstruction with certain overlaps). Defaults to False.
pad_center (bool): If True, pads the input signal so that frames are centered.
Important for perfect reconstruction. Defaults to True.
"""
super(STFT2D, self).__init__()
if not (isinstance(win_len, tuple) and len(win_len) == 2 and
isinstance(win_len[0], int) and isinstance(win_len[1], int)):
raise ValueError("win_len must be a tuple of two integers (len_w, len_h)")
if not (isinstance(win_hop, tuple) and len(win_hop) == 2 and
isinstance(win_hop[0], int) and isinstance(win_hop[1], int)):
raise ValueError("win_hop must be a tuple of two integers (hop_w, hop_h)")
if not (isinstance(fft_len, tuple) and len(fft_len) == 2 and
isinstance(fft_len[0], int) and isinstance(fft_len[1], int)):
raise ValueError("fft_len must be a tuple of two integers (fft_w, fft_h)")
self.win_len = win_len # (len_w, len_h)
self.win_hop = win_hop # (hop_w, hop_h)
self.fft_len = fft_len # (fft_w, fft_h)
if not (self.fft_len[0] >= self.win_len[0] and self.fft_len[1] >= self.win_len[1]):
raise ValueError("fft_len components must be greater than or equal to win_len components.")
if not (self.win_len[0] > 0 and self.win_len[1] > 0):
raise ValueError("win_len components must be positive.")
if not (self.win_hop[0] > 0 and self.win_hop[1] > 0):
raise ValueError("win_hop components must be positive.")
self.win_type = win_type
self.win_sqrt = win_sqrt
self.pad_center = pad_center
self.pad_amount_w = self.fft_len[0] // 2 # Padding for width dimension (self.win_len[0])
self.pad_amount_h = self.fft_len[1] // 2 # Padding for height dimension (self.win_len[1])
try:
win_w_1d_np = get_window(win_type, self.win_len[0], fftbins=True) # Width window
win_h_1d_np = get_window(win_type, self.win_len[1], fftbins=True) # Height window
except Exception as e:
raise ValueError(f"Failed to create window of type '{win_type}' with scipy.signal.get_window. Error: {e}")
win_w_1d = torch.tensor(win_w_1d_np, dtype=torch.float32)
win_h_1d = torch.tensor(win_h_1d_np, dtype=torch.float32)
# Outer product to create 2D window: (win_len_w, win_len_h)
window_2d_base = torch.outer(win_w_1d, win_h_1d)
if self.win_sqrt:
self.analysis_window = torch.sqrt(window_2d_base)
else:
self.analysis_window = window_2d_base
if self.win_sqrt:
norm_window_component = window_2d_base
else:
norm_window_component = window_2d_base.pow(2)
self.register_buffer('analysis_window_buffer', self.analysis_window, persistent=False)
self.register_buffer('norm_window_component_buffer', norm_window_component, persistent=False)
self.crop_dims_HW = None
self.original_N_at_transform = None
self.original_C_at_transform = None
self.input_orig_ndim = None
self.cola_w = check_COLA(win_w_1d_np, self.win_len[0], self.win_len[0] - self.win_hop[0], tol=1e-4)
self.cola_h = check_COLA(win_h_1d_np, self.win_len[1], self.win_len[1] - self.win_hop[1], tol=1e-4)
if not (self.cola_w and self.cola_h) and self.win_sqrt:
warnings.warn(f"COLA condition not met for win_type='{self.win_type}', win_len={self.win_len}, win_hop={self.win_hop}. "
"Perfect reconstruction may not be achieved even with win_sqrt=True.", UserWarning)
def is_perfect(self) -> bool:
if self.win_sqrt:
return self.cola_w and self.cola_h and self.pad_center
else:
return self.cola_w and self.cola_h and self.pad_center
def transform(self, inputs: torch.Tensor) -> torch.Tensor:
"""
Performs 2D STFT on multi-channel images.
Input H, W order is maintained for freq and frame count output dimensions.
Args:
inputs (torch.Tensor): Input tensor of shape (H, W), (C, H, W), or (N, C, H, W).
H is height, W is width.
Returns:
torch.Tensor: STFT coefficients of shape
(N, C, num_freqs_H, num_freqs_W, num_frames_H, num_frames_W).
num_freqs_H = fft_len_h // 2 + 1 (from rfft2 on H-dim of patch)
num_freqs_W = fft_len_w (from rfft2 on W-dim of patch)
num_frames_H = number of frames extracted along the H dimension.
num_frames_W = number of frames extracted along the W dimension.
"""
self.input_orig_ndim = inputs.dim()
if inputs.dim() == 2:
inputs_batched = inputs.unsqueeze(0).unsqueeze(0)
elif inputs.dim() == 3:
inputs_batched = inputs.unsqueeze(0)
elif inputs.dim() == 4:
inputs_batched = inputs
else:
raise ValueError(f"Input must be 2D (H,W), 3D (C,H,W) or 4D (N,C,H,W). Got {inputs.dim()}D.")
N_orig, C_orig, H_orig, W_orig = inputs_batched.shape
self.crop_dims_HW = (H_orig, W_orig)
inputs_core_proc = inputs_batched.reshape(N_orig * C_orig, H_orig, W_orig).permute(0, 2, 1)
# Now inputs_core_proc is (B_eff, W_orig, H_orig)
current_device = inputs_core_proc.device
if self.analysis_window_buffer.device != current_device:
self.analysis_window_buffer = self.analysis_window_buffer.to(current_device)
self.norm_window_component_buffer = self.norm_window_component_buffer.to(current_device)
if self.pad_center:
inputs_padded = F.pad(inputs_core_proc, (self.pad_amount_h, self.pad_amount_h,
self.pad_amount_w, self.pad_amount_w), mode='reflect')
else:
inputs_padded = inputs_core_proc
# inputs_padded is (B_eff, W_padded, H_padded)
frames_w_unfold = inputs_padded.unfold(1, self.win_len[0], self.win_hop[0])
# Shape: (B_eff, num_w_frames, H_padded, win_len_w)
frames_wh_unfold = frames_w_unfold.unfold(2, self.win_len[1], self.win_hop[1])
# Shape: (B_eff, num_w_frames, num_h_frames, win_len_w, win_len_h)
windowed_frames = frames_wh_unfold * self.analysis_window_buffer.view(1, 1, 1, self.win_len[0], self.win_len[1])
stft_coeffs_core = torch.fft.rfft2(windowed_frames, s=self.fft_len, norm='ortho')
# Shape: (B_eff, num_w_frames, num_h_frames, fft_len_w_bins, fft_len_h_bins_rfft)
# Permute to (B_eff, freq_H, freq_W, frames_H, frames_W)
# Original indices: 0:B_eff, 1:num_w_f, 2:num_h_f, 3:freq_w, 4:freq_h_rfft
stft_coeffs_permuted = stft_coeffs_core.permute(0, 4, 3, 2, 1)
final_output_shape = (N_orig, C_orig,
stft_coeffs_permuted.shape[1], stft_coeffs_permuted.shape[2], # freq_H, freq_W
stft_coeffs_permuted.shape[3], stft_coeffs_permuted.shape[4]) # frames_H, frames_W
stft_final = stft_coeffs_permuted.reshape(final_output_shape)
return stft_final
def inverse(self, stft_coeffs_user: torch.Tensor) -> torch.Tensor:
"""
Performs 2D inverse STFT for multi-channel STFT coefficients.
Expects coeff order (N, C, freq_H, freq_W, frames_H, frames_W).
Args:
stft_coeffs_user (torch.Tensor): STFT coefficients of shape
(N, C, num_freqs_H, num_freqs_W, num_frames_H, num_frames_W) or
(C, num_freqs_H, num_freqs_W, num_frames_H, num_frames_W).
Returns:
torch.Tensor: Reconstructed signal, shape matching original input to transform (e.g. (N,C,H,W)).
"""
self.inverse_input_orig_ndim = stft_coeffs_user.dim()
if stft_coeffs_user.dim() == 5:
stft_coeffs_input = stft_coeffs_user.unsqueeze(0)
elif stft_coeffs_user.dim() == 6:
stft_coeffs_input = stft_coeffs_user
else:
raise ValueError(f"Inverse input STFT coeffs must be 5D or 6D. Got {stft_coeffs_user.dim()}D.")
N_curr, C_curr, _freq_H, _freq_W, num_h_frames, num_w_frames = stft_coeffs_input.shape
stft_for_irfft_core = stft_coeffs_input.reshape(N_curr * C_curr,
_freq_H, _freq_W,
num_h_frames, num_w_frames)
# Shape: (B_eff, freq_H, freq_W, frames_H, frames_W)
# Permute back to (B_eff, frames_W, frames_H, freq_W, freq_H_rfft) for irfft2
# Input indices: 0:B_eff, 1:freq_H, 2:freq_W, 3:frames_H, 4:frames_W
stft_coeffs_permuted_for_irfft = stft_for_irfft_core.permute(0, 4, 3, 2, 1)
current_device = stft_coeffs_permuted_for_irfft.device
if self.analysis_window_buffer.device != current_device:
self.analysis_window_buffer = self.analysis_window_buffer.to(current_device)
self.norm_window_component_buffer = self.norm_window_component_buffer.to(current_device)
frames_rec_fft_dim = torch.fft.irfft2(stft_coeffs_permuted_for_irfft, s=self.fft_len, norm='ortho')
# Expected input to irfft2 was (..., fft_len_w, fft_len_h_rfft_domain)
# s=(fft_len_w, fft_len_h) means output is (..., fft_len_w, fft_len_h)
# Shape: (B_eff, num_w_frames, num_h_frames, fft_len_w, fft_len_h)
frames_rec = frames_rec_fft_dim[..., :self.win_len[0], :self.win_len[1]]
# Shape: (B_eff, num_w_frames, num_h_frames, win_len_w, win_len_h)
windowed_frames_rec = frames_rec * self.analysis_window_buffer.view(1, 1, 1, self.win_len[0], self.win_len[1])
out_W_proc_full = (num_w_frames - 1) * self.win_hop[0] + self.win_len[0]
out_H_proc_full = (num_h_frames - 1) * self.win_hop[1] + self.win_len[1]
output_signal_core = torch.zeros((N_curr * C_curr, out_W_proc_full, out_H_proc_full), device=current_device)
norm_sum_core = torch.zeros((N_curr * C_curr, out_W_proc_full, out_H_proc_full), device=current_device)
norm_win_comp_eff = self.norm_window_component_buffer
for i_w_frame in range(num_w_frames):
for i_h_frame in range(num_h_frames):
start_w_proc = i_w_frame * self.win_hop[0]
start_h_proc = i_h_frame * self.win_hop[1]
end_w_proc = start_w_proc + self.win_len[0]
end_h_proc = start_h_proc + self.win_len[1]
output_signal_core[:, start_w_proc:end_w_proc, start_h_proc:end_h_proc] += \
windowed_frames_rec[:, i_w_frame, i_h_frame, :, :] # patch (win_w x win_h)
norm_sum_core[:, start_w_proc:end_w_proc, start_h_proc:end_h_proc] += norm_win_comp_eff
valid_norm = norm_sum_core > 1e-8
output_signal_core[valid_norm] = output_signal_core[valid_norm] / norm_sum_core[valid_norm]
reconstructed_signal_padded_core = output_signal_core # (B_eff, W_proc_padded, H_proc_padded)
if self.pad_center:
if self.crop_dims_HW is None:
raise RuntimeError("crop_dims_HW is None. Cannot crop.")
H_orig, W_orig = self.crop_dims_HW
start_w_crop = self.pad_amount_w
end_w_crop = start_w_crop + W_orig
start_h_crop = self.pad_amount_h
end_h_crop = start_h_crop + H_orig
if end_w_crop > out_W_proc_full or end_h_crop > out_H_proc_full:
warnings.warn(f"Cropping region for W_proc [{start_w_crop}:{end_w_crop}] or H_proc [{start_h_crop}:{end_h_crop}] "
f"exceeds reconstructed W_proc={out_W_proc_full}, H_proc={out_H_proc_full}.", UserWarning)
reconstructed_signal_core_cropped = reconstructed_signal_padded_core[:, start_w_crop:end_w_crop, start_h_crop:end_h_crop]
else:
reconstructed_signal_core_cropped = reconstructed_signal_padded_core
if self.crop_dims_HW is not None:
H_orig, W_orig = self.crop_dims_HW
reconstructed_signal_core_cropped = reconstructed_signal_core_cropped[:, :W_orig, :H_orig]
# reconstructed_signal_core_cropped is (B_eff, W_orig, H_orig)
reconstructed_eff_dims = reconstructed_signal_core_cropped.permute(0, 2, 1)
# Now (B_eff, H_orig, W_orig)
final_reconstructed_shape = (N_curr, C_curr,
reconstructed_eff_dims.shape[1], reconstructed_eff_dims.shape[2])
output_final = reconstructed_eff_dims.reshape(final_reconstructed_shape)
if self.inverse_input_orig_ndim == 5 and output_final.shape[0] == 1:
return output_final.squeeze(0)
return output_final
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
input_orig_ndim_for_fwd = inputs.dim()
stft_coeffs = self.transform(inputs)
reconstructed_signal = self.inverse(stft_coeffs)
if input_orig_ndim_for_fwd == 2:
return reconstructed_signal.squeeze(0).squeeze(0)
elif input_orig_ndim_for_fwd == 3:
return reconstructed_signal.squeeze(0)
return reconstructed_signal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment