Skip to content

Commit

Permalink
solved boj 3584
Browse files Browse the repository at this point in the history
  • Loading branch information
ha4219 committed Apr 5, 2022
1 parent 5e6e83b commit b18023b
Showing 1 changed file with 54 additions and 50 deletions.
104 changes: 54 additions & 50 deletions boj/boj_3584.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,57 +8,61 @@
setrecursionlimit(10**5)
input = stdin.readline

for ____ in range(int(input())):
n = int(input())

a = [[] for _ in range(n+1)]
v = [0] * (n+1)

parent = [[[0, maxsize, -maxsize]
for __ in range(MAX+1)] for _ in range(n+1)]
d = [0] * (n+1)

for _ in range(n-1):
p, q = map(int, input().split())
v[q] = 1
a[p].append(q)

p = 0
for i in range(1, n+1):
if v[i] == 0:
p = i
break

def dfs(cur, depth):
for next in a[cur]:
if d[next]:
continue
parent[next][0][0] = cur
d[next] = depth + 1
dfs(next, depth+1)
n = int(input())

a = [[] for _ in range(n+1)]
parent = [[0] * (MAX+1) for _ in range(n+1)]
d = [0] * (n+1)
dis = [0] * (n+1)
v = [0] * (n+1)

for _ in range(n-1):
p, q, c = map(int, input().split())
a[p].append((q, c))
a[q].append((p, c))


def dfs(cur, depth, distance):
if v[cur]:
return

d[p] = 0
dfs(p, 0)

for i in range(1, MAX):
for j in range(1, n+1):
parent[j][i][0] = parent[parent[j][i-1][0]][i-1][0]

def solve(l, r):
if d[l] > d[r]:
l, r = r, l
for i in range(MAX, -1, -1):
if d[r] - d[l] >= (1 << i):
left = parent[r][i][1]
r = parent[r][i][0]
if l == r:
return l
for i in range(MAX, -1, -1):
if parent[l][i][0] != parent[r][i][0]:
l = parent[l][i][0]
r = parent[r][i][0]
return parent[l][0][0]
for next, next_cost in a[cur]:
if d[next] != 0:
continue
parent[next][0] = cur
dis[next] = distance + next_cost
d[next] = depth+1
dfs(next, depth+1, dis[next])
return


dfs(1, 1, 0)


for i in range(1, MAX):
for j in range(1, n+1):
parent[j][i] = parent[parent[j][i-1]][i-1]

cache = {}


def solve(l, r):
if d[l] > d[r]:
l, r = r, l
for i in range(MAX, -1, -1):
if d[r] - d[l] >= (1 << i):
r = parent[r][i]
if l == r:
return l
for i in range(MAX, -1, -1):
if parent[l][i] != parent[r][i]:
l = parent[l][i]
r = parent[r][i]
return parent[l][0]


m = int(input())
for _ in range(m):
p, q = map(int, input().split())
print(solve(p, q))
ancestor = solve(p, q)
print(dis[p] + dis[q] - dis[ancestor] * 2)

0 comments on commit b18023b

Please sign in to comment.