-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathkruskals.py
97 lines (77 loc) · 2.66 KB
/
kruskals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Implementation of prims algorithm to check if the Distributed GHS returned
correctly or not"""
import sys
INF = sys.maxsize
class Kruskals:
"""Class for implementation of kruskals algorithm"""
def __init__(self, num_nodes):
"""Ctor
Arguments:
num_nodes {Integer} -- Number of nodes in the graph
"""
self.num_nodes = num_nodes
self.parent = list(range(num_nodes))
def get_parent(self, node):
"""Traverse up the sub-tree to get the parent of node
Arguments:
node {Integer}
Returns:
Integer -- Parent node
"""
if self.parent[node] == node:
return node
return self.get_parent(self.parent[node])
def union(self, parent1, parent2):
"""Combine two sub-trees
Arguments:
parent1 {Integer}
parent2 {Integer}
"""
self.parent[parent1] = parent2
def get_mst(self, edges):
"""Get the MST using Kruskal's algorithm
Arguments:
edges {List} -- List of raw edges
Returns:
Float -- Tree Weight
"""
# Sort the edges in order of increasing weights
edges.sort(key=lambda x: float(x[2]))
# Pick the least edge at a time, and check if it doesn't form a cycle
mst_set = []
for edge in edges:
node1 = int(edge[0])
node2 = int(edge[1])
node1_parent = self.get_parent(node1)
node2_parent = self.get_parent(node2)
if node1_parent != node2_parent:
# Add edge to mst
mst_set.append(edge)
self.union(node1_parent, node2_parent)
if len(mst_set) == self.num_nodes - 1:
break
# Find and return the sum of weights of the tree edges
tree_weight = 0
for tree_edge in mst_set:
tree_weight += float(tree_edge[2])
return tree_weight
if __name__ == '__main__':
if len(sys.argv) < 2:
print('To run the file: python kruskals.py <input-file>')
sys.exit()
# Read from the input file
input_file = sys.argv[1]
with open(input_file) as file:
contents = file.readlines()
contents = [x.strip() for x in contents]
num_nodes = int(contents[0])
raw_edges = []
for line in contents[1:]:
if len(line) > 1:
line = line[1:-1].split(',')
raw_edges.append(line)
else:
break
k = Kruskals(num_nodes)
weight = k.get_mst(raw_edges)
print('[SUCCESS]: Completed Execution. MST Weight: ' + str(weight))