Skip to content

Commit

Permalink
Add sqrt info matrix to some batch residuals. Mild max-mixture fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
vkorotkine committed Jan 15, 2025
1 parent eaeb96b commit 249ad9f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
1 change: 1 addition & 0 deletions navlie/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
"""

from .estimator import BatchEstimator
from .problem import Problem
2 changes: 1 addition & 1 deletion navlie/batch/gaussian_mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def mix_errors(

nonlinear_part = np.array(np.log(alpha_max / alpha_k)).reshape(-1)
nonlinear_part = np.sqrt(2) * np.sqrt(nonlinear_part)
e_mix = np.concatenate([linear_part, nonlinear_part])
e_mix = np.concatenate([linear_part.reshape(-1), nonlinear_part])

reused_values = {"alphas": alphas, "res_values": res_values}

Expand Down
32 changes: 25 additions & 7 deletions navlie/batch/residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def jacobian_fd(self, states: List[State], step_size=1e-6) -> List[np.ndarray]:
Parameters
----------
states : List[State]
Evaluation point of Jacobians, a list of states that
Evaluation point of Jacobians, a list of states that
the residual is a function of.
Returns
-------
List[np.ndarray]
Expand All @@ -78,7 +78,7 @@ def jacobian_fd(self, states: List[State], step_size=1e-6) -> List[np.ndarray]:
w.r.t states[0], the second element is the Jacobian of the residual w.r.t states[1], etc.
"""
jac_list: List[np.ndarray] = [None] * len(states)

# Compute the Jacobian for each state via finite difference
for state_num, X_bar in enumerate(states):
e_bar = self.evaluate(states)
Expand All @@ -100,13 +100,14 @@ def jacobian_fd(self, states: List[State], step_size=1e-6) -> List[np.ndarray]:
jac_list[state_num] = jac_fd

return jac_list

def sqrt_info_matrix(self, states: List[State]):
"""
Returns the information matrix
"""
pass


class PriorResidual(Residual):
"""
A generic prior error.
Expand Down Expand Up @@ -157,6 +158,7 @@ def sqrt_info_matrix(self, states: List[State]):
"""
return self._L


class ProcessResidual(Residual):
"""
A generic process residual.
Expand Down Expand Up @@ -211,16 +213,24 @@ def evaluate(
if compute_jacobians:
jac_list = [None] * len(states)
if compute_jacobians[0]:
jac_list[0] = -L.T @ self._process_model.jacobian(
x_km1, self._u, dt
)
jac_list[0] = -L.T @ self._process_model.jacobian(x_km1, self._u, dt)
if compute_jacobians[1]:
jac_list[1] = L.T @ x_k.minus_jacobian(x_k_hat)

return e, jac_list

return e

def sqrt_info_matrix(self, states: List[State]):
"""
Returns the square root of the information matrix
"""
x_km1 = states[0]
x_k = states[1]
dt = x_k.stamp - x_km1.stamp
L = self._process_model.sqrt_information(x_km1, self._u, dt)
return L


class MeasurementResidual(Residual):
"""
Expand Down Expand Up @@ -268,3 +278,11 @@ def evaluate(
return e, jacobians

return e

def sqrt_info_matrix(self, states: List[State]):
"""
Returns the square root of the information matrix
"""
x = states[0]
L = self._y.model.sqrt_information(x)
return L

0 comments on commit 249ad9f

Please sign in to comment.