Created
June 5, 2025 17:28
-
-
Save johnbanq/c8c3bb2da8c738e2407bc9b3691bea27 to your computer and use it in GitHub Desktop.
tmp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from timm.layers import DropPath, to_2tuple, trunc_normal_ | |
| import torch.utils.checkpoint as checkpoint | |
| import nibabel as nib | |
| import numpy as np | |
| def to_3tuple(x): | |
| """Converts input to a 3-element tuple.""" | |
| if isinstance(x, tuple): | |
| return x | |
| return (x, x, x) | |
| # ----------------------------------------------------------------------------------- | |
| # STAGE 1: STEM - Shallow Feature Extraction | |
| # ----------------------------------------------------------------------------------- | |
| class PatchEmbed3D(nn.Module): | |
| r""" Volume to Patch Embedding for 3D inputs | |
| Args: | |
| img_size (tuple[int]): Input volume size (D, H, W). | |
| patch_size (tuple[int]): Patch size (Pd, Ph, Pw). | |
| in_chans (int): Number of input channels. | |
| embed_dim (int): Number of output channels (tokens). | |
| norm_layer (nn.Module, optional): Optional normalization layer | |
| """ | |
| def __init__(self, img_size=(64, 224, 224), patch_size=(4, 4, 4), in_chans=1, embed_dim=96, norm_layer=None): | |
| super().__init__() | |
| self.img_size = to_3tuple(img_size) | |
| self.patch_size = to_3tuple(patch_size) | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.patches_resolution = tuple([self.img_size[0] // self.patch_size[0], | |
| self.img_size[1] // self.patch_size[1], | |
| self.img_size[2] // self.patch_size[2]]) | |
| self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] * self.patches_resolution[2] | |
| self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size) | |
| self.norm = norm_layer(embed_dim) if norm_layer else None | |
| def forward(self, x): | |
| # x shape: (B, C_in, D_img, H_img, W_img) | |
| x = self.proj(x) # -> (B, embed_dim, D_p, H_p, W_p) | |
| B, C, D_p, H_p, W_p = x.shape | |
| assert (D_p, H_p, W_p) == self.patches_resolution, \ | |
| f"Projected shape {(D_p, H_p, W_p)} does not match internal patches_resolution {self.patches_resolution}" | |
| x = x.flatten(2).transpose(1, 2) # -> (B, N_patches, embed_dim) | |
| if self.norm: | |
| x = self.norm(x) | |
| return x | |
| def flops(self): | |
| # FLOPs for the projection convolution | |
| flops = (self.patches_resolution[0] * self.patches_resolution[1] * self.patches_resolution[2]) * \ | |
| self.in_chans * self.embed_dim * self.patch_size[0] * self.patch_size[1] * self.patch_size[2] | |
| if self.norm: | |
| flops += self.num_patches * self.embed_dim # FLOPs for LayerNorm | |
| return flops | |
| class PatchUnEmbed3D(nn.Module): | |
| """ | |
| Patch Unembedding for 3D volumetric data. | |
| Converts patch tokens [B, L, C] -> volume [B, C, D_p, H_p, W_p]. | |
| """ | |
| def __init__(self, img_size, embed_dim): | |
| super().__init__() | |
| self.patches_resolution = to_3tuple(img_size) | |
| self.embed_dim = embed_dim | |
| def forward(self, x, x_size_patch_grid): | |
| B, L, C = x.shape | |
| D_p, H_p, W_p = x_size_patch_grid | |
| assert L == D_p * H_p * W_p, f"Mismatch between token count L={L} and patch grid volume {D_p*H_p*W_p}" | |
| assert (D_p, H_p, W_p) == self.patches_resolution, \ | |
| f"x_size_patch_grid {x_size_patch_grid} must match internal patches_resolution {self.patches_resolution}" | |
| x = x.transpose(1, 2).contiguous().view(B, C, D_p, H_p, W_p) | |
| return x | |
| def flops(self): | |
| return 0 | |
| # ----------------------------------------------------------------------------------- | |
| # STAGE 2: BODY - Transformer Backbone | |
| # ----------------------------------------------------------------------------------- | |
| class Mlp(nn.Module): | |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(hidden_features, out_features) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| def window_partition_3d(x, window_size_tuple): | |
| B, D, H, W, C = x.shape | |
| WD, WH, WW = window_size_tuple | |
| x = x.view(B, D // WD, WD, H // WH, WH, W // WW, WW, C) | |
| windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, WD, WH, WW, C) | |
| return windows | |
| def window_reverse_3d(windows, window_size_tuple, D, H, W): | |
| WD, WH, WW = window_size_tuple | |
| num_windows_d = D // WD | |
| num_windows_h = H // WH | |
| num_windows_w = W // WW | |
| B_times_num_windows = windows.shape[0] | |
| B = B_times_num_windows // (num_windows_d * num_windows_h * num_windows_w) | |
| x = windows.view(B, num_windows_d, num_windows_h, num_windows_w, WD, WH, WW, -1) | |
| x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1) | |
| return x | |
| class WindowAttention3D(nn.Module): | |
| def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): | |
| super().__init__() | |
| self.dim = dim | |
| self.window_size = to_3tuple(window_size) | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| self.relative_position_bias_table = nn.Parameter( | |
| torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), num_heads)) | |
| coords_d = torch.arange(self.window_size[0]) | |
| coords_h = torch.arange(self.window_size[1]) | |
| coords_w = torch.arange(self.window_size[2]) | |
| coords = torch.stack(torch.meshgrid([coords_d, coords_h, coords_w], indexing='ij')) | |
| coords_flatten = torch.flatten(coords, 1) | |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] | |
| relative_coords = relative_coords.permute(1, 2, 0).contiguous() | |
| relative_coords[:, :, 0] += self.window_size[0] - 1 | |
| relative_coords[:, :, 1] += self.window_size[1] - 1 | |
| relative_coords[:, :, 2] += self.window_size[2] - 1 | |
| relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) | |
| relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1) | |
| relative_position_index = relative_coords.sum(-1) | |
| self.register_buffer("relative_position_index", relative_position_index) | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| self.softmax = nn.Softmax(dim=-1) | |
| trunc_normal_(self.relative_position_bias_table, std=0.02) | |
| def forward(self, x, mask=None): | |
| B_, N_win, C = x.shape | |
| qkv = self.qkv(x).reshape(B_, N_win, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| q = q * self.scale | |
| attn = (q @ k.transpose(-2, -1)) | |
| relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |
| N_win, N_win, -1) | |
| relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() | |
| attn = attn + relative_position_bias.unsqueeze(0) | |
| if mask is not None: | |
| nW = mask.shape[0] | |
| attn = attn.view(B_ // nW, nW, self.num_heads, N_win, N_win) + mask.unsqueeze(1).unsqueeze(0) | |
| attn = attn.view(-1, self.num_heads, N_win, N_win) | |
| attn = self.softmax(attn) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B_, N_win, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class SwinTransformerBlock3D(nn.Module): | |
| def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, | |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, | |
| drop=0., attn_drop=0., drop_path=0., | |
| act_layer=nn.GELU, norm_layer=nn.LayerNorm): | |
| super().__init__() | |
| self.dim = dim | |
| self.input_resolution = to_3tuple(input_resolution) | |
| self.num_heads = num_heads | |
| self.window_size_tuple = to_3tuple(window_size) | |
| self.shift_size_tuple = to_3tuple(shift_size) | |
| self.mlp_ratio = mlp_ratio | |
| if any(i_res < w_size for i_res, w_size in zip(self.input_resolution, self.window_size_tuple)): | |
| self.window_size_tuple = self.input_resolution | |
| self.shift_size_tuple = (0, 0, 0) | |
| if any(s_size >= w_size for s_size, w_size in zip(self.shift_size_tuple, self.window_size_tuple)): | |
| self.shift_size_tuple = (0, 0, 0) | |
| self.norm1 = norm_layer(dim) | |
| self.attn = WindowAttention3D(dim, self.window_size_tuple, num_heads, qkv_bias, qk_scale, attn_drop, drop) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) | |
| if any(s > 0 for s in self.shift_size_tuple): | |
| D_p, H_p, W_p = self.input_resolution | |
| img_mask = torch.zeros((1, D_p, H_p, W_p, 1)) | |
| d_slices = (slice(0, -self.window_size_tuple[0]), slice(-self.window_size_tuple[0], -self.shift_size_tuple[0]), slice(-self.shift_size_tuple[0], None)) | |
| h_slices = (slice(0, -self.window_size_tuple[1]), slice(-self.window_size_tuple[1], -self.shift_size_tuple[1]), slice(-self.shift_size_tuple[1], None)) | |
| w_slices = (slice(0, -self.window_size_tuple[2]), slice(-self.window_size_tuple[2], -self.shift_size_tuple[2]), slice(-self.shift_size_tuple[2], None)) | |
| cnt = 0 | |
| for d in d_slices: | |
| for h in h_slices: | |
| for w in w_slices: | |
| img_mask[:, d, h, w, :] = cnt | |
| cnt += 1 | |
| mask_windows = window_partition_3d(img_mask, self.window_size_tuple) | |
| mask_windows = mask_windows.view(-1, self.window_size_tuple[0] * self.window_size_tuple[1] * self.window_size_tuple[2]) | |
| attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) | |
| attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) | |
| else: | |
| attn_mask = None | |
| self.register_buffer("attn_mask", attn_mask) | |
| def forward(self, x): | |
| B, L, C = x.shape | |
| D_p, H_p, W_p = self.input_resolution | |
| assert L == D_p * H_p * W_p, f"Input feature has wrong size L={L}, expected {D_p*H_p*W_p}" | |
| shortcut = x | |
| x = self.norm1(x) | |
| x = x.view(B, D_p, H_p, W_p, C) | |
| if any(s > 0 for s in self.shift_size_tuple): | |
| x = torch.roll(x, shifts=(-self.shift_size_tuple[0], -self.shift_size_tuple[1], -self.shift_size_tuple[2]), dims=(1, 2, 3)) | |
| x_windows = window_partition_3d(x, self.window_size_tuple) | |
| x_windows = x_windows.view(-1, self.window_size_tuple[0] * self.window_size_tuple[1] * self.window_size_tuple[2], C) | |
| attn_windows = self.attn(x_windows, mask=self.attn_mask) | |
| attn_windows = attn_windows.view(-1, self.window_size_tuple[0], self.window_size_tuple[1], self.window_size_tuple[2], C) | |
| x = window_reverse_3d(attn_windows, self.window_size_tuple, D_p, H_p, W_p) | |
| if any(s > 0 for s in self.shift_size_tuple): | |
| x = torch.roll(x, shifts=(self.shift_size_tuple[0], self.shift_size_tuple[1], self.shift_size_tuple[2]), dims=(1, 2, 3)) | |
| x = x.view(B, D_p * H_p * W_p, C) | |
| x = shortcut + self.drop_path(x) | |
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |
| return x | |
| class BasicLayer3D(nn.Module): | |
| def __init__(self, dim, input_resolution, depth, num_heads, window_size, | |
| mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., | |
| drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): | |
| super().__init__() | |
| self.dim = dim | |
| self.input_resolution = to_3tuple(input_resolution) | |
| self.depth = depth | |
| self.use_checkpoint = use_checkpoint | |
| self.blocks = nn.ModuleList([ | |
| SwinTransformerBlock3D( | |
| dim=dim, input_resolution=self.input_resolution, num_heads=num_heads, window_size=window_size, | |
| shift_size=0 if (i % 2 == 0) else tuple(ws // 2 for ws in to_3tuple(window_size)), | |
| mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, | |
| drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) | |
| for i in range(depth)]) | |
| if downsample is not None: | |
| self.downsample = downsample(self.input_resolution, dim=dim, norm_layer=norm_layer) | |
| else: | |
| self.downsample = None | |
| def forward(self, x): | |
| for blk in self.blocks: | |
| if self.use_checkpoint and self.training: | |
| x = checkpoint.checkpoint(blk, x, use_reentrant=False) | |
| else: | |
| x = blk(x) | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| return x | |
| class RSTB3D(nn.Module): | |
| """Residual Swin Transformer Block (3D version).""" | |
| def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True, | |
| qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, | |
| downsample=None, use_checkpoint=False, resi_connection='1conv'): | |
| super().__init__() | |
| self.dim = dim | |
| self.input_resolution = to_3tuple(input_resolution) | |
| self.residual_group = BasicLayer3D(dim=dim, input_resolution=self.input_resolution, depth=depth, | |
| num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, | |
| drop_path=drop_path, norm_layer=norm_layer, downsample=downsample, | |
| use_checkpoint=use_checkpoint) | |
| if resi_connection == '1conv': | |
| self.conv = nn.Conv3d(dim, dim, 3, 1, 1) | |
| elif resi_connection == '3conv': | |
| self.conv = nn.Sequential( | |
| nn.Conv3d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(dim // 4, dim, 3, 1, 1)) | |
| self.patch_unembed_res = PatchUnEmbed3D(img_size=self.input_resolution, embed_dim=dim) | |
| self.patch_embed_res = PatchEmbed3D(img_size=self.input_resolution, patch_size=(1, 1, 1), in_chans=dim, embed_dim=dim) | |
| def forward(self, x): | |
| residual = x | |
| x_processed_tokens = self.residual_group(x) | |
| x_spatial_grid = self.patch_unembed_res(x_processed_tokens, self.input_resolution) | |
| x_conv_out = self.conv(x_spatial_grid) | |
| x_tokens_from_conv = self.patch_embed_res(x_conv_out) | |
| return x_tokens_from_conv + residual | |
| # ----------------------------------------------------------------------------------- | |
| # STAGE 3: HEAD - High-Quality Image Reconstruction | |
| # ----------------------------------------------------------------------------------- | |
| class PixelShuffle3D(nn.Module): | |
| def __init__(self, upscale_factor): | |
| super(PixelShuffle3D, self).__init__() | |
| self.r = upscale_factor | |
| def forward(self, x): | |
| B, C, D, H, W = x.shape | |
| r = self.r | |
| assert C % (r ** 3) == 0, f"Channels {C} must be divisible by upscale_factor^3={r**3}" | |
| out_c = C // (r ** 3) | |
| x = x.view(B, out_c, r, r, r, D, H, W) | |
| x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() | |
| x = x.view(B, out_c, D * r, H * r, W * r) | |
| return x | |
| class Upsample3D(nn.Sequential): | |
| def __init__(self, scale, num_feat): | |
| m = [] | |
| if (scale & (scale - 1)) == 0 and scale != 0: | |
| for _ in range(int(math.log2(scale))): | |
| m.append(nn.Conv3d(num_feat, (2 ** 3) * num_feat, 3, 1, 1)) | |
| m.append(PixelShuffle3D(2)) | |
| elif scale == 3: | |
| m.append(nn.Conv3d(num_feat, (3 ** 3) * num_feat, 3, 1, 1)) | |
| m.append(PixelShuffle3D(3)) | |
| else: | |
| raise ValueError(f'scale {scale} is not supported. Use powers of 2 or 3.') | |
| super(Upsample3D, self).__init__(*m) | |
| class UpsampleOneStep3D(nn.Sequential): | |
| def __init__(self, scale, num_feat, num_out_ch, input_resolution): | |
| self.num_feat = num_feat | |
| self.input_resolution = input_resolution | |
| m = [ | |
| nn.Conv3d(num_feat, (scale ** 3) * num_out_ch, 3, 1, 1), | |
| PixelShuffle3D(scale) | |
| ] | |
| super(UpsampleOneStep3D, self).__init__(*m) | |
| # ----------------------------------------------------------------------------------- | |
| # MAIN SwinIR MODEL | |
| # ----------------------------------------------------------------------------------- | |
| class SwinIR3D(nn.Module): | |
| def __init__(self, img_size=(64, 64, 64), patch_size=(1, 1, 1), in_chans=1, | |
| embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), | |
| window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, | |
| drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, | |
| norm_layer=nn.LayerNorm, ape=False, patch_norm=True, | |
| use_checkpoint=False, upscale=2, img_range=1., upsampler='pixelshuffle', | |
| resi_connection='1conv', **kwargs): | |
| super(SwinIR3D, self).__init__() | |
| self.img_size_tuple = to_3tuple(img_size) | |
| self.patch_size_tuple = to_3tuple(patch_size) | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.upscale = upscale | |
| self.upsampler = upsampler | |
| self.img_range = img_range | |
| self.mean = torch.zeros(1, in_chans, 1, 1, 1) | |
| self.conv_first = nn.Conv3d(in_chans, embed_dim, 3, 1, 1) | |
| self.patch_embed = PatchEmbed3D( | |
| img_size=self.img_size_tuple, patch_size=self.patch_size_tuple, in_chans=embed_dim, | |
| embed_dim=embed_dim, norm_layer=norm_layer if patch_norm else None) | |
| self.patches_resolution = self.patch_embed.patches_resolution | |
| self.patch_unembed = PatchUnEmbed3D(img_size=self.patches_resolution, embed_dim=embed_dim) | |
| self.pos_drop = nn.Dropout(p=drop_rate) | |
| if ape: | |
| self.absolute_pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, embed_dim)) | |
| trunc_normal_(self.absolute_pos_embed, std=.02) | |
| else: | |
| self.absolute_pos_embed = None | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
| self.layers = nn.ModuleList() | |
| for i_layer in range(len(depths)): | |
| layer = RSTB3D( | |
| dim=embed_dim, input_resolution=self.patches_resolution, depth=depths[i_layer], | |
| num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio, | |
| qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, | |
| drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], | |
| norm_layer=norm_layer, downsample=None, use_checkpoint=use_checkpoint, | |
| resi_connection=resi_connection) | |
| self.layers.append(layer) | |
| self.norm_after_body = norm_layer(embed_dim) | |
| if resi_connection == '1conv': | |
| self.conv_after_body = nn.Conv3d(embed_dim, embed_dim, 3, 1, 1) | |
| elif resi_connection == '3conv': | |
| self.conv_after_body = nn.Sequential( | |
| nn.Conv3d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv3d(embed_dim // 4, embed_dim, 3, 1, 1)) | |
| if self.upsampler == 'pixelshuffle': | |
| self.conv_before_upsample = nn.Sequential(nn.Conv3d(embed_dim, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)) | |
| self.upsample = Upsample3D(upscale, embed_dim) | |
| self.conv_last = nn.Conv3d(embed_dim, in_chans, 3, 1, 1) | |
| elif self.upsampler == 'pixelshuffledirect': | |
| self.upsample = UpsampleOneStep3D(upscale, embed_dim, in_chans, input_resolution=self.patches_resolution) | |
| self.conv_last = None # No final conv needed as UpsampleOneStep3D handles it | |
| else: # No upsampling | |
| self.upsample = None | |
| self.conv_last = nn.Conv3d(embed_dim, in_chans, 3, 1, 1) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, nn.Linear): | |
| trunc_normal_(m.weight, std=.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.bias, 0) | |
| nn.init.constant_(m.weight, 1.0) | |
| def forward_features(self, x_feat_in): | |
| x_tokens = self.patch_embed(x_feat_in) | |
| if self.absolute_pos_embed is not None: | |
| x_tokens = x_tokens + self.absolute_pos_embed | |
| x_tokens = self.pos_drop(x_tokens) | |
| for layer in self.layers: | |
| x_tokens = layer(x_tokens) | |
| x_tokens = self.norm_after_body(x_tokens) | |
| x_unembedded_main = self.patch_unembed(x_tokens, self.patches_resolution) | |
| return x_unembedded_main | |
| def forward(self, x_input): | |
| self.mean = self.mean.type_as(x_input) | |
| x_norm = (x_input - self.mean) * self.img_range | |
| x_shallow = self.conv_first(x_norm) | |
| res_deep_spatial = self.forward_features(x_shallow) | |
| res_deep_conv = self.conv_after_body(res_deep_spatial) | |
| # Long skip connection | |
| if x_shallow.shape == res_deep_conv.shape: | |
| x_body_out = x_shallow + res_deep_conv | |
| else: | |
| print(f"Warning: Skipping long skip connection due to shape mismatch. " | |
| f"Shallow: {x_shallow.shape}, Deep: {res_deep_conv.shape}. " | |
| f"This is expected if main patch_size > (1,1,1).") | |
| x_body_out = res_deep_conv | |
| if self.upsampler == 'pixelshuffle': | |
| x_upsample_in = self.conv_before_upsample(x_body_out) | |
| x_upsampled = self.upsample(x_upsample_in) | |
| x_final = self.conv_last(x_upsampled) | |
| elif self.upsampler == 'pixelshuffledirect': | |
| x_final = self.upsample(x_body_out) | |
| else: | |
| x_final = self.conv_last(x_body_out) | |
| x_final = x_final / self.img_range + self.mean | |
| return x_final | |
| # # --- Example Instantiation --- | |
| # try: | |
| # # --- 1. Define Model Architecture --- | |
| # # NOTE: in_chans is set to 1 to match the input tensor `x` | |
| # model = SwinIR3D( | |
| # img_size=32, | |
| # patch_size=1, | |
| # in_chans=1, | |
| # embed_dim=180, | |
| # depths=[2, 2, 2], # Reduced depth for faster testing | |
| # num_heads=[6, 6, 6], | |
| # window_size=8, | |
| # mlp_ratio=4., | |
| # upscale=2, | |
| # upsampler='pixelshuffledirect', | |
| # resi_connection='1conv' | |
| # ) | |
| # # --- 2. Create Dummy Input --- | |
| # # Input tensor with 1 channel, matching `in_chans=1` above | |
| # original_image = np.load('lq_image.npy') | |
| # original_shape = original_image.shape # H, W, D | |
| # vol = original_image.astype(np.float32) | |
| # # print(f"Original image shape: {original_image.shape}, dtype: {original_image.dtype}") | |
| # # # --- Scenario: Downsample and then pad to next power of two --- | |
| # # print("\n--- Scenario: Downsample and pad to next power of two ---") | |
| # # # Initialize the processor with a downsample scale | |
| # y = torch.from_numpy(vol).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions | |
| # y = y.float() # Ensure float type | |
| # print(f"Input tensor shape after processing: {y.shape}, dtype: {y.dtype}") | |
| # x = y | |
| # print(f"Created dummy input tensor with shape: {x.shape}") | |
| # # --- 3. (Optional) Load Pre-trained Weights --- | |
| # # This block is wrapped in a try-except to allow the script to run | |
| # # without the specific weight file. | |
| # state_dict_3d = torch.load('swinir_inflated3D.pth', weights_only=True) | |
| # model.load_state_dict(state_dict_3d, strict=False) | |
| # model = model.cuda() | |
| # x = x.cuda() | |
| # print("Model and data moved to GPU.") | |
| # # --- 4. Run Forward Pass --- | |
| # print("\nExecuting model forward pass...") | |
| # with torch.no_grad(): # Use no_grad for inference | |
| # output = model(x) | |
| # np.save('output_swinir.npy', output.cpu().numpy()) | |
| # print("Model ran successfully!") | |
| # print(f"Input shape: {x.shape}") | |
| # print(f"Output shape: {output.shape}") | |
| # # assert output.shape == (1, 1, 128, 64, 64), "Output shape is not as expected!" | |
| # print(f"\nOutput shape is correct for upscale={4}") | |
| # except Exception as e: | |
| # print(f"\n--- An Error Occurred ---") | |
| # print(f"Error during model execution: {e}") | |
| # import traceback | |
| # traceback.print_exc() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment