Skip to content

Instantly share code, notes, and snippets.

@johnbanq
Created June 5, 2025 17:28
Show Gist options
  • Select an option

  • Save johnbanq/c8c3bb2da8c738e2407bc9b3691bea27 to your computer and use it in GitHub Desktop.

Select an option

Save johnbanq/c8c3bb2da8c738e2407bc9b3691bea27 to your computer and use it in GitHub Desktop.
tmp
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath, to_2tuple, trunc_normal_
import torch.utils.checkpoint as checkpoint
import nibabel as nib
import numpy as np
def to_3tuple(x):
"""Converts input to a 3-element tuple."""
if isinstance(x, tuple):
return x
return (x, x, x)
# -----------------------------------------------------------------------------------
# STAGE 1: STEM - Shallow Feature Extraction
# -----------------------------------------------------------------------------------
class PatchEmbed3D(nn.Module):
r""" Volume to Patch Embedding for 3D inputs
Args:
img_size (tuple[int]): Input volume size (D, H, W).
patch_size (tuple[int]): Patch size (Pd, Ph, Pw).
in_chans (int): Number of input channels.
embed_dim (int): Number of output channels (tokens).
norm_layer (nn.Module, optional): Optional normalization layer
"""
def __init__(self, img_size=(64, 224, 224), patch_size=(4, 4, 4), in_chans=1, embed_dim=96, norm_layer=None):
super().__init__()
self.img_size = to_3tuple(img_size)
self.patch_size = to_3tuple(patch_size)
self.in_chans = in_chans
self.embed_dim = embed_dim
self.patches_resolution = tuple([self.img_size[0] // self.patch_size[0],
self.img_size[1] // self.patch_size[1],
self.img_size[2] // self.patch_size[2]])
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] * self.patches_resolution[2]
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else None
def forward(self, x):
# x shape: (B, C_in, D_img, H_img, W_img)
x = self.proj(x) # -> (B, embed_dim, D_p, H_p, W_p)
B, C, D_p, H_p, W_p = x.shape
assert (D_p, H_p, W_p) == self.patches_resolution, \
f"Projected shape {(D_p, H_p, W_p)} does not match internal patches_resolution {self.patches_resolution}"
x = x.flatten(2).transpose(1, 2) # -> (B, N_patches, embed_dim)
if self.norm:
x = self.norm(x)
return x
def flops(self):
# FLOPs for the projection convolution
flops = (self.patches_resolution[0] * self.patches_resolution[1] * self.patches_resolution[2]) * \
self.in_chans * self.embed_dim * self.patch_size[0] * self.patch_size[1] * self.patch_size[2]
if self.norm:
flops += self.num_patches * self.embed_dim # FLOPs for LayerNorm
return flops
class PatchUnEmbed3D(nn.Module):
"""
Patch Unembedding for 3D volumetric data.
Converts patch tokens [B, L, C] -> volume [B, C, D_p, H_p, W_p].
"""
def __init__(self, img_size, embed_dim):
super().__init__()
self.patches_resolution = to_3tuple(img_size)
self.embed_dim = embed_dim
def forward(self, x, x_size_patch_grid):
B, L, C = x.shape
D_p, H_p, W_p = x_size_patch_grid
assert L == D_p * H_p * W_p, f"Mismatch between token count L={L} and patch grid volume {D_p*H_p*W_p}"
assert (D_p, H_p, W_p) == self.patches_resolution, \
f"x_size_patch_grid {x_size_patch_grid} must match internal patches_resolution {self.patches_resolution}"
x = x.transpose(1, 2).contiguous().view(B, C, D_p, H_p, W_p)
return x
def flops(self):
return 0
# -----------------------------------------------------------------------------------
# STAGE 2: BODY - Transformer Backbone
# -----------------------------------------------------------------------------------
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition_3d(x, window_size_tuple):
B, D, H, W, C = x.shape
WD, WH, WW = window_size_tuple
x = x.view(B, D // WD, WD, H // WH, WH, W // WW, WW, C)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, WD, WH, WW, C)
return windows
def window_reverse_3d(windows, window_size_tuple, D, H, W):
WD, WH, WW = window_size_tuple
num_windows_d = D // WD
num_windows_h = H // WH
num_windows_w = W // WW
B_times_num_windows = windows.shape[0]
B = B_times_num_windows // (num_windows_d * num_windows_h * num_windows_w)
x = windows.view(B, num_windows_d, num_windows_h, num_windows_w, WD, WH, WW, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, D, H, W, -1)
return x
class WindowAttention3D(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = to_3tuple(window_size)
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1), num_heads))
coords_d = torch.arange(self.window_size[0])
coords_h = torch.arange(self.window_size[1])
coords_w = torch.arange(self.window_size[2])
coords = torch.stack(torch.meshgrid([coords_d, coords_h, coords_w], indexing='ij'))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
B_, N_win, C = x.shape
qkv = self.qkv(x).reshape(B_, N_win, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
N_win, N_win, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N_win, N_win) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N_win, N_win)
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N_win, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock3D(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = to_3tuple(input_resolution)
self.num_heads = num_heads
self.window_size_tuple = to_3tuple(window_size)
self.shift_size_tuple = to_3tuple(shift_size)
self.mlp_ratio = mlp_ratio
if any(i_res < w_size for i_res, w_size in zip(self.input_resolution, self.window_size_tuple)):
self.window_size_tuple = self.input_resolution
self.shift_size_tuple = (0, 0, 0)
if any(s_size >= w_size for s_size, w_size in zip(self.shift_size_tuple, self.window_size_tuple)):
self.shift_size_tuple = (0, 0, 0)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention3D(dim, self.window_size_tuple, num_heads, qkv_bias, qk_scale, attn_drop, drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
if any(s > 0 for s in self.shift_size_tuple):
D_p, H_p, W_p = self.input_resolution
img_mask = torch.zeros((1, D_p, H_p, W_p, 1))
d_slices = (slice(0, -self.window_size_tuple[0]), slice(-self.window_size_tuple[0], -self.shift_size_tuple[0]), slice(-self.shift_size_tuple[0], None))
h_slices = (slice(0, -self.window_size_tuple[1]), slice(-self.window_size_tuple[1], -self.shift_size_tuple[1]), slice(-self.shift_size_tuple[1], None))
w_slices = (slice(0, -self.window_size_tuple[2]), slice(-self.window_size_tuple[2], -self.shift_size_tuple[2]), slice(-self.shift_size_tuple[2], None))
cnt = 0
for d in d_slices:
for h in h_slices:
for w in w_slices:
img_mask[:, d, h, w, :] = cnt
cnt += 1
mask_windows = window_partition_3d(img_mask, self.window_size_tuple)
mask_windows = mask_windows.view(-1, self.window_size_tuple[0] * self.window_size_tuple[1] * self.window_size_tuple[2])
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
B, L, C = x.shape
D_p, H_p, W_p = self.input_resolution
assert L == D_p * H_p * W_p, f"Input feature has wrong size L={L}, expected {D_p*H_p*W_p}"
shortcut = x
x = self.norm1(x)
x = x.view(B, D_p, H_p, W_p, C)
if any(s > 0 for s in self.shift_size_tuple):
x = torch.roll(x, shifts=(-self.shift_size_tuple[0], -self.shift_size_tuple[1], -self.shift_size_tuple[2]), dims=(1, 2, 3))
x_windows = window_partition_3d(x, self.window_size_tuple)
x_windows = x_windows.view(-1, self.window_size_tuple[0] * self.window_size_tuple[1] * self.window_size_tuple[2], C)
attn_windows = self.attn(x_windows, mask=self.attn_mask)
attn_windows = attn_windows.view(-1, self.window_size_tuple[0], self.window_size_tuple[1], self.window_size_tuple[2], C)
x = window_reverse_3d(attn_windows, self.window_size_tuple, D_p, H_p, W_p)
if any(s > 0 for s in self.shift_size_tuple):
x = torch.roll(x, shifts=(self.shift_size_tuple[0], self.shift_size_tuple[1], self.shift_size_tuple[2]), dims=(1, 2, 3))
x = x.view(B, D_p * H_p * W_p, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class BasicLayer3D(nn.Module):
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = to_3tuple(input_resolution)
self.depth = depth
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
SwinTransformerBlock3D(
dim=dim, input_resolution=self.input_resolution, num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else tuple(ws // 2 for ws in to_3tuple(window_size)),
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
for i in range(depth)])
if downsample is not None:
self.downsample = downsample(self.input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint and self.training:
x = checkpoint.checkpoint(blk, x, use_reentrant=False)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class RSTB3D(nn.Module):
"""Residual Swin Transformer Block (3D version)."""
def __init__(self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4., qkv_bias=True,
qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
downsample=None, use_checkpoint=False, resi_connection='1conv'):
super().__init__()
self.dim = dim
self.input_resolution = to_3tuple(input_resolution)
self.residual_group = BasicLayer3D(dim=dim, input_resolution=self.input_resolution, depth=depth,
num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
drop_path=drop_path, norm_layer=norm_layer, downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv3d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
self.conv = nn.Sequential(
nn.Conv3d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(dim // 4, dim, 3, 1, 1))
self.patch_unembed_res = PatchUnEmbed3D(img_size=self.input_resolution, embed_dim=dim)
self.patch_embed_res = PatchEmbed3D(img_size=self.input_resolution, patch_size=(1, 1, 1), in_chans=dim, embed_dim=dim)
def forward(self, x):
residual = x
x_processed_tokens = self.residual_group(x)
x_spatial_grid = self.patch_unembed_res(x_processed_tokens, self.input_resolution)
x_conv_out = self.conv(x_spatial_grid)
x_tokens_from_conv = self.patch_embed_res(x_conv_out)
return x_tokens_from_conv + residual
# -----------------------------------------------------------------------------------
# STAGE 3: HEAD - High-Quality Image Reconstruction
# -----------------------------------------------------------------------------------
class PixelShuffle3D(nn.Module):
def __init__(self, upscale_factor):
super(PixelShuffle3D, self).__init__()
self.r = upscale_factor
def forward(self, x):
B, C, D, H, W = x.shape
r = self.r
assert C % (r ** 3) == 0, f"Channels {C} must be divisible by upscale_factor^3={r**3}"
out_c = C // (r ** 3)
x = x.view(B, out_c, r, r, r, D, H, W)
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
x = x.view(B, out_c, D * r, H * r, W * r)
return x
class Upsample3D(nn.Sequential):
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0 and scale != 0:
for _ in range(int(math.log2(scale))):
m.append(nn.Conv3d(num_feat, (2 ** 3) * num_feat, 3, 1, 1))
m.append(PixelShuffle3D(2))
elif scale == 3:
m.append(nn.Conv3d(num_feat, (3 ** 3) * num_feat, 3, 1, 1))
m.append(PixelShuffle3D(3))
else:
raise ValueError(f'scale {scale} is not supported. Use powers of 2 or 3.')
super(Upsample3D, self).__init__(*m)
class UpsampleOneStep3D(nn.Sequential):
def __init__(self, scale, num_feat, num_out_ch, input_resolution):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = [
nn.Conv3d(num_feat, (scale ** 3) * num_out_ch, 3, 1, 1),
PixelShuffle3D(scale)
]
super(UpsampleOneStep3D, self).__init__(*m)
# -----------------------------------------------------------------------------------
# MAIN SwinIR MODEL
# -----------------------------------------------------------------------------------
class SwinIR3D(nn.Module):
def __init__(self, img_size=(64, 64, 64), patch_size=(1, 1, 1), in_chans=1,
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='pixelshuffle',
resi_connection='1conv', **kwargs):
super(SwinIR3D, self).__init__()
self.img_size_tuple = to_3tuple(img_size)
self.patch_size_tuple = to_3tuple(patch_size)
self.in_chans = in_chans
self.embed_dim = embed_dim
self.upscale = upscale
self.upsampler = upsampler
self.img_range = img_range
self.mean = torch.zeros(1, in_chans, 1, 1, 1)
self.conv_first = nn.Conv3d(in_chans, embed_dim, 3, 1, 1)
self.patch_embed = PatchEmbed3D(
img_size=self.img_size_tuple, patch_size=self.patch_size_tuple, in_chans=embed_dim,
embed_dim=embed_dim, norm_layer=norm_layer if patch_norm else None)
self.patches_resolution = self.patch_embed.patches_resolution
self.patch_unembed = PatchUnEmbed3D(img_size=self.patches_resolution, embed_dim=embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
if ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
else:
self.absolute_pos_embed = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.layers = nn.ModuleList()
for i_layer in range(len(depths)):
layer = RSTB3D(
dim=embed_dim, input_resolution=self.patches_resolution, depth=depths[i_layer],
num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer, downsample=None, use_checkpoint=use_checkpoint,
resi_connection=resi_connection)
self.layers.append(layer)
self.norm_after_body = norm_layer(embed_dim)
if resi_connection == '1conv':
self.conv_after_body = nn.Conv3d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
self.conv_after_body = nn.Sequential(
nn.Conv3d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(embed_dim // 4, embed_dim, 3, 1, 1))
if self.upsampler == 'pixelshuffle':
self.conv_before_upsample = nn.Sequential(nn.Conv3d(embed_dim, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True))
self.upsample = Upsample3D(upscale, embed_dim)
self.conv_last = nn.Conv3d(embed_dim, in_chans, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
self.upsample = UpsampleOneStep3D(upscale, embed_dim, in_chans, input_resolution=self.patches_resolution)
self.conv_last = None # No final conv needed as UpsampleOneStep3D handles it
else: # No upsampling
self.upsample = None
self.conv_last = nn.Conv3d(embed_dim, in_chans, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_features(self, x_feat_in):
x_tokens = self.patch_embed(x_feat_in)
if self.absolute_pos_embed is not None:
x_tokens = x_tokens + self.absolute_pos_embed
x_tokens = self.pos_drop(x_tokens)
for layer in self.layers:
x_tokens = layer(x_tokens)
x_tokens = self.norm_after_body(x_tokens)
x_unembedded_main = self.patch_unembed(x_tokens, self.patches_resolution)
return x_unembedded_main
def forward(self, x_input):
self.mean = self.mean.type_as(x_input)
x_norm = (x_input - self.mean) * self.img_range
x_shallow = self.conv_first(x_norm)
res_deep_spatial = self.forward_features(x_shallow)
res_deep_conv = self.conv_after_body(res_deep_spatial)
# Long skip connection
if x_shallow.shape == res_deep_conv.shape:
x_body_out = x_shallow + res_deep_conv
else:
print(f"Warning: Skipping long skip connection due to shape mismatch. "
f"Shallow: {x_shallow.shape}, Deep: {res_deep_conv.shape}. "
f"This is expected if main patch_size > (1,1,1).")
x_body_out = res_deep_conv
if self.upsampler == 'pixelshuffle':
x_upsample_in = self.conv_before_upsample(x_body_out)
x_upsampled = self.upsample(x_upsample_in)
x_final = self.conv_last(x_upsampled)
elif self.upsampler == 'pixelshuffledirect':
x_final = self.upsample(x_body_out)
else:
x_final = self.conv_last(x_body_out)
x_final = x_final / self.img_range + self.mean
return x_final
# # --- Example Instantiation ---
# try:
# # --- 1. Define Model Architecture ---
# # NOTE: in_chans is set to 1 to match the input tensor `x`
# model = SwinIR3D(
# img_size=32,
# patch_size=1,
# in_chans=1,
# embed_dim=180,
# depths=[2, 2, 2], # Reduced depth for faster testing
# num_heads=[6, 6, 6],
# window_size=8,
# mlp_ratio=4.,
# upscale=2,
# upsampler='pixelshuffledirect',
# resi_connection='1conv'
# )
# # --- 2. Create Dummy Input ---
# # Input tensor with 1 channel, matching `in_chans=1` above
# original_image = np.load('lq_image.npy')
# original_shape = original_image.shape # H, W, D
# vol = original_image.astype(np.float32)
# # print(f"Original image shape: {original_image.shape}, dtype: {original_image.dtype}")
# # # --- Scenario: Downsample and then pad to next power of two ---
# # print("\n--- Scenario: Downsample and pad to next power of two ---")
# # # Initialize the processor with a downsample scale
# y = torch.from_numpy(vol).unsqueeze(0).unsqueeze(0) # Add batch and channel dimensions
# y = y.float() # Ensure float type
# print(f"Input tensor shape after processing: {y.shape}, dtype: {y.dtype}")
# x = y
# print(f"Created dummy input tensor with shape: {x.shape}")
# # --- 3. (Optional) Load Pre-trained Weights ---
# # This block is wrapped in a try-except to allow the script to run
# # without the specific weight file.
# state_dict_3d = torch.load('swinir_inflated3D.pth', weights_only=True)
# model.load_state_dict(state_dict_3d, strict=False)
# model = model.cuda()
# x = x.cuda()
# print("Model and data moved to GPU.")
# # --- 4. Run Forward Pass ---
# print("\nExecuting model forward pass...")
# with torch.no_grad(): # Use no_grad for inference
# output = model(x)
# np.save('output_swinir.npy', output.cpu().numpy())
# print("Model ran successfully!")
# print(f"Input shape: {x.shape}")
# print(f"Output shape: {output.shape}")
# # assert output.shape == (1, 1, 128, 64, 64), "Output shape is not as expected!"
# print(f"\nOutput shape is correct for upscale={4}")
# except Exception as e:
# print(f"\n--- An Error Occurred ---")
# print(f"Error during model execution: {e}")
# import traceback
# traceback.print_exc()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment