forked from google/or-tools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdimacs_assignment.cc
205 lines (192 loc) · 8.05 KB
/
dimacs_assignment.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
// Copyright 2010-2021 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/flags/parse.h"
#include "absl/flags/usage.h"
#include "absl/strings/str_format.h"
#include "examples/cpp/parse_dimacs_assignment.h"
#include "examples/cpp/print_dimacs_assignment.h"
#include "ortools/algorithms/hungarian.h"
#include "ortools/base/commandlineflags.h"
#include "ortools/base/logging.h"
#include "ortools/base/timer.h"
#include "ortools/graph/ebert_graph.h"
#include "ortools/graph/linear_assignment.h"
ABSL_FLAG(bool, assignment_compare_hungarian, false,
"Compare result and speed against Hungarian method.");
ABSL_FLAG(std::string, assignment_problem_output_file, "",
"Print the problem to this file in DIMACS format (after layout "
"is optimized, if applicable).");
ABSL_FLAG(bool, assignment_reverse_arcs, false,
"Ignored if --assignment_static_graph=true. Use StarGraph "
"if true, ForwardStarGraph if false.");
ABSL_FLAG(bool, assignment_static_graph, true,
"Use the ForwardStarStaticGraph representation, "
"otherwise ForwardStarGraph or StarGraph according "
"to --assignment_reverse_arcs.");
namespace operations_research {
typedef ForwardStarStaticGraph GraphType;
template <typename GraphType>
CostValue BuildAndSolveHungarianInstance(
const LinearSumAssignment<GraphType>& assignment) {
const GraphType& graph = assignment.Graph();
typedef std::vector<double> HungarianRow;
typedef std::vector<HungarianRow> HungarianProblem;
HungarianProblem hungarian_cost;
hungarian_cost.resize(assignment.NumLeftNodes());
// First we have to find the biggest cost magnitude so we can
// initialize the arc costs that aren't really there.
CostValue largest_cost_magnitude = 0;
for (typename GraphType::ArcIterator arc_it(graph); arc_it.Ok();
arc_it.Next()) {
ArcIndex arc = arc_it.Index();
CostValue cost_magnitude = std::abs(assignment.ArcCost(arc));
largest_cost_magnitude = std::max(largest_cost_magnitude, cost_magnitude);
}
double missing_arc_cost = static_cast<double>(
(assignment.NumLeftNodes() * largest_cost_magnitude) + 1);
for (HungarianProblem::iterator row = hungarian_cost.begin();
row != hungarian_cost.end(); ++row) {
row->resize(assignment.NumNodes() - assignment.NumLeftNodes(),
missing_arc_cost);
}
// We're using a graph representation without forward arcs, so in
// order to use the generic GraphType::ArcIterator we would
// need to increase our memory footprint by building the array of
// arc tails (since we need tails to build the input to the
// hungarian algorithm). We opt for the alternative of iterating
// over hte arcs via adjacency lists, which gives us the arc tails
// implicitly.
for (typename GraphType::NodeIterator node_it(graph); node_it.Ok();
node_it.Next()) {
NodeIndex node = node_it.Index();
NodeIndex tail = (node - GraphType::kFirstNode);
for (typename GraphType::OutgoingArcIterator arc_it(graph, node);
arc_it.Ok(); arc_it.Next()) {
ArcIndex arc = arc_it.Index();
NodeIndex head =
(graph.Head(arc) - assignment.NumLeftNodes() - GraphType::kFirstNode);
double cost = static_cast<double>(assignment.ArcCost(arc));
hungarian_cost[tail][head] = cost;
}
}
absl::flat_hash_map<int, int> result;
absl::flat_hash_map<int, int> wish_this_could_be_null;
WallTimer timer;
VLOG(1) << "Beginning Hungarian method.";
timer.Start();
MinimizeLinearAssignment(hungarian_cost, &result, &wish_this_could_be_null);
double elapsed = timer.GetInMs() / 1000.0;
LOG(INFO) << "Hungarian result computed in " << elapsed << " seconds.";
double result_cost = 0.0;
for (int i = 0; i < assignment.NumLeftNodes(); ++i) {
int mate = result[i];
result_cost += hungarian_cost[i][mate];
}
return static_cast<CostValue>(result_cost);
}
template <typename GraphType>
void DisplayAssignment(const LinearSumAssignment<GraphType>& assignment) {
for (typename LinearSumAssignment<GraphType>::BipartiteLeftNodeIterator
node_it(assignment);
node_it.Ok(); node_it.Next()) {
const NodeIndex left_node = node_it.Index();
const ArcIndex matching_arc = assignment.GetAssignmentArc(left_node);
const NodeIndex right_node = assignment.Head(matching_arc);
VLOG(5) << "assigned (" << left_node << ", " << right_node
<< "): " << assignment.ArcCost(matching_arc);
}
}
template <typename GraphType>
int SolveDimacsAssignment(int argc, char* argv[]) {
std::string error_message;
// Handle on the graph we will need to delete because the
// LinearSumAssignment object does not take ownership of it.
GraphType* graph = nullptr;
DimacsAssignmentParser<GraphType> parser(argv[1]);
LinearSumAssignment<GraphType>* assignment =
parser.Parse(&error_message, &graph);
if (assignment == nullptr) {
LOG(FATAL) << error_message;
}
if (!absl::GetFlag(FLAGS_assignment_problem_output_file).empty()) {
// The following tail array management stuff is done in a generic
// way so we can plug in different types of graphs for which the
// TailArrayManager template can be instantiated, even though we
// know the type of the graph explicitly. In this way, the type of
// the graph can be switched just by changing the graph type in
// this file and making no other changes to the code.
TailArrayManager<GraphType> tail_array_manager(graph);
PrintDimacsAssignmentProblem<GraphType>(
*assignment, tail_array_manager,
absl::GetFlag(FLAGS_assignment_problem_output_file));
tail_array_manager.ReleaseTailArrayIfForwardGraph();
}
CostValue hungarian_cost = 0.0;
bool hungarian_solved = false;
if (absl::GetFlag(FLAGS_assignment_compare_hungarian)) {
hungarian_cost = BuildAndSolveHungarianInstance(*assignment);
hungarian_solved = true;
}
WallTimer timer;
timer.Start();
bool success = assignment->ComputeAssignment();
double elapsed = timer.GetInMs() / 1000.0;
if (success) {
CostValue cost = assignment->GetCost();
DisplayAssignment(*assignment);
LOG(INFO) << "Cost of optimum assignment: " << cost;
LOG(INFO) << "Computed in " << elapsed << " seconds.";
LOG(INFO) << assignment->StatsString();
if (hungarian_solved && (cost != hungarian_cost)) {
LOG(ERROR) << "Optimum cost mismatch: " << cost << " vs. "
<< hungarian_cost << ".";
}
} else {
LOG(WARNING) << "Given problem is infeasible.";
}
delete assignment;
delete graph;
return EXIT_SUCCESS;
}
} // namespace operations_research
static const char* const kUsageTemplate = "usage: %s <filename>";
using ::operations_research::ForwardStarGraph;
using ::operations_research::ForwardStarStaticGraph;
using ::operations_research::SolveDimacsAssignment;
using ::operations_research::StarGraph;
int main(int argc, char* argv[]) {
std::string usage;
if (argc < 1) {
usage = absl::StrFormat(kUsageTemplate, "solve_dimacs_assignment");
} else {
usage = absl::StrFormat(kUsageTemplate, argv[0]);
}
google::InitGoogleLogging(usage.c_str());
absl::ParseCommandLine(argc, argv);
if (argc < 2) {
LOG(FATAL) << usage;
}
if (absl::GetFlag(FLAGS_assignment_static_graph)) {
return SolveDimacsAssignment<ForwardStarStaticGraph>(argc, argv);
} else if (absl::GetFlag(FLAGS_assignment_reverse_arcs)) {
return SolveDimacsAssignment<StarGraph>(argc, argv);
} else {
return SolveDimacsAssignment<ForwardStarGraph>(argc, argv);
}
}