diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3e544201..52e7a233 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -2,6 +2,7 @@ import math import os import time +import struct from dataclasses import dataclass, field from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -84,6 +85,8 @@ class Config: eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model as ply + ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Initialization strategy init_type: str = "sfm" @@ -164,6 +167,7 @@ class Config: def adjust_steps(self, factor: float): self.eval_steps = [int(i * factor) for i in self.eval_steps] self.save_steps = [int(i * factor) for i in self.save_steps] + self.ply_steps = [int(i * factor) for i in self.ply_steps] self.max_steps = int(self.max_steps * factor) self.sh_degree_interval = int(self.sh_degree_interval * factor) @@ -181,6 +185,97 @@ def adjust_steps(self, factor: float): assert_never(strategy) +def save_ply(splats: torch.nn.ParameterDict, dir: str, colors: torch.Tensor = None): + # Convert all tensors to numpy arrays in one go + print(f"Saving ply to {dir}") + numpy_data = {k: v.detach().cpu().numpy() for k, v in splats.items()} + + means = numpy_data["means"] + scales = numpy_data["scales"] + quats = numpy_data["quats"] + opacities = numpy_data["opacities"] + + sh0 = numpy_data["sh0"].transpose(0, 2, 1).reshape(means.shape[0], -1) + shN = numpy_data["shN"].transpose(0, 2, 1).reshape(means.shape[0], -1) + + # Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays + invalid_mask = ( + np.isnan(means).any(axis=1) + | np.isinf(means).any(axis=1) + | np.isnan(scales).any(axis=1) + | np.isinf(scales).any(axis=1) + | np.isnan(quats).any(axis=1) + | np.isinf(quats).any(axis=1) + | np.isnan(opacities).any(axis=0) + | np.isinf(opacities).any(axis=0) + | np.isnan(sh0).any(axis=1) + | np.isinf(sh0).any(axis=1) + | np.isnan(shN).any(axis=1) + | np.isinf(shN).any(axis=1) + ) + + # Filter out rows with NaNs or Infs from all data arrays + means = means[~invalid_mask] + scales = scales[~invalid_mask] + quats = quats[~invalid_mask] + opacities = opacities[~invalid_mask] + sh0 = sh0[~invalid_mask] + shN = shN[~invalid_mask] + + num_points = means.shape[0] + + with open(dir, "wb") as f: + # Write PLY header + f.write(b"ply\n") + f.write(b"format binary_little_endian 1.0\n") + f.write(f"element vertex {num_points}\n".encode()) + f.write(b"property float x\n") + f.write(b"property float y\n") + f.write(b"property float z\n") + f.write(b"property float nx\n") + f.write(b"property float ny\n") + f.write(b"property float nz\n") + + if colors is not None: + for j in range(colors.shape[1]): + f.write(f"property float f_dc_{j}\n".encode()) + else: + for i, data in enumerate([sh0, shN]): + prefix = "f_dc" if i == 0 else "f_rest" + for j in range(data.shape[1]): + f.write(f"property float {prefix}_{j}\n".encode()) + + f.write(b"property float opacity\n") + + for i in range(scales.shape[1]): + f.write(f"property float scale_{i}\n".encode()) + for i in range(quats.shape[1]): + f.write(f"property float rot_{i}\n".encode()) + + f.write(b"end_header\n") + + # Write vertex data + for i in range(num_points): + f.write(struct.pack("