Skip to content

Commit

Permalink
Merge pull request #9 from munechika-koyo/develop
Browse files Browse the repository at this point in the history
Refactor MFR loop
  • Loading branch information
munechika-koyo authored Nov 18, 2023
2 parents a409dc8 + c9dc517 commit bdfea95
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 42 deletions.
2 changes: 1 addition & 1 deletion cherab/inversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .mfr import Mfr

__all__ = ["compute_svd", "_SVDBase", "Lcurve", "GCV", "Mfr"]
__version__ = "0.1.2.dev0"
__version__ = "0.1.2"
94 changes: 57 additions & 37 deletions cherab/inversion/mfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def solve(
if not issubclass(regularizer, _SVDBase):
raise TypeError("regularizer must be a subclass of _SVDBase")

# check data attribute
if self._data is None:
raise ValueError("data attribute is not set")

# check initial solution
if x0 is None:
x0 = np.ones(self._gmat.shape[1])
Expand All @@ -212,50 +216,66 @@ def solve(
else:
path: Path = Path(path)

# MFR loop
# set iteration counter and status
niter = 0
status = {}
self._converged = False
errors = []
reg = None
x = None

# set timer
start_time = time()

# start MFR iteration
while niter < miter and not self._converged:
with Spinner(f"{niter:02}-th MFR iteration", timer=True) as sp:
sp_base_text = sp.text + " "

# compute regularization matrix
hmat = self.regularization_matrix(
x0, eps=eps, derivative_weights=derivative_weights
)

# compute SVD components
spinner = sp if verbose else None
singular, u_vecs, basis = compute_svd(self._gmat, hmat, use_gpu=use_gpu, sp=spinner)

# find optimal solution using regularizer class
sp.text = sp_base_text + " (Solving regularizer)"
reg = regularizer(singular, u_vecs, basis, data=self._data)
x, _ = reg.solve(bounds=bounds, **kwargs)

# check convergence
diff = np.linalg.norm(x - x0)
errors.append(diff)
self._converged = bool(diff < tol)

# update solution
x0 = x

# store regularizer object at each iteration
if store_regularizers:
with (path / f"regularizer_{niter}.pickle").open("wb") as f:
pickle.dump(reg, f)

# print iteration information
_text = f"(Diff: {diff:.3e}, Tolerance: {tol:.3e}, lambda: {reg.lambda_opt:.3e})"
sp.text = sp_base_text + _text
sp.ok()

# update iteration counter
niter += 1
try:
sp_base_text = sp.text + " "

# compute regularization matrix
hmat = self.regularization_matrix(
x0, eps=eps, derivative_weights=derivative_weights
)

# compute SVD components
spinner = sp if verbose else None
singular, u_vecs, basis = compute_svd(
self._gmat, hmat, use_gpu=use_gpu, sp=spinner
)

# find optimal solution using regularizer class
sp.text = sp_base_text + " (Solving regularizer)"
reg = regularizer(singular, u_vecs, basis, data=self._data)
x, _ = reg.solve(bounds=bounds, **kwargs)

# check convergence
diff = np.linalg.norm(x - x0)
errors.append(diff)
self._converged = bool(diff < tol)

# update solution
x0 = x

# store regularizer object at each iteration
if store_regularizers:
with (path / f"regularizer_{niter}.pickle").open("wb") as f:
pickle.dump(reg, f)

# print iteration information
_text = (
f"(Diff: {diff:.3e}, Tolerance: {tol:.3e}, lambda: {reg.lambda_opt:.3e})"
)
sp.text = sp_base_text + _text
sp.ok()

# update iteration counter
niter += 1

except Exception as e:
sp.fail()
print(e)
break

elapsed_time = time() - start_time

Expand Down
4 changes: 2 additions & 2 deletions docs/source/_static/switcher.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
},
{
"name": "0.1",
"version": "v0.1.1",
"url": "https://cherab-inversion.readthedocs.io/en/v0.1.1/",
"version": "v0.1.2",
"url": "https://cherab-inversion.readthedocs.io/en/v0.1.2/",
"preferred": true
}
]
2 changes: 1 addition & 1 deletion meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ project(
'cherab-inversion',
'cython',
# Note that version cannot dinamically changed now.
version: '0.1.2.dev0',
version: '0.1.2',
meson_version: '>= 0.64.0',
default_options: [
'cython_args=-3',
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ build-backend = "mesonpy"
[project]
name = "cherab-inversion"
description = "Cherab inversion framework"
version = "0.1.2.dev0"
version = "0.1.2"
readme = "README.md"
authors = [
{ name = "Koyo Munechika", email = "[email protected]" },
Expand Down

0 comments on commit bdfea95

Please sign in to comment.