Skip to content

Instantly share code, notes, and snippets.

@mosure
Created January 5, 2025 04:34
Show Gist options
  • Select an option

  • Save mosure/d9d4d271e05a106157ce39db62ec4f84 to your computer and use it in GitHub Desktop.

Select an option

Save mosure/d9d4d271e05a106157ce39db62ec4f84 to your computer and use it in GitHub Desktop.
# compatible with https://github.com/fudan-zvg/4d-gaussian-splatting
import sys
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, OptimizationParams
import numpy as np
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from plyfile import PlyData, PlyElement
import torch
from scene import GaussianModel
def export_gaussians_to_ply(gaussians, ply_filename="output.ply4d"):
"""
Export Gaussian properties to a PLY file with the following layout:
- xyzt: (N, 4)
- opacity: (N, 1)
- rotation: (N, 4)
- rotation_r: (N, 4)
- scaling_xyzt: (N, 4)
- features: (N, 48, 3)
The features are labeled like feat_r_0, feat_g_0, feat_b_0, ... feat_r_47, feat_g_47, feat_b_47.
"""
# Fetch data as NumPy arrays
xyzt = gaussians.get_xyzt.detach().cpu().numpy() # (N, 4)
opacity = gaussians.get_opacity.detach().cpu().numpy() # (N, 1)
rotation = gaussians.get_rotation.detach().cpu().numpy() # (N, 4)
rotation_r = gaussians.get_rotation_r.detach().cpu().numpy() # (N, 4)
scaling_xyzt = gaussians.get_scaling_xyzt.detach().cpu().numpy() # (N, 4)
features = gaussians.get_features.detach().cpu().numpy() # (N, 48, 3)
N = xyzt.shape[0]
# 48 "blocks", each with 3 channels => r, g, b
num_feat_groups = features.shape[1] # 48
feat_channels = features.shape[2] # 3
# Define base fields in the structured array
ply_fields = [
("x", "f4"), ("y", "f4"), ("z", "f4"), ("t", "f4"),
("opacity", "f4"),
("rot_x", "f4"), ("rot_y", "f4"), ("rot_z", "f4"), ("rot_w", "f4"),
("rot_r_x", "f4"), ("rot_r_y", "f4"), ("rot_r_z", "f4"), ("rot_r_w", "f4"),
("sx", "f4"), ("sy", "f4"), ("sz", "f4"), ("st", "f4"),
]
# Dynamically add feature columns:
# feat_r_0, feat_g_0, feat_b_0 ... feat_r_47, feat_g_47, feat_b_47
channel_labels = ["r", "g", "b"] # index 0->r, 1->g, 2->b
for i in range(num_feat_groups):
for ci, c_label in enumerate(channel_labels):
ply_fields.append((f"feat_{c_label}_{i}", "f4"))
# Create the structured array
vertex_array = np.zeros(N, dtype=ply_fields)
# Fill base fields
vertex_array["x"] = xyzt[:, 0]
vertex_array["y"] = xyzt[:, 1]
vertex_array["z"] = xyzt[:, 2]
vertex_array["t"] = xyzt[:, 3]
vertex_array["opacity"] = opacity[:, 0] # (N, 1) -> (N,)
vertex_array["rot_x"] = rotation[:, 0]
vertex_array["rot_y"] = rotation[:, 1]
vertex_array["rot_z"] = rotation[:, 2]
vertex_array["rot_w"] = rotation[:, 3]
vertex_array["rot_r_x"] = rotation_r[:, 0]
vertex_array["rot_r_y"] = rotation_r[:, 1]
vertex_array["rot_r_z"] = rotation_r[:, 2]
vertex_array["rot_r_w"] = rotation_r[:, 3]
vertex_array["sx"] = scaling_xyzt[:, 0]
vertex_array["sy"] = scaling_xyzt[:, 1]
vertex_array["sz"] = scaling_xyzt[:, 2]
vertex_array["st"] = scaling_xyzt[:, 3]
# Fill features
# features.shape = (N, 48, 3)
for i in range(num_feat_groups): # 0..47
for ci, c_label in enumerate(channel_labels): # 0->r, 1->g, 2->b
vertex_array[f"feat_{c_label}_{i}"] = features[:, i, ci]
# Wrap and save the data to a text-based PLY
# Remove text=True if you prefer a binary file
ply_el = PlyElement.describe(vertex_array, "vertex")
PlyData([ply_el], text=False).write(ply_filename)
def export(
dataset,
opt,
pipe,
testing_iterations,
saving_iterations,
checkpoint,
debug_from,
gaussian_dim,
time_duration,
num_pts,
num_pts_ratio,
rot_4d,
force_sh_3d,
batch_size,
):
gaussians = GaussianModel(
dataset.sh_degree,
gaussian_dim=gaussian_dim,
time_duration=time_duration,
rot_4d=rot_4d,
force_sh_3d=force_sh_3d,
sh_degree_t=2 if pipe.eval_shfs_4d else 0
)
gaussians.training_setup(opt)
print('trying to load from', checkpoint)
(model_params, first_iter) = torch.load(checkpoint)
gaussians.restore(model_params, opt)
print('xyzt', gaussians.get_xyzt.shape)
print('opacity', gaussians.get_opacity.shape)
print('rotation', gaussians.get_rotation.shape)
print('rotation_r', gaussians.get_rotation_r.shape)
print('scaling_xyzt', gaussians.get_scaling_xyzt.shape)
print('features', gaussians.get_features.shape)
export_gaussians_to_ply(gaussians, "output.ply4d")
if __name__ == "__main__":
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
parser.add_argument("--config", type=str)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument('--detect_anomaly', action='store_true', default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--start_checkpoint", type=str, default = None)
parser.add_argument("--gaussian_dim", type=int, default=3)
parser.add_argument("--time_duration", nargs=2, type=float, default=[-0.5, 0.5])
parser.add_argument('--num_pts', type=int, default=100_000)
parser.add_argument('--num_pts_ratio', type=float, default=1.0)
parser.add_argument("--rot_4d", action="store_true")
parser.add_argument("--force_sh_3d", action="store_true")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seed", type=int, default=6666)
parser.add_argument("--exhaust_test", action="store_true")
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
cfg = OmegaConf.load(args.config)
def recursive_merge(key, host):
if isinstance(host[key], DictConfig):
for key1 in host[key].keys():
recursive_merge(key1, host[key])
else:
assert hasattr(args, key), key
setattr(args, key, host[key])
for k in cfg.keys():
recursive_merge(k, cfg)
export(
lp.extract(args),
op.extract(args),
pp.extract(args),
args.test_iterations,
args.save_iterations,
args.start_checkpoint,
args.debug_from,
args.gaussian_dim,
args.time_duration,
args.num_pts,
args.num_pts_ratio,
args.rot_4d,
args.force_sh_3d,
args.batch_size,
)
print("\export complete.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment