diff --git a/models/renderer.py b/models/renderer.py index a509029..cae704b 100644 --- a/models/renderer.py +++ b/models/renderer.py @@ -319,13 +319,17 @@ def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb with torch.no_grad(): pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) - + for i in range(self.up_sample_steps): + _n_importance = self.n_importance // self.up_sample_steps + # add padding to the last step when it is not divisible + if ((i+1) == self.up_sample_steps): + _n_importance += (self.n_importance - (self.n_importance // self.up_sample_steps) * self.up_sample_steps) new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, - self.n_importance // self.up_sample_steps, + _n_importance, 64 * 2**i) z_vals, sdf = self.cat_z_vals(rays_o, rays_d,