Skip to content

Commit

Permalink
Refactor solve method to not be so monstrous
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmundt committed Feb 27, 2024
1 parent 13d80dd commit 4c93c5e
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 60 deletions.
136 changes: 77 additions & 59 deletions pyomo/contrib/solver/ipopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(self, **kwds):
self._writer = NLWriter()
self._available_cache = None
self._version_cache = None
self.time_limit_buffer = {'min' : 1.0, 'max' : 100.0, 'tol' : 0.01}
self.time_limit_buffer = {'min': 1.0, 'max': 100.0, 'tol': 0.01}

def available(self, config=None):
if config is None:
Expand Down Expand Up @@ -323,6 +323,72 @@ def _run_subprocess(
)
return process, iters, ipopt_time_nofunc, ipopt_time_func, ipopt_total_time

def _generate_results_object(
self,
basename: str,
timer: HierarchicalTimer,
proven_infeasible: bool,
nl_info: NLWriterInfo,
subprocess: tuple,
):
if proven_infeasible:
results = Results()
results.termination_condition = TerminationCondition.provenInfeasible
results.solution_loader = SolSolutionLoader(None, None)
results.iteration_count = 0
results.timing_info.total_seconds = 0
elif len(nl_info.variables) == 0:
if len(nl_info.eliminated_vars) == 0:
results = Results()
results.termination_condition = TerminationCondition.emptyModel
results.solution_loader = SolSolutionLoader(None, None)
else:
results = Results()
results.termination_condition = (
TerminationCondition.convergenceCriteriaSatisfied
)
results.solution_status = SolutionStatus.optimal
results.solution_loader = SolSolutionLoader(None, nl_info=nl_info)
results.iteration_count = 0
results.timing_info.total_seconds = 0
else:
if os.path.isfile(basename + '.sol'):
with open(basename + '.sol', 'r') as sol_file:
timer.start('parse_sol')
results = self._parse_solution(sol_file, nl_info)
timer.stop('parse_sol')
else:
results = Results()
if subprocess[0].returncode != 0:
results.extra_info.return_code = subprocess[0].returncode
results.termination_condition = TerminationCondition.error
results.solution_loader = SolSolutionLoader(None, None)
else:
results.iteration_count = subprocess[1]
if subprocess[2] is not None:
results.timing_info.ipopt_excluding_nlp_functions = subprocess[2]

if subprocess[3] is not None:
results.timing_info.nlp_function_evaluations = subprocess[3]
if subprocess[4] is not None:
results.timing_info.total_seconds = subprocess[4]
return results

def _load_solutions(self, model, results):
results.solution_loader.load_vars()
if (
hasattr(model, 'dual')
and isinstance(model.dual, Suffix)
and model.dual.import_enabled()
):
model.dual.update(results.solution_loader.get_duals())
if (
hasattr(model, 'rc')
and isinstance(model.rc, Suffix)
and model.rc.import_enabled()
):
model.rc.update(results.solution_loader.get_reduced_costs())

@document_kwargs_from_configdict(CONFIG)
def solve(self, model, **kwds):
# Begin time tracking
Expand Down Expand Up @@ -374,55 +440,19 @@ def solve(self, model, **kwds):
proven_infeasible = False
except InfeasibleConstraintException:
proven_infeasible = True
nl_info = None
timer.stop('write_nl_file')
if not proven_infeasible and len(nl_info.variables) > 0:
process, iters, ipopt_time_nofunc, ipopt_time_func, ipopt_total_time = (
self._run_subprocess(basename, timer, config, nl_info)
subprocess_result = self._run_subprocess(
basename, timer, config, nl_info
)

if proven_infeasible:
results = Results()
results.termination_condition = TerminationCondition.provenInfeasible
results.solution_loader = SolSolutionLoader(None, None)
results.iteration_count = 0
results.timing_info.total_seconds = 0
elif len(nl_info.variables) == 0:
if len(nl_info.eliminated_vars) == 0:
results = Results()
results.termination_condition = TerminationCondition.emptyModel
results.solution_loader = SolSolutionLoader(None, None)
else:
results = Results()
results.termination_condition = (
TerminationCondition.convergenceCriteriaSatisfied
)
results.solution_status = SolutionStatus.optimal
results.solution_loader = SolSolutionLoader(None, nl_info=nl_info)
results.iteration_count = 0
results.timing_info.total_seconds = 0
else:
if os.path.isfile(basename + '.sol'):
with open(basename + '.sol', 'r') as sol_file:
timer.start('parse_sol')
results = self._parse_solution(sol_file, nl_info)
timer.stop('parse_sol')
else:
results = Results()
if process.returncode != 0:
results.extra_info.return_code = process.returncode
results.termination_condition = TerminationCondition.error
results.solution_loader = SolSolutionLoader(None, None)
else:
results.iteration_count = iters
if ipopt_time_nofunc is not None:
results.timing_info.ipopt_excluding_nlp_functions = (
ipopt_time_nofunc
)

if ipopt_time_func is not None:
results.timing_info.nlp_function_evaluations = ipopt_time_func
if ipopt_total_time is not None:
results.timing_info.total_seconds = ipopt_total_time
subprocess_result = None

results = self._generate_results_object(
basename, timer, proven_infeasible, nl_info, subprocess_result
)

if (
config.raise_exception_on_nonoptimal_result
and results.solution_status != SolutionStatus.optimal
Expand All @@ -444,19 +474,7 @@ def solve(self, model, **kwds):
)

if config.load_solutions:
results.solution_loader.load_vars()
if (
hasattr(model, 'dual')
and isinstance(model.dual, Suffix)
and model.dual.import_enabled()
):
model.dual.update(results.solution_loader.get_duals())
if (
hasattr(model, 'rc')
and isinstance(model.rc, Suffix)
and model.rc.import_enabled()
):
model.rc.update(results.solution_loader.get_reduced_costs())
self._load_solutions(model, results)

if (
results.solution_status in {SolutionStatus.feasible, SolutionStatus.optimal}
Expand Down
2 changes: 1 addition & 1 deletion pyomo/contrib/solver/tests/unit/test_ipopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_class_member_list(self):
'solve',
'version',
'name',
'time_limit_buffer'
'time_limit_buffer',
]
method_list = [method for method in dir(opt) if method.startswith('_') is False]
self.assertEqual(sorted(expected_list), sorted(method_list))
Expand Down

0 comments on commit 4c93c5e

Please sign in to comment.