diff --git a/configs/dynerf/cook_spinach.yaml b/configs/dynerf/cook_spinach.yaml index 08bd80d..15a6218 100644 --- a/configs/dynerf/cook_spinach.yaml +++ b/configs/dynerf/cook_spinach.yaml @@ -6,10 +6,11 @@ rot_4d: True force_sh_3d: False batch_size: 4 exhaust_test: True +start_checkpoint: "output/N3V/cook_spinach/chkpnt_best.pth" ModelParams: sh_degree: 3 - source_path: "data/N3V/cook_spinach" + source_path: "/home/zhangyidi/Projects/4DGS/data/N3V/cook_spinach" model_path: "output/N3V/cook_spinach" images: "images" resolution: 2 diff --git a/configs/dynerf/cook_spinach_test.yaml b/configs/dynerf/cook_spinach_test.yaml new file mode 100644 index 0000000..ce12fd6 --- /dev/null +++ b/configs/dynerf/cook_spinach_test.yaml @@ -0,0 +1,61 @@ +gaussian_dim: 4 +time_duration: [0.0, 10.0] +num_pts: 300_000 +num_pts_ratio: 1.0 +rot_4d: True +force_sh_3d: False +batch_size: 4 +exhaust_test: True +start_checkpoint: "output/N3V/cook_spinach/chkpnt_best.pth" +val: True + +ModelParams: + sh_degree: 3 + source_path: "/home/zhangyidi/Projects/4DGS/data/N3V/cook_spinach" + model_path: "output/N3V/cook_spinach" + images: "images" + resolution: 2 + white_background: False + data_device: "cuda" + eval: True + extension: ".png" + num_extra_pts: 0 + loaded_pth: "" + frame_ratio: 1 + dataloader: True + +PipelineParams: + convert_SHs_python: False + compute_cov3D_python: False + debug: False + env_map_res: 0 + env_optimize_until: 1000000000 + env_optimize_from: 0 + eval_shfs_4d: True + +OptimizationParams: + iterations: 30_000 + position_lr_init: 0.00016 + position_t_lr_init: -1.0 + position_lr_final: 0.0000016 + position_lr_delay_mult: 0.01 + position_lr_max_steps: 30_000 + feature_lr: 0.0025 + opacity_lr: 0.05 + scaling_lr: 0.005 + rotation_lr: 0.001 + percent_dense: 0.01 + lambda_dssim: 0.2 + thresh_opa_prune: 0.005 + densification_interval: 100 + opacity_reset_interval: 3000 + densify_from_iter: 500 + densify_until_iter: 15_000 + densify_grad_threshold: 0.0002 + densify_grad_t_threshold: 0.0002 / 40 # 想办法用上 + densify_until_num_points: -1 + final_prune_from_iter: -1 + sh_increase_interval: 1000 + lambda_opa_mask: 0.0 + lambda_rigid: 0.0 + lambda_motion: 0.0 diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index 0576be8..535e288 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -15,6 +15,7 @@ from .diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer from scene.gaussian_model import GaussianModel from utils.sh_utils import eval_sh, eval_shfs_4d +from collections import defaultdict def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None): """ @@ -145,6 +146,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, flow_2d = flow_2d[mask] # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, depth, alpha, flow, covs_com = rasterizer( means3D = means3D, means2D = means2D, @@ -158,6 +160,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, rotations = rotations, rotations_r = rotations_r, cov3D_precomp = cov3D_precomp) + if pipe.env_map_res: assert pc.env_map is not None diff --git a/train.py b/main.py similarity index 88% rename from train.py rename to main.py index 72546d4..f88be71 100644 --- a/train.py +++ b/main.py @@ -28,12 +28,56 @@ from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from torch.utils.data import DataLoader + +from utils.mesh_utils import GaussianExtractor, to_cam_open3d, post_process_mesh +from utils.render_utils import generate_path, create_videos +import open3d as o3d + try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_FOUND = True except ImportError: TENSORBOARD_FOUND = False +def validation(dataset, opt, pipe,checkpoint, gaussian_dim, time_duration, rot_4d, force_sh_3d, + num_pts, num_pts_ratio): + + bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + gaussians = GaussianModel(dataset.sh_degree, gaussian_dim=gaussian_dim, time_duration=time_duration, + rot_4d=rot_4d, force_sh_3d=force_sh_3d, sh_degree_t=2 if pipe.eval_shfs_4d else 0) + + assert checkpoint, "No checkpoint provided for validation" + scene = Scene(dataset, gaussians, shuffle=False,num_pts=num_pts, num_pts_ratio=num_pts_ratio, time_duration=time_duration) + + (model_params, first_iter) = torch.load(checkpoint) + train_dir = os.path.join(dataset.model_path, 'train', "ours_{}".format(first_iter)) + test_dir = os.path.join(dataset.model_path, 'test', "ours_{}".format(first_iter)) + gaussians.restore(model_params, None) + gaussExtractor = GaussianExtractor(gaussians, render, pipe, bg_color=bg_color) + + ######### 1. Validation and Rendering ############ + + print("export rendered testing images ...") + os.makedirs(test_dir, exist_ok=True) + gaussExtractor.reconstruction(scene.getTestCameras(),test_dir,stage = "validation") + gaussExtractor.export_image(test_dir,mode = "validation") + + ######### 2. Render Trajectory ############ + + print("rendering trajectory ...") + traj_dir = os.path.join(test_dir, 'traj') + os.makedirs(traj_dir, exist_ok=True) + n_fames = 480 + cam_traj = generate_path(scene.getTrainCameras(), n_frames=n_fames) + gaussExtractor.reconstruction(cam_traj, test_dir,stage = "trajectory") + gaussExtractor.export_image(traj_dir,mode = "trajectory") + create_videos( base_dir =traj_dir, + input_dir=traj_dir, + out_name='render_traj', + num_frames=n_fames) + def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint, debug_from, gaussian_dim, time_duration, num_pts, num_pts_ratio, rot_4d, force_sh_3d, batch_size): @@ -373,6 +417,7 @@ def setup_seed(seed): parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--seed", type=int, default=6666) parser.add_argument("--exhaust_test", action="store_true") + parser.add_argument("--val", action="store_true", default=False) args = parser.parse_args(sys.argv[1:]) args.save_iterations.append(args.iterations) @@ -399,8 +444,13 @@ def recursive_merge(key, host): safe_state(args.quiet) torch.autograd.set_detect_anomaly(args.detect_anomaly) - training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.start_checkpoint, args.debug_from, - args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d, args.batch_size) + if cfg.val == False: + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.start_checkpoint, args.debug_from, + args.gaussian_dim, args.time_duration, args.num_pts, args.num_pts_ratio, args.rot_4d, args.force_sh_3d, args.batch_size) + + else: + validation(lp.extract(args), op.extract(args), pp.extract(args),args.start_checkpoint,args.gaussian_dim, + args.time_duration,args.rot_4d, args.force_sh_3d, args.num_pts, args.num_pts_ratio) + - # All done - print("\nTraining complete.") + print("\nComplete.") diff --git a/scene/__init__.py b/scene/__init__.py index ccb10a5..145c0c3 100644 --- a/scene/__init__.py +++ b/scene/__init__.py @@ -90,6 +90,8 @@ def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration def save(self, iteration): torch.save((self.gaussians.capture(), iteration), self.model_path + "/chkpnt" + str(iteration) + ".pth") + point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) + self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) def getTrainCameras(self, scale=1.0): return CameraDataset(self.train_cameras[scale].copy(), self.white_background) diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index 91cdae0..32af901 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -245,6 +245,48 @@ def get_current_covariance_and_mean_offset(self, scaling_modifier = 1, timestamp self._rotation_r, dt = timestamp - self.get_t) + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + + if self.active_sh_degree > 0: + for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + xyz = self.get_xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + if self.active_sh_degree != 0: + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + + # TODO: may need to add empty shs for SIBR_viewer? + if self.active_sh_degree > 0: + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + else: + attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + + + def oneupSHdegree(self): if self.active_sh_degree < self.max_sh_degree: self.active_sh_degree += 1 diff --git a/utils/image_utils.py b/utils/image_utils.py index 5494fac..555298f 100644 --- a/utils/image_utils.py +++ b/utils/image_utils.py @@ -26,4 +26,3 @@ def easy_cmap(x: torch.Tensor): x_rgb[1] = torch.clamp(x_normalize, 0, 1) x_rgb[2] = torch.clamp(x_normalize, 0, 1) return x_rgb - diff --git a/utils/mesh_utils.py b/utils/mesh_utils.py new file mode 100644 index 0000000..2ee9c35 --- /dev/null +++ b/utils/mesh_utils.py @@ -0,0 +1,329 @@ +# +# Copyright (C) 2024, ShanghaiTech +# SVIP research group, https://github.com/svip-lab +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact huangbb@shanghaitech.edu.cn +# + +import torch +import numpy as np +import os +import math +from tqdm import tqdm +from utils.render_utils import save_img_f32, save_img_u8 +from functools import partial +import open3d as o3d +import trimesh +from collections import defaultdict +from utils.loss_utils import ssim +from utils.image_utils import psnr +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +import json + +def post_process_mesh(mesh, cluster_to_keep=1000): + """ + Post-process a mesh to filter out floaters and disconnected parts + """ + import copy + print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep)) + mesh_0 = copy.deepcopy(mesh) + with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm: + triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles()) + + triangle_clusters = np.asarray(triangle_clusters) + cluster_n_triangles = np.asarray(cluster_n_triangles) + cluster_area = np.asarray(cluster_area) + n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep] + n_cluster = max(n_cluster, 50) # filter meshes smaller than 50 + triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster + mesh_0.remove_triangles_by_mask(triangles_to_remove) + mesh_0.remove_unreferenced_vertices() + mesh_0.remove_degenerate_triangles() + print("num vertices raw {}".format(len(mesh.vertices))) + print("num vertices post {}".format(len(mesh_0.vertices))) + return mesh_0 + +def to_cam_open3d(viewpoint_stack): + camera_traj = [] + for i, viewpoint_cam in enumerate(viewpoint_stack): + W = viewpoint_cam.image_width + H = viewpoint_cam.image_height + ndc2pix = torch.tensor([ + [W / 2, 0, 0, (W-1) / 2], + [0, H / 2, 0, (H-1) / 2], + [0, 0, 0, 1]]).float().cuda().T + intrins = (viewpoint_cam.projection_matrix @ ndc2pix)[:3,:3].T + intrinsic=o3d.camera.PinholeCameraIntrinsic( + width=viewpoint_cam.image_width, + height=viewpoint_cam.image_height, + cx = intrins[0,2].item(), + cy = intrins[1,2].item(), + fx = intrins[0,0].item(), + fy = intrins[1,1].item() + ) + + extrinsic=np.asarray((viewpoint_cam.world_view_transform.T).cpu().numpy()) + camera = o3d.camera.PinholeCameraParameters() + camera.extrinsic = extrinsic + camera.intrinsic = intrinsic + camera_traj.append(camera) + + return camera_traj + + +class GaussianExtractor(object): + def __init__(self, gaussians, render, pipe, bg_color=None): + """ + a class that extracts attributes a scene presented by 2DGS + + Usage example: + >>> gaussExtrator = GaussianExtractor(gaussians, render, pipe) + >>> gaussExtrator.reconstruction(view_points) + >>> mesh = gaussExtractor.export_mesh_bounded(...) + """ + if bg_color is None: + bg_color = [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + self.gaussians = gaussians + self.render = partial(render, pipe=pipe, bg_color=background) + self.clean() + + @torch.no_grad() + def clean(self): + self.depthmaps = [] + # self.alphamaps = [] + self.rgbmaps = [] + # self.normals = [] + # self.depth_normals = [] + self.viewpoint_stack = [] + self.flowmaps=[] + + @torch.no_grad() + def reconstruction(self, viewpoint_stack, model_path , stage = "validation"): + """ + reconstruct radiance field given cameras + """ + self.clean() + self.viewpoint_stack = viewpoint_stack + metrics = defaultdict(list) + + lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to("cuda") + + for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="reconstruct radiance fields"): + if stage == "validation": + render_pkg = self.render(viewpoint_cam[1].cuda(), self.gaussians) + gt_image = viewpoint_cam[0].cuda() + else: + render_pkg = self.render(viewpoint_cam.cuda(), self.gaussians) + + rgb = render_pkg['render'] + alpha = render_pkg['alpha'] + flow = render_pkg['flow'] + flow = torch.cat([flow, torch.zeros_like(flow[:1])], dim=0) + depth = render_pkg['depth'] + + self.rgbmaps.append(rgb.cpu()) + self.depthmaps.append(depth.cpu()) + self.flowmaps.append(flow.cpu()) + + if stage == "validation": + metrics["psnr"].append(psnr(gt_image, rgb)) + metrics["ssim"].append(ssim(gt_image, rgb)) + # metrics["lpips"].append(lpips(gt_image.unsqueeze(0), rgb.unsqueeze(0))) + if stage == "validation": + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "num_GS": self.gaussians.get_xyz.shape[0], + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f} " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + stats_dir = os.path.join(model_path, "stats") + os.makedirs(stats_dir, exist_ok=True) + with open(f"{stats_dir}/{stage}.json", "w") as f: + json.dump(stats, f) + + def estimate_bounding_sphere(self): + """ + Estimate the bounding sphere given camera pose + """ + from utils.render_utils import transform_poses_pca, focus_point_fn + torch.cuda.empty_cache() + c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in self.viewpoint_stack]) + poses = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1]) + center = (focus_point_fn(poses)) + self.radius = np.linalg.norm(c2ws[:,:3,3] - center, axis=-1).min() + self.center = torch.from_numpy(center).float().cuda() + print(f"The estimated bounding radius is {self.radius:.2f}") + print(f"Use at least {2.0 * self.radius:.2f} for depth_trunc") + + @torch.no_grad() + def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_backgrond=True): + """ + Perform TSDF fusion given a fixed depth range, used in the paper. + + voxel_size: the voxel size of the volume + sdf_trunc: truncation value + depth_trunc: maximum depth range, should depended on the scene's scales + mask_backgrond: whether to mask backgroud, only works when the dataset have masks + + return o3d.mesh + """ + print("Running tsdf volume integration ...") + print(f'voxel_size: {voxel_size}') + print(f'sdf_trunc: {sdf_trunc}') + print(f'depth_truc: {depth_trunc}') + + volume = o3d.pipelines.integration.ScalableTSDFVolume( + voxel_length= voxel_size, + sdf_trunc=sdf_trunc, + color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 + ) + + for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc="TSDF integration progress"): + rgb = self.rgbmaps[i] + depth = self.depthmaps[i] + # flow=self.flowmaps[i] + + # if we have mask provided, use it + if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None): + depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0 + + # make open3d rgbd + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + o3d.geometry.Image(np.asarray(rgb.permute(1,2,0).cpu().numpy() * 255, order="C", dtype=np.uint8)), + o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order="C")), + depth_trunc = depth_trunc, convert_rgb_to_intensity=False, + depth_scale = 1.0 + ) + + volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic) + + mesh = volume.extract_triangle_mesh() + return mesh + + @torch.no_grad() + def extract_mesh_unbounded(self, resolution=1024): + """ + Experimental features, extracting meshes from unbounded scenes, not fully test across datasets. + return o3d.mesh + """ + def contract(x): + mag = torch.linalg.norm(x, ord=2, dim=-1)[..., None] + return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag)) + + def uncontract(y): + mag = torch.linalg.norm(y, ord=2, dim=-1)[..., None] + return torch.where(mag < 1, y, (1 / (2-mag) * (y/mag))) + + def compute_sdf_perframe(i, points, depthmap, rgbmap, viewpoint_cam): + """ + compute per frame sdf + """ + new_points = torch.cat([points, torch.ones_like(points[...,:1])], dim=-1) @ viewpoint_cam.full_proj_transform + z = new_points[..., -1:] + pix_coords = (new_points[..., :2] / new_points[..., -1:]) + mask_proj = ((pix_coords > -1. ) & (pix_coords < 1.) & (z > 0)).all(dim=-1) + sampled_depth = torch.nn.functional.grid_sample(depthmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(-1, 1) + sampled_rgb = torch.nn.functional.grid_sample(rgbmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T + sdf = (sampled_depth-z) + return sdf, sampled_rgb, mask_proj + + def compute_unbounded_tsdf(samples, inv_contraction, voxel_size, return_rgb=False): + """ + Fusion all frames, perform adaptive sdf_funcation on the contract spaces. + """ + if inv_contraction is not None: + mask = torch.linalg.norm(samples, dim=-1) > 1 + # adaptive sdf_truncation + sdf_trunc = 5 * voxel_size * torch.ones_like(samples[:, 0]) + sdf_trunc[mask] *= 1/(2-torch.linalg.norm(samples, dim=-1)[mask].clamp(max=1.9)) + samples = inv_contraction(samples) + else: + sdf_trunc = 5 * voxel_size + + tsdfs = torch.ones_like(samples[:,0]) * 1 + rgbs = torch.zeros((samples.shape[0], 3)).cuda() + + weights = torch.ones_like(samples[:,0]) + for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="TSDF integration progress"): + sdf, rgb, mask_proj = compute_sdf_perframe(i, samples, + depthmap = self.depthmaps[i], + rgbmap = self.rgbmaps[i], + viewpoint_cam=self.viewpoint_stack[i], + ) + + # volume integration + sdf = sdf.flatten() + mask_proj = mask_proj & (sdf > -sdf_trunc) + sdf = torch.clamp(sdf / sdf_trunc, min=-1.0, max=1.0)[mask_proj] + w = weights[mask_proj] + wp = w + 1 + tsdfs[mask_proj] = (tsdfs[mask_proj] * w + sdf) / wp + rgbs[mask_proj] = (rgbs[mask_proj] * w[:,None] + rgb[mask_proj]) / wp[:,None] + # update weight + weights[mask_proj] = wp + + if return_rgb: + return tsdfs, rgbs + + return tsdfs + + normalize = lambda x: (x - self.center) / self.radius + unnormalize = lambda x: (x * self.radius) + self.center + inv_contraction = lambda x: unnormalize(uncontract(x)) + + N = resolution + voxel_size = (self.radius * 2 / N) + print(f"Computing sdf gird resolution {N} x {N} x {N}") + print(f"Define the voxel_size as {voxel_size}") + sdf_function = lambda x: compute_unbounded_tsdf(x, inv_contraction, voxel_size) + from utils.mcube_utils import marching_cubes_with_contraction + R = contract(normalize(self.gaussians.get_xyz)).norm(dim=-1).cpu().numpy() + R = np.quantile(R, q=0.95) + R = min(R+0.01, 1.9) + + mesh = marching_cubes_with_contraction( + sdf=sdf_function, + bounding_box_min=(-R, -R, -R), + bounding_box_max=(R, R, R), + level=0, + resolution=N, + inv_contraction=inv_contraction, + ) + + # coloring the mesh + torch.cuda.empty_cache() + mesh = mesh.as_open3d + print("texturing mesh ... ") + _, rgbs = compute_unbounded_tsdf(torch.tensor(np.asarray(mesh.vertices)).float().cuda(), inv_contraction=None, voxel_size=voxel_size, return_rgb=True) + mesh.vertex_colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy()) + return mesh + + @torch.no_grad() + def export_image(self, path,mode="validation"): + render_path = os.path.join(path, "renders") + gts_path = os.path.join(path, "gt") + vis_path = os.path.join(path, "vis") + flow_path=os.path.join(path,"flow") + os.makedirs(render_path, exist_ok=True) + os.makedirs(vis_path, exist_ok=True) + os.makedirs(gts_path, exist_ok=True) + os.makedirs(flow_path,exist_ok=True) + for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc="export images"): + if mode == "validation" and viewpoint_cam[0] is not None: + gt = viewpoint_cam[0][0:3, :, :] + save_img_u8(gt.permute(1,2,0).cpu().numpy(), os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) + save_img_u8(self.rgbmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) + save_img_u8(self.flowmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(flow_path, '{0:05d}'.format(idx) + ".png")) + save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + ".tiff")) \ No newline at end of file diff --git a/utils/render_utils.py b/utils/render_utils.py new file mode 100644 index 0000000..00db9e2 --- /dev/null +++ b/utils/render_utils.py @@ -0,0 +1,289 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import os +import enum +import types +from typing import List, Mapping, Optional, Text, Tuple, Union +import copy +from PIL import Image +import mediapy as media +from matplotlib import cm +from tqdm import tqdm + +import torch + +def normalize(x: np.ndarray) -> np.ndarray: + """Normalization helper function.""" + return x / np.linalg.norm(x) + +def pad_poses(p: np.ndarray) -> np.ndarray: + """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" + bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) + return np.concatenate([p[..., :3, :4], bottom], axis=-2) + + +def unpad_poses(p: np.ndarray) -> np.ndarray: + """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" + return p[..., :3, :4] + + +def recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Recenter poses around the origin.""" + cam2world = average_pose(poses) + transform = np.linalg.inv(pad_poses(cam2world)) + poses = transform @ pad_poses(poses) + return unpad_poses(poses), transform + + +def average_pose(poses: np.ndarray) -> np.ndarray: + """New pose using average position, z-axis, and up vector of input poses.""" + position = poses[:, :3, 3].mean(0) + z_axis = poses[:, :3, 2].mean(0) + up = poses[:, :3, 1].mean(0) + cam2world = viewmatrix(z_axis, up, position) + return cam2world + +def viewmatrix(lookdir: np.ndarray, up: np.ndarray, + position: np.ndarray) -> np.ndarray: + """Construct lookat view matrix.""" + vec2 = normalize(lookdir) + vec0 = normalize(np.cross(up, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, position], axis=1) + return m + +def focus_point_fn(poses: np.ndarray) -> np.ndarray: + """Calculate nearest point to all focal axes in poses.""" + directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] + m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) + mt_m = np.transpose(m, [0, 2, 1]) @ m + focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] + return focus_pt + +def transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Transforms poses so principal components lie on XYZ axes. + + Args: + poses: a (N, 3, 4) array containing the cameras' camera to world transforms. + + Returns: + A tuple (poses, transform), with the transformed poses and the applied + camera_to_world transforms. + """ + t = poses[:, :3, 3] + t_mean = t.mean(axis=0) + t = t - t_mean + + eigval, eigvec = np.linalg.eig(t.T @ t) + # Sort eigenvectors in order of largest to smallest eigenvalue. + inds = np.argsort(eigval)[::-1] + eigvec = eigvec[:, inds] + rot = eigvec.T + if np.linalg.det(rot) < 0: + rot = np.diag(np.array([1, 1, -1])) @ rot + + transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) + poses_recentered = unpad_poses(transform @ pad_poses(poses)) + transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) + + # Flip coordinate system if z component of y-axis is negative + if poses_recentered.mean(axis=0)[2, 1] < 0: + poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered + transform = np.diag(np.array([1, -1, -1, 1])) @ transform + + return poses_recentered, transform + # points = np.random.rand(3,100) + # points_h = np.concatenate((points,np.ones_like(points[:1])), axis=0) + # (poses_recentered @ points_h)[0] + # (transform @ pad_poses(poses) @ points_h)[0,:3] + # import pdb; pdb.set_trace() + + # # Just make sure it's it in the [-1, 1]^3 cube + # scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) + # poses_recentered[:, :3, 3] *= scale_factor + # transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform + + # return poses_recentered, transform + + +def generate_ellipse_path(poses: np.ndarray, + n_frames: int = 120, + const_speed: bool = True, + z_variation: float = 0., + z_phase: float = 0.) -> np.ndarray: + """Generate an elliptical render path based on the given poses.""" + # Calculate the focal point for the path (cameras point toward this). + center = focus_point_fn(poses) + # Path height sits at z=0 (in middle of zero-mean capture pattern). + offset = np.array([center[0], center[1], 0]) + + # Calculate scaling for ellipse axes based on input camera positions. + sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) + # Use ellipse that is symmetric about the focal point in xy. + low = -sc + offset + high = sc + offset + # Optional height variation need not be symmetric + z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) + z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) + + def get_positions(theta): + # Interpolate between bounds with trig functions to get ellipse in x-y. + # Optionally also interpolate in z to change camera height along path. + return np.stack([ + low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5), + low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5), + z_variation * (z_low[2] + (z_high - z_low)[2] * + (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)), + ], -1) + + theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True) + positions = get_positions(theta) + + #if const_speed: + + # # Resample theta angles so that the velocity is closer to constant. + # lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) + # theta = stepfun.sample(None, theta, np.log(lengths), n_frames + 1) + # positions = get_positions(theta) + + # Throw away duplicated last position. + positions = positions[:-1] + + # Set path's up vector to axis closest to average of input pose up vectors. + avg_up = poses[:, :3, 1].mean(0) + avg_up = avg_up / np.linalg.norm(avg_up) + ind_up = np.argmax(np.abs(avg_up)) + up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) + + return np.stack([viewmatrix(p - center, up, p) for p in positions]) + + +def generate_path(viewpoint_cameras, n_frames=480): + viewpoint_cameras = [viewpoint_cameras[i] for i in range(0, len(viewpoint_cameras), 300)] + + c2ws = np.array([np.linalg.inv(np.asarray((cam[1].world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras]) + pose = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1]) + pose_recenter, colmap_to_world_transform = transform_poses_pca(pose) + + # generate new poses + new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames) + # warp back to orignal scale + new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses) + + traj = [] + for i, c2w in enumerate(new_poses): + c2w = c2w @ np.diag([1, -1, -1, 1]) + cam = copy.deepcopy(viewpoint_cameras[0][1]) + cam.image_height = int(cam.image_height / 2) * 2 + cam.image_width = int(cam.image_width / 2) * 2 + cam.world_view_transform = torch.from_numpy(np.linalg.inv(c2w).T).float().cuda() + cam.full_proj_transform = (cam.world_view_transform.unsqueeze(0).bmm(cam.projection_matrix.cuda().unsqueeze(0))).squeeze(0) + cam.camera_center = cam.world_view_transform.inverse()[3, :3] + cam.timestamp =10.0/n_frames*i + traj.append((cam)) + + return traj + +def load_img(pth: str) -> np.ndarray: + """Load an image and cast to float32.""" + with open(pth, 'rb') as f: + image = np.array(Image.open(f), dtype=np.float32) + return image + + +def create_videos(base_dir, input_dir, out_name, num_frames=480): + """Creates videos out of the images saved to disk.""" + # Last two parts of checkpoint path are experiment name and scene name. + video_prefix = f'{out_name}' + zpad = max(5, len(str(num_frames - 1))) + idx_to_str = lambda idx: str(idx).zfill(zpad) + + os.makedirs(base_dir, exist_ok=True) + render_dist_curve_fn = np.log + + # Load one example frame to get image shape and depth range. + depth_file = os.path.join(input_dir, 'vis', f'depth_{idx_to_str(0)}.tiff') + depth_frame = load_img(depth_file) + shape = depth_frame.shape + p = 3 + distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) + lo, hi = [render_dist_curve_fn(x) for x in distance_limits] + print(f'Video shape is {shape[:2]}') + + video_kwargs = { + 'shape': shape[:2], + 'codec': 'h264', + 'fps': 48, + 'crf': 18, + } + + for k in ['depth', 'normal', 'color','flow']: + video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') + input_format = 'gray' if k == 'alpha' else 'rgb' + + + file_ext = 'png' if k in ['color', 'normal','flow'] else 'tiff' + idx = 0 + + if k == 'color': + file0 = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.{file_ext}') + elif k=='flow': + file0 = os.path.join(input_dir, 'flow', f'{k}_{idx_to_str(0)}.{file_ext}') + else: + file0 = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(0)}.{file_ext}') + + if not os.path.exists(file0): + print(f'Images missing for tag {k}') + continue + print(f'Making video {video_file}...') + with media.VideoWriter( + video_file, **video_kwargs, input_format=input_format) as writer: + for idx in tqdm(range(num_frames)): + # img_file = os.path.join(input_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') + if k == 'color': + img_file = os.path.join(input_dir, 'renders', f'{idx_to_str(idx)}.{file_ext}') + elif k=='flow': + img_file = os.path.join(input_dir, 'flow', f'{idx_to_str(idx)}.{file_ext}') + else: + img_file = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(idx)}.{file_ext}') + + if not os.path.exists(img_file): + ValueError(f'Image file {img_file} does not exist.') + img = load_img(img_file) + + if k in ['color', 'normal','flow']: + img = img / 255. + elif k.startswith('depth'): + img = render_dist_curve_fn(img) + img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1) + img = cm.get_cmap('turbo')(img)[..., :3] + + frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) + writer.add_image(frame) + idx += 1 + +def save_img_u8(img, pth): + """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" + with open(pth, 'wb') as f: + Image.fromarray( + (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( + f, 'PNG') + + +def save_img_f32(depthmap, pth): + """Save an image (probably a depthmap) to disk as a float32 TIFF.""" + with open(pth, 'wb') as f: + Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF') \ No newline at end of file