forked from AFumis/BCT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdraw_tree.py
119 lines (97 loc) · 3.46 KB
/
draw_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from igraph import Graph
import plotly.graph_objects as go
import plotly.io as io
def draw_tree(tree,title=""):
'''
Plot context tree in a variable length Markov chain in PNG
Args:
-----
tree: Tree object
title: String
Returns:
-----
None: NoneType
'''
io.renderers.default='png'
nodes = tree.edges
nr_vertices =len(nodes) * len(tree.s) + 1
v_label = [""]
lab = [0]
cont = 0
for children in nodes.values():
for child in children:
cont+=1
v_label += [child]
lab += [cont]
G = Graph(directed=True)
# Add vertices
G.add_vertices(nr_vertices)
# Add edges to create the tree structure
for father in nodes:
for children in nodes[father]:
G.add_edges([(lab[v_label.index(father)],lab[v_label.index(children)])])
lay = G.layout('rt')
position = {k: lay[k] for k in range(nr_vertices)}
Y = [lay[k][1] for k in range(nr_vertices)]
M = max(Y)
E = [e.tuple for e in G.es] # list of edges
L = len(position)
Xn = [position[k][0] for k in range(L)]
Yn = [2*M-position[k][1] for k in range(L)]
Xe = []
Ye = []
for edge in E:
Xe+=[position[edge[0]][0],position[edge[1]][0], None]
Ye+=[2*M-position[edge[0]][1],2*M-position[edge[1]][1], None]
labels = v_label
fig = go.Figure()
fig.add_trace(go.Scatter(x=Xe,
y=Ye,
mode='lines',
line=dict(color='rgb(0,0,0)', width=1),
hoverinfo='none'
))
fig.add_trace(go.Scatter(x=Xn,
y=Yn,
mode='markers',
name='bla',
marker=dict(symbol='circle-dot',
size=5,
color='rgb(250,50,50)', #'#DB4551',
line=dict(color='rgb(250,50,50)', width=1)
),
text=labels,
hoverinfo='text',
opacity=0.8
))
def make_annotations(pos, text, font_size=15, font_color='rgb(0,0,0)'):
L=len(pos)
if len(text)!=L:
raise ValueError('The lists pos and text must have the same len')
annotations = []
for k in range(L):
annotations.append(
dict(
text=labels[k], # or replace labels with a different list for the text within the circle
x=pos[k][0], y=2*M-position[k][1]-0.2,
xref='x1', yref='y1',
font=dict(color=font_color, size=font_size),
showarrow=False)
)
return annotations
axis = dict(showline=False, # hide axis line, grid, ticklabels and title
zeroline=False,
showgrid=False,
showticklabels=False,
)
fig.update_layout(title= title,
annotations=make_annotations(position, v_label),
font_size=12,
showlegend=False,
xaxis=axis,
yaxis=axis,
margin=dict(l=40, r=40, b=85, t=100),
hovermode='closest',
plot_bgcolor='rgb(255,255,255)'
)
fig.show()