-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor RGBD Pixel Kernel #161
Conversation
b0cbc5e
to
e677b82
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes look good to me. I left some inline comments for some potential improvements, but I'm also okay with just merging this PR as-is so we can move fast :). Since we have the commit history anyways, we can also revert and make changes in the future if needed.
total_log_prob = 0.0 | ||
|
||
is_depth_non_return = observed_rgbd[3] == 0.0 | ||
|
||
# Is visible | ||
total_visible_log_prob = 0.0 | ||
# color term | ||
total_visible_log_prob += self.inlier_color_distribution.logpdf( | ||
observed_rgbd[:3], latent_rgbd[:3], color_scale | ||
) | ||
# depth term | ||
total_visible_log_prob += jnp.where( | ||
is_depth_non_return, | ||
jnp.log(depth_nonreturn_prob), | ||
jnp.log(1 - depth_nonreturn_prob) | ||
+ self.inlier_depth_distribution.logpdf( | ||
observed_rgbd[3], latent_rgbd[3], depth_scale | ||
), | ||
) | ||
|
||
# Is not visible | ||
total_not_visible_log_prob = 0.0 | ||
# color term | ||
outlier_color_log_prob = self.outlier_color_distribution.logpdf( | ||
observed_rgbd[:3], latent_rgbd[:3], color_scale | ||
) | ||
outlier_depth_log_prob = self.outlier_depth_distribution.logpdf( | ||
observed_rgbd[3], latent_rgbd[3], depth_scale | ||
) | ||
|
||
total_not_visible_log_prob += outlier_color_log_prob | ||
# depth term | ||
total_not_visible_log_prob += jnp.where( | ||
is_depth_non_return, | ||
jnp.log(depth_nonreturn_prob), | ||
jnp.log(1 - depth_nonreturn_prob) + outlier_depth_log_prob, | ||
) | ||
|
||
total_log_prob += jnp.logaddexp( | ||
jnp.log(visibility_prob) + total_visible_log_prob, | ||
jnp.log(1 - visibility_prob) + total_not_visible_log_prob, | ||
) | ||
depth_logpdf = self.depth_kernel.logpdf( | ||
observed_rgbd[3], | ||
latent_rgbd[3], | ||
depth_scale, | ||
visibility_prob, | ||
depth_nonreturn_prob, | ||
return jnp.where( | ||
jnp.any(is_unexplained(latent_rgbd)), | ||
outlier_color_log_prob + outlier_depth_log_prob, | ||
total_log_prob, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Niice. I think we might be able to implement some of these mixture terms using the mixture distribution class, so that we can get both logpdf
and sample
for free, but we can do that in a future PR.
(actually if you don't mind, I can also do some of the clean ups after you merge this)
from b3d.chisight.gen3d.pixel_kernels.pixel_depth_kernels import ( | ||
FullPixelDepthDistribution, | ||
) | ||
# import b3d.chisight.gen3d.inference_moves as im |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind marking these tests as skipping (and maybe also pytest.importorskip) instead of commenting them out for now? Just so that we have a reminder to come back and clean these up in the future
@nishadgothoskar I made the inference changes we discussed, and the I noticed that the log q scores for one of the proposals are causing it to break (and they had been before as well), so I will work on debugging this this evening. But this should hopefully unblock you to continue developing your tests! If it would be helpful to discuss testing strategy or anything else this evening, in light of this changed inference code, please let me know! |
No description provided.