Created
August 3, 2025 08:59
-
-
Save RF5/4d5de0713fcefea252ea684c447559e8 to your computer and use it in GitHub Desktop.
Simple 2D STFT and 2D ISTFT Pytorch module
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| A simple 2D STFT and 2D ISTFT Pytorch module. | |
| Author: M. Baas | |
| 2025 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from scipy.signal import get_window, check_COLA # type: ignore | |
| import warnings | |
| class STFT2D(nn.Module): | |
| def __init__(self, win_len=(64, 64), win_hop=(32, 32), fft_len=(64, 64), | |
| win_type='hann', win_sqrt=False, pad_center=True): | |
| """ | |
| 2D Short-Time Fourier Transform (STFT) module for multi-channel images. | |
| The parameters win_len, win_hop, fft_len are tuples (value_for_width, value_for_height). | |
| Args: | |
| win_len (tuple of int): Window lengths for width and height dimensions. (len_w, len_h). | |
| Defaults to (64, 64). | |
| win_hop (tuple of int): Hop lengths for width and height dimensions. (hop_w, hop_h). | |
| Defaults to (32, 32). | |
| fft_len (tuple of int): FFT lengths for width and height dimensions. (fft_w, fft_h). | |
| Should be >= win_len. Defaults to (64, 64). | |
| win_type (str): Type of window to use (e.g., 'hann', 'hamming'). Passed to scipy.signal.get_window. | |
| Defaults to 'hann'. | |
| win_sqrt (bool): If True, use square root of the window for analysis and synthesis (for perfect | |
| reconstruction with certain overlaps). Defaults to False. | |
| pad_center (bool): If True, pads the input signal so that frames are centered. | |
| Important for perfect reconstruction. Defaults to True. | |
| """ | |
| super(STFT2D, self).__init__() | |
| if not (isinstance(win_len, tuple) and len(win_len) == 2 and | |
| isinstance(win_len[0], int) and isinstance(win_len[1], int)): | |
| raise ValueError("win_len must be a tuple of two integers (len_w, len_h)") | |
| if not (isinstance(win_hop, tuple) and len(win_hop) == 2 and | |
| isinstance(win_hop[0], int) and isinstance(win_hop[1], int)): | |
| raise ValueError("win_hop must be a tuple of two integers (hop_w, hop_h)") | |
| if not (isinstance(fft_len, tuple) and len(fft_len) == 2 and | |
| isinstance(fft_len[0], int) and isinstance(fft_len[1], int)): | |
| raise ValueError("fft_len must be a tuple of two integers (fft_w, fft_h)") | |
| self.win_len = win_len # (len_w, len_h) | |
| self.win_hop = win_hop # (hop_w, hop_h) | |
| self.fft_len = fft_len # (fft_w, fft_h) | |
| if not (self.fft_len[0] >= self.win_len[0] and self.fft_len[1] >= self.win_len[1]): | |
| raise ValueError("fft_len components must be greater than or equal to win_len components.") | |
| if not (self.win_len[0] > 0 and self.win_len[1] > 0): | |
| raise ValueError("win_len components must be positive.") | |
| if not (self.win_hop[0] > 0 and self.win_hop[1] > 0): | |
| raise ValueError("win_hop components must be positive.") | |
| self.win_type = win_type | |
| self.win_sqrt = win_sqrt | |
| self.pad_center = pad_center | |
| self.pad_amount_w = self.fft_len[0] // 2 # Padding for width dimension (self.win_len[0]) | |
| self.pad_amount_h = self.fft_len[1] // 2 # Padding for height dimension (self.win_len[1]) | |
| try: | |
| win_w_1d_np = get_window(win_type, self.win_len[0], fftbins=True) # Width window | |
| win_h_1d_np = get_window(win_type, self.win_len[1], fftbins=True) # Height window | |
| except Exception as e: | |
| raise ValueError(f"Failed to create window of type '{win_type}' with scipy.signal.get_window. Error: {e}") | |
| win_w_1d = torch.tensor(win_w_1d_np, dtype=torch.float32) | |
| win_h_1d = torch.tensor(win_h_1d_np, dtype=torch.float32) | |
| # Outer product to create 2D window: (win_len_w, win_len_h) | |
| window_2d_base = torch.outer(win_w_1d, win_h_1d) | |
| if self.win_sqrt: | |
| self.analysis_window = torch.sqrt(window_2d_base) | |
| else: | |
| self.analysis_window = window_2d_base | |
| if self.win_sqrt: | |
| norm_window_component = window_2d_base | |
| else: | |
| norm_window_component = window_2d_base.pow(2) | |
| self.register_buffer('analysis_window_buffer', self.analysis_window, persistent=False) | |
| self.register_buffer('norm_window_component_buffer', norm_window_component, persistent=False) | |
| self.crop_dims_HW = None | |
| self.original_N_at_transform = None | |
| self.original_C_at_transform = None | |
| self.input_orig_ndim = None | |
| self.cola_w = check_COLA(win_w_1d_np, self.win_len[0], self.win_len[0] - self.win_hop[0], tol=1e-4) | |
| self.cola_h = check_COLA(win_h_1d_np, self.win_len[1], self.win_len[1] - self.win_hop[1], tol=1e-4) | |
| if not (self.cola_w and self.cola_h) and self.win_sqrt: | |
| warnings.warn(f"COLA condition not met for win_type='{self.win_type}', win_len={self.win_len}, win_hop={self.win_hop}. " | |
| "Perfect reconstruction may not be achieved even with win_sqrt=True.", UserWarning) | |
| def is_perfect(self) -> bool: | |
| if self.win_sqrt: | |
| return self.cola_w and self.cola_h and self.pad_center | |
| else: | |
| return self.cola_w and self.cola_h and self.pad_center | |
| def transform(self, inputs: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Performs 2D STFT on multi-channel images. | |
| Input H, W order is maintained for freq and frame count output dimensions. | |
| Args: | |
| inputs (torch.Tensor): Input tensor of shape (H, W), (C, H, W), or (N, C, H, W). | |
| H is height, W is width. | |
| Returns: | |
| torch.Tensor: STFT coefficients of shape | |
| (N, C, num_freqs_H, num_freqs_W, num_frames_H, num_frames_W). | |
| num_freqs_H = fft_len_h // 2 + 1 (from rfft2 on H-dim of patch) | |
| num_freqs_W = fft_len_w (from rfft2 on W-dim of patch) | |
| num_frames_H = number of frames extracted along the H dimension. | |
| num_frames_W = number of frames extracted along the W dimension. | |
| """ | |
| self.input_orig_ndim = inputs.dim() | |
| if inputs.dim() == 2: | |
| inputs_batched = inputs.unsqueeze(0).unsqueeze(0) | |
| elif inputs.dim() == 3: | |
| inputs_batched = inputs.unsqueeze(0) | |
| elif inputs.dim() == 4: | |
| inputs_batched = inputs | |
| else: | |
| raise ValueError(f"Input must be 2D (H,W), 3D (C,H,W) or 4D (N,C,H,W). Got {inputs.dim()}D.") | |
| N_orig, C_orig, H_orig, W_orig = inputs_batched.shape | |
| self.crop_dims_HW = (H_orig, W_orig) | |
| inputs_core_proc = inputs_batched.reshape(N_orig * C_orig, H_orig, W_orig).permute(0, 2, 1) | |
| # Now inputs_core_proc is (B_eff, W_orig, H_orig) | |
| current_device = inputs_core_proc.device | |
| if self.analysis_window_buffer.device != current_device: | |
| self.analysis_window_buffer = self.analysis_window_buffer.to(current_device) | |
| self.norm_window_component_buffer = self.norm_window_component_buffer.to(current_device) | |
| if self.pad_center: | |
| inputs_padded = F.pad(inputs_core_proc, (self.pad_amount_h, self.pad_amount_h, | |
| self.pad_amount_w, self.pad_amount_w), mode='reflect') | |
| else: | |
| inputs_padded = inputs_core_proc | |
| # inputs_padded is (B_eff, W_padded, H_padded) | |
| frames_w_unfold = inputs_padded.unfold(1, self.win_len[0], self.win_hop[0]) | |
| # Shape: (B_eff, num_w_frames, H_padded, win_len_w) | |
| frames_wh_unfold = frames_w_unfold.unfold(2, self.win_len[1], self.win_hop[1]) | |
| # Shape: (B_eff, num_w_frames, num_h_frames, win_len_w, win_len_h) | |
| windowed_frames = frames_wh_unfold * self.analysis_window_buffer.view(1, 1, 1, self.win_len[0], self.win_len[1]) | |
| stft_coeffs_core = torch.fft.rfft2(windowed_frames, s=self.fft_len, norm='ortho') | |
| # Shape: (B_eff, num_w_frames, num_h_frames, fft_len_w_bins, fft_len_h_bins_rfft) | |
| # Permute to (B_eff, freq_H, freq_W, frames_H, frames_W) | |
| # Original indices: 0:B_eff, 1:num_w_f, 2:num_h_f, 3:freq_w, 4:freq_h_rfft | |
| stft_coeffs_permuted = stft_coeffs_core.permute(0, 4, 3, 2, 1) | |
| final_output_shape = (N_orig, C_orig, | |
| stft_coeffs_permuted.shape[1], stft_coeffs_permuted.shape[2], # freq_H, freq_W | |
| stft_coeffs_permuted.shape[3], stft_coeffs_permuted.shape[4]) # frames_H, frames_W | |
| stft_final = stft_coeffs_permuted.reshape(final_output_shape) | |
| return stft_final | |
| def inverse(self, stft_coeffs_user: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Performs 2D inverse STFT for multi-channel STFT coefficients. | |
| Expects coeff order (N, C, freq_H, freq_W, frames_H, frames_W). | |
| Args: | |
| stft_coeffs_user (torch.Tensor): STFT coefficients of shape | |
| (N, C, num_freqs_H, num_freqs_W, num_frames_H, num_frames_W) or | |
| (C, num_freqs_H, num_freqs_W, num_frames_H, num_frames_W). | |
| Returns: | |
| torch.Tensor: Reconstructed signal, shape matching original input to transform (e.g. (N,C,H,W)). | |
| """ | |
| self.inverse_input_orig_ndim = stft_coeffs_user.dim() | |
| if stft_coeffs_user.dim() == 5: | |
| stft_coeffs_input = stft_coeffs_user.unsqueeze(0) | |
| elif stft_coeffs_user.dim() == 6: | |
| stft_coeffs_input = stft_coeffs_user | |
| else: | |
| raise ValueError(f"Inverse input STFT coeffs must be 5D or 6D. Got {stft_coeffs_user.dim()}D.") | |
| N_curr, C_curr, _freq_H, _freq_W, num_h_frames, num_w_frames = stft_coeffs_input.shape | |
| stft_for_irfft_core = stft_coeffs_input.reshape(N_curr * C_curr, | |
| _freq_H, _freq_W, | |
| num_h_frames, num_w_frames) | |
| # Shape: (B_eff, freq_H, freq_W, frames_H, frames_W) | |
| # Permute back to (B_eff, frames_W, frames_H, freq_W, freq_H_rfft) for irfft2 | |
| # Input indices: 0:B_eff, 1:freq_H, 2:freq_W, 3:frames_H, 4:frames_W | |
| stft_coeffs_permuted_for_irfft = stft_for_irfft_core.permute(0, 4, 3, 2, 1) | |
| current_device = stft_coeffs_permuted_for_irfft.device | |
| if self.analysis_window_buffer.device != current_device: | |
| self.analysis_window_buffer = self.analysis_window_buffer.to(current_device) | |
| self.norm_window_component_buffer = self.norm_window_component_buffer.to(current_device) | |
| frames_rec_fft_dim = torch.fft.irfft2(stft_coeffs_permuted_for_irfft, s=self.fft_len, norm='ortho') | |
| # Expected input to irfft2 was (..., fft_len_w, fft_len_h_rfft_domain) | |
| # s=(fft_len_w, fft_len_h) means output is (..., fft_len_w, fft_len_h) | |
| # Shape: (B_eff, num_w_frames, num_h_frames, fft_len_w, fft_len_h) | |
| frames_rec = frames_rec_fft_dim[..., :self.win_len[0], :self.win_len[1]] | |
| # Shape: (B_eff, num_w_frames, num_h_frames, win_len_w, win_len_h) | |
| windowed_frames_rec = frames_rec * self.analysis_window_buffer.view(1, 1, 1, self.win_len[0], self.win_len[1]) | |
| out_W_proc_full = (num_w_frames - 1) * self.win_hop[0] + self.win_len[0] | |
| out_H_proc_full = (num_h_frames - 1) * self.win_hop[1] + self.win_len[1] | |
| output_signal_core = torch.zeros((N_curr * C_curr, out_W_proc_full, out_H_proc_full), device=current_device) | |
| norm_sum_core = torch.zeros((N_curr * C_curr, out_W_proc_full, out_H_proc_full), device=current_device) | |
| norm_win_comp_eff = self.norm_window_component_buffer | |
| for i_w_frame in range(num_w_frames): | |
| for i_h_frame in range(num_h_frames): | |
| start_w_proc = i_w_frame * self.win_hop[0] | |
| start_h_proc = i_h_frame * self.win_hop[1] | |
| end_w_proc = start_w_proc + self.win_len[0] | |
| end_h_proc = start_h_proc + self.win_len[1] | |
| output_signal_core[:, start_w_proc:end_w_proc, start_h_proc:end_h_proc] += \ | |
| windowed_frames_rec[:, i_w_frame, i_h_frame, :, :] # patch (win_w x win_h) | |
| norm_sum_core[:, start_w_proc:end_w_proc, start_h_proc:end_h_proc] += norm_win_comp_eff | |
| valid_norm = norm_sum_core > 1e-8 | |
| output_signal_core[valid_norm] = output_signal_core[valid_norm] / norm_sum_core[valid_norm] | |
| reconstructed_signal_padded_core = output_signal_core # (B_eff, W_proc_padded, H_proc_padded) | |
| if self.pad_center: | |
| if self.crop_dims_HW is None: | |
| raise RuntimeError("crop_dims_HW is None. Cannot crop.") | |
| H_orig, W_orig = self.crop_dims_HW | |
| start_w_crop = self.pad_amount_w | |
| end_w_crop = start_w_crop + W_orig | |
| start_h_crop = self.pad_amount_h | |
| end_h_crop = start_h_crop + H_orig | |
| if end_w_crop > out_W_proc_full or end_h_crop > out_H_proc_full: | |
| warnings.warn(f"Cropping region for W_proc [{start_w_crop}:{end_w_crop}] or H_proc [{start_h_crop}:{end_h_crop}] " | |
| f"exceeds reconstructed W_proc={out_W_proc_full}, H_proc={out_H_proc_full}.", UserWarning) | |
| reconstructed_signal_core_cropped = reconstructed_signal_padded_core[:, start_w_crop:end_w_crop, start_h_crop:end_h_crop] | |
| else: | |
| reconstructed_signal_core_cropped = reconstructed_signal_padded_core | |
| if self.crop_dims_HW is not None: | |
| H_orig, W_orig = self.crop_dims_HW | |
| reconstructed_signal_core_cropped = reconstructed_signal_core_cropped[:, :W_orig, :H_orig] | |
| # reconstructed_signal_core_cropped is (B_eff, W_orig, H_orig) | |
| reconstructed_eff_dims = reconstructed_signal_core_cropped.permute(0, 2, 1) | |
| # Now (B_eff, H_orig, W_orig) | |
| final_reconstructed_shape = (N_curr, C_curr, | |
| reconstructed_eff_dims.shape[1], reconstructed_eff_dims.shape[2]) | |
| output_final = reconstructed_eff_dims.reshape(final_reconstructed_shape) | |
| if self.inverse_input_orig_ndim == 5 and output_final.shape[0] == 1: | |
| return output_final.squeeze(0) | |
| return output_final | |
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
| input_orig_ndim_for_fwd = inputs.dim() | |
| stft_coeffs = self.transform(inputs) | |
| reconstructed_signal = self.inverse(stft_coeffs) | |
| if input_orig_ndim_for_fwd == 2: | |
| return reconstructed_signal.squeeze(0).squeeze(0) | |
| elif input_orig_ndim_for_fwd == 3: | |
| return reconstructed_signal.squeeze(0) | |
| return reconstructed_signal |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment