Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add visualization #60

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/dynerf/cook_spinach.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ rot_4d: True
force_sh_3d: False
batch_size: 4
exhaust_test: True
start_checkpoint: "output/N3V/cook_spinach/chkpnt_best.pth"

ModelParams:
sh_degree: 3
source_path: "data/N3V/cook_spinach"
source_path: "/home/zhangyidi/Projects/4DGS/data/N3V/cook_spinach"
model_path: "output/N3V/cook_spinach"
images: "images"
resolution: 2
Expand Down
61 changes: 61 additions & 0 deletions configs/dynerf/cook_spinach_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
gaussian_dim: 4
time_duration: [0.0, 10.0]
num_pts: 300_000
num_pts_ratio: 1.0
rot_4d: True
force_sh_3d: False
batch_size: 4
exhaust_test: True
start_checkpoint: "output/N3V/cook_spinach/chkpnt_best.pth"
val: True

ModelParams:
sh_degree: 3
source_path: "/home/zhangyidi/Projects/4DGS/data/N3V/cook_spinach"
model_path: "output/N3V/cook_spinach"
images: "images"
resolution: 2
white_background: False
data_device: "cuda"
eval: True
extension: ".png"
num_extra_pts: 0
loaded_pth: ""
frame_ratio: 1
dataloader: True

PipelineParams:
convert_SHs_python: False
compute_cov3D_python: False
debug: False
env_map_res: 0
env_optimize_until: 1000000000
env_optimize_from: 0
eval_shfs_4d: True

OptimizationParams:
iterations: 30_000
position_lr_init: 0.00016
position_t_lr_init: -1.0
position_lr_final: 0.0000016
position_lr_delay_mult: 0.01
position_lr_max_steps: 30_000
feature_lr: 0.0025
opacity_lr: 0.05
scaling_lr: 0.005
rotation_lr: 0.001
percent_dense: 0.01
lambda_dssim: 0.2
thresh_opa_prune: 0.005
densification_interval: 100
opacity_reset_interval: 3000
densify_from_iter: 500
densify_until_iter: 15_000
densify_grad_threshold: 0.0002
densify_grad_t_threshold: 0.0002 / 40 # 想办法用上
densify_until_num_points: -1
final_prune_from_iter: -1
sh_increase_interval: 1000
lambda_opa_mask: 0.0
lambda_rigid: 0.0
lambda_motion: 0.0
3 changes: 3 additions & 0 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh, eval_shfs_4d
from collections import defaultdict

def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
"""
Expand Down Expand Up @@ -145,6 +146,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
flow_2d = flow_2d[mask]

# Rasterize visible Gaussians to image, obtain their radii (on screen).

rendered_image, radii, depth, alpha, flow, covs_com = rasterizer(
means3D = means3D,
means2D = means2D,
Expand All @@ -158,6 +160,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
rotations = rotations,
rotations_r = rotations_r,
cov3D_precomp = cov3D_precomp)


if pipe.env_map_res:
assert pc.env_map is not None
Expand Down
58 changes: 54 additions & 4 deletions train.py → main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,56 @@
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader

from utils.mesh_utils import GaussianExtractor, to_cam_open3d, post_process_mesh
from utils.render_utils import generate_path, create_videos
import open3d as o3d

try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_FOUND = True
except ImportError:
TENSORBOARD_FOUND = False

def validation(dataset, opt, pipe,checkpoint, gaussian_dim, time_duration, rot_4d, force_sh_3d,
num_pts, num_pts_ratio):

bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

gaussians = GaussianModel(dataset.sh_degree, gaussian_dim=gaussian_dim, time_duration=time_duration,
rot_4d=rot_4d, force_sh_3d=force_sh_3d, sh_degree_t=2 if pipe.eval_shfs_4d else 0)

assert checkpoint, "No checkpoint provided for validation"
scene = Scene(dataset, gaussians, shuffle=False,num_pts=num_pts, num_pts_ratio=num_pts_ratio, time_duration=time_duration)

(model_params, first_iter) = torch.load(checkpoint)
train_dir = os.path.join(dataset.model_path, 'train', "ours_{}".format(first_iter))
test_dir = os.path.join(dataset.model_path, 'test', "ours_{}".format(first_iter))
gaussians.restore(model_params, None)
gaussExtractor = GaussianExtractor(gaussians, render, pipe, bg_color=bg_color)

######### 1. Validation and Rendering ############

print("export rendered testing images ...")
os.makedirs(test_dir, exist_ok=True)
gaussExtractor.reconstruction(scene.getTestCameras(),test_dir,stage = "validation")
gaussExtractor.export_image(test_dir,mode = "validation")

######### 2. Render Trajectory ############

print("rendering trajectory ...")
traj_dir = os.path.join(test_dir, 'traj')
os.makedirs(traj_dir, exist_ok=True)
n_fames = 480
cam_traj = generate_path(scene.getTrainCameras(), n_frames=n_fames)
gaussExtractor.reconstruction(cam_traj, test_dir,stage = "trajectory")
gaussExtractor.export_image(traj_dir,mode = "trajectory")
create_videos( base_dir =traj_dir,
input_dir=traj_dir,
out_name='render_traj',
num_frames=n_fames)

def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint, debug_from,
gaussian_dim, time_duration, num_pts, num_pts_ratio, rot_4d, force_sh_3d, batch_size):

Expand Down Expand Up @@ -373,6 +417,7 @@ def setup_seed(seed):
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--seed", type=int, default=6666)
parser.add_argument("--exhaust_test", action="store_true")
parser.add_argument("--val", action="store_true", default=False)

args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
Expand All @@ -399,8 +444,13 @@ def recursive_merge(key, host):
safe_state(args.quiet)

torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.start_checkpoint, args.debug_from,
args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d, args.batch_size)
if cfg.val == False:
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.start_checkpoint, args.debug_from,
args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d, args.batch_size)

else:
validation(lp.extract(args), op.extract(args), pp.extract(args),args.start_checkpoint,args.gaussian_dim,
args.time_duration,args.rot_4d, args.force_sh_3d, args.num_pts, args.num_pts_ratio)


# All done
print("\nTraining complete.")
print("\nComplete.")
2 changes: 2 additions & 0 deletions scene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration

def save(self, iteration):
torch.save((self.gaussians.capture(), iteration), self.model_path + "/chkpnt" + str(iteration) + ".pth")
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))

def getTrainCameras(self, scale=1.0):
return CameraDataset(self.train_cameras[scale].copy(), self.white_background)
Expand Down
42 changes: 42 additions & 0 deletions scene/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,48 @@ def get_current_covariance_and_mean_offset(self, scaling_modifier = 1, timestamp
self._rotation_r,
dt = timestamp - self.get_t)

def construct_list_of_attributes(self):
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
# All channels except the 3 DC
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
l.append('f_dc_{}'.format(i))

if self.active_sh_degree > 0:
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
l.append('f_rest_{}'.format(i))
l.append('opacity')
for i in range(self._scaling.shape[1]):
l.append('scale_{}'.format(i))
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
return l

def save_ply(self, path):
mkdir_p(os.path.dirname(path))
xyz = self.get_xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
if self.active_sh_degree != 0:
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = self._opacity.detach().cpu().numpy()
scale = self._scaling.detach().cpu().numpy()
rotation = self._rotation.detach().cpu().numpy()

dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

elements = np.empty(xyz.shape[0], dtype=dtype_full)

# TODO: may need to add empty shs for SIBR_viewer?
if self.active_sh_degree > 0:
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
else:
attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)



def oneupSHdegree(self):
if self.active_sh_degree < self.max_sh_degree:
self.active_sh_degree += 1
Expand Down
1 change: 0 additions & 1 deletion utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,3 @@ def easy_cmap(x: torch.Tensor):
x_rgb[1] = torch.clamp(x_normalize, 0, 1)
x_rgb[2] = torch.clamp(x_normalize, 0, 1)
return x_rgb

Loading