Skip to content

Commit

Permalink
add preconditioner to sparse solver
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-a-cox committed Feb 14, 2025
1 parent 1403da3 commit b60325d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
49 changes: 49 additions & 0 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2899,6 +2899,8 @@ def sparse_linear_fit_2D(
atol: float = 1e-10,
btol: float = 1e-10,
iter_lim: int = None,
precondition_solver: bool = False,
eig_scaling_factor: float = 1e-1,
**kwargs
) -> np.ndarray:
"""
Expand Down Expand Up @@ -2927,6 +2929,20 @@ def sparse_linear_fit_2D(
flattened `data` array, x is the solution, and r is the residual.
iter_lim : int, optional
Maximum number of iterations for `lsqr`, default is None
precondition_solver : bool, optional, default False
If True, the solver will apply a preconditioner to the basis matrices before
solving the least-squares problem. The preconditioner is computed using the
the inverse of the regularized Gramian matrix of the basis matrices. Prior
computing the inverse, the eigenvalues of the Gramian matrix are regularized
by adding a small value proportional to the smallest eigenvalue. This helps
to stabilize the computation of the inverse. The regularization factor is
computed as the minimum eigenvalue of the Gramian matrix multiplied by the
`eig_scaling_factor` parameter.
eig_scaling_factor : float, optional, default 1e-1
Regularization factor for the eigenvalues of the Gramian matrix. The factor
is computed as the minimum eigenvalue of the Gramian matrix multiplied by
`eig_scaling_factor`. Reasonable values are typically in the range of 1e-1
to 1e-3.
**kwargs : dict
Additional keyword arguments passed to `scipy.sparse.linalg.lsqr`.
Expand Down Expand Up @@ -2960,6 +2976,36 @@ def sparse_linear_fit_2D(
axis_1_basis.shape[-1] * axis_2_basis.shape[-1], # i * j
)

if precondition_solver:
# Compute separate preconditioners for the two axes
# Start by computing separable weights for the two axes
u, s, v = np.linalg.svd(weights, full_matrices=False)
axis_1_wgts = np.abs(u[:, 0] * np.sqrt(s[0]))
axis_2_wgts = np.abs(v[0] * np.sqrt(s[0]))

# Compute the preconditioner for the first axis
XTX_axis_1 = np.dot(axis_1_basis.T.conj() * axis_1_wgts, axis_1_basis)
eigenval, _ = np.linalg.eig(XTX_axis_1)
axis_1_lambda = np.min(
eigenval[eigenval.real > np.finfo(eigenval.dtype).eps] * eig_scaling_factor
)
axis_1_pcond = np.linalg.pinv(
XTX_axis_1 + np.eye(XTX_axis_1.shape[0]) * axis_1_lambda
)

# Compute the preconditioner for the second axis
XTX_axis_2 = np.dot(axis_2_basis.T.conj() * axis_2_wgts, axis_2_basis)
eigenval, _ = np.linalg.eig(XTX_axis_2)
axis_2_lambda = np.min(
eigenval[eigenval.real > np.finfo(eigenval.dtype).eps] * eig_scaling_factor
)
axis_2_pcond = np.linalg.pinv(
XTX_axis_2 + np.eye(XTX_axis_2.shape[0]) * axis_2_lambda
)

axis_1_basis = np.dot(axis_1_basis, axis_1_pcond)
axis_2_basis = np.dot(axis_2_basis, axis_2_pcond)

# Define the implicit LinearOperator representing the Kronecker product
linear_operator = sparse.linalg.LinearOperator(
full_operator_shape,
Expand All @@ -2985,6 +3031,9 @@ def sparse_linear_fit_2D(
# Reshape output
x = x.reshape(axis_1_basis.shape[-1], axis_2_basis.shape[-1])

if precondition_solver:
x = np.dot(axis_1_pcond, x).dot(axis_2_pcond)

return x, meta

def separable_linear_fit_2D(
Expand Down
55 changes: 55 additions & 0 deletions hera_filters/tests/test_dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,3 +1535,58 @@ def test_sparse_linear_fit_2d_non_binary_wgts():

# Check that the fit closely matches to the separable fit
np.testing.assert_allclose(sol, sol_sparse, atol=1e-9, rtol=1e-6)

def test_precondition_sparse_solver():
# test that separable linear fit works as expected.
ntimes, nfreqs = 100, 50

# Generate some data/flags
# By construction, the data is separable in the time and frequency directions
# and the flags are also separable. The fit should be able to recover the
# true data in the unflagged region.
rng = np.random.default_rng(42)
freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12])
time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12])
time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1))
freq_flags = rng.choice([True, False], p=[0.1, 0.9], size=(1, nfreqs))
x_true = rng.normal(0, 1, size=(time_basis.shape[-1], freq_basis.shape[-1]))
data = np.dot(time_basis, x_true).dot(freq_basis.T)
freqs = np.linspace(100e6, 200e6, nfreqs)

# Generate separable, non-binary weights
axis_1_weights = (~time_flags[:, 0]).astype(float) * rng.integers(1, 10, size=(ntimes,))
axis_2_weights = (~freq_flags[0]).astype(float)
wgts = np.outer(axis_1_weights, axis_2_weights)

# Add frequency dependence to the weights to make the problem more ill-conditioned
wgts *= (freqs / 150e6) ** -3.5

# Fit the data
sol = dspec.separable_linear_fit_2D(
data=data,
axis_1_weights=(~time_flags[:, 0]).astype(float),
axis_2_weights=(~freq_flags[0]).astype(float),
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
)

sol_sparse, meta = dspec.sparse_linear_fit_2D(
data=data,
weights=wgts,
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
precondition_solver=False
)

sol_sparse_precond, meta_precond = dspec.sparse_linear_fit_2D(
data=data,
weights=wgts,
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
precondition_solver=True
)

# Check that the fit closely matches to the separable fit
np.testing.assert_allclose(sol, sol_sparse, atol=1e-8, rtol=1e-6)
np.testing.assert_allclose(sol, sol_sparse_precond, atol=1e-8, rtol=1e-6)
np.testing.assert_array_less(meta_precond['iter_num'], meta['iter_num'])

0 comments on commit b60325d

Please sign in to comment.