diff --git a/boj/boj_3584.py b/boj/boj_3584.py index 582dd61..6baa4ca 100644 --- a/boj/boj_3584.py +++ b/boj/boj_3584.py @@ -8,57 +8,61 @@ setrecursionlimit(10**5) input = stdin.readline -for ____ in range(int(input())): - n = int(input()) - - a = [[] for _ in range(n+1)] - v = [0] * (n+1) - - parent = [[[0, maxsize, -maxsize] - for __ in range(MAX+1)] for _ in range(n+1)] - d = [0] * (n+1) - - for _ in range(n-1): - p, q = map(int, input().split()) - v[q] = 1 - a[p].append(q) - - p = 0 - for i in range(1, n+1): - if v[i] == 0: - p = i - break - - def dfs(cur, depth): - for next in a[cur]: - if d[next]: - continue - parent[next][0][0] = cur - d[next] = depth + 1 - dfs(next, depth+1) +n = int(input()) + +a = [[] for _ in range(n+1)] +parent = [[0] * (MAX+1) for _ in range(n+1)] +d = [0] * (n+1) +dis = [0] * (n+1) +v = [0] * (n+1) + +for _ in range(n-1): + p, q, c = map(int, input().split()) + a[p].append((q, c)) + a[q].append((p, c)) + + +def dfs(cur, depth, distance): + if v[cur]: return - d[p] = 0 - dfs(p, 0) - - for i in range(1, MAX): - for j in range(1, n+1): - parent[j][i][0] = parent[parent[j][i-1][0]][i-1][0] - - def solve(l, r): - if d[l] > d[r]: - l, r = r, l - for i in range(MAX, -1, -1): - if d[r] - d[l] >= (1 << i): - left = parent[r][i][1] - r = parent[r][i][0] - if l == r: - return l - for i in range(MAX, -1, -1): - if parent[l][i][0] != parent[r][i][0]: - l = parent[l][i][0] - r = parent[r][i][0] - return parent[l][0][0] + for next, next_cost in a[cur]: + if d[next] != 0: + continue + parent[next][0] = cur + dis[next] = distance + next_cost + d[next] = depth+1 + dfs(next, depth+1, dis[next]) + return + + +dfs(1, 1, 0) + + +for i in range(1, MAX): + for j in range(1, n+1): + parent[j][i] = parent[parent[j][i-1]][i-1] + +cache = {} + + +def solve(l, r): + if d[l] > d[r]: + l, r = r, l + for i in range(MAX, -1, -1): + if d[r] - d[l] >= (1 << i): + r = parent[r][i] + if l == r: + return l + for i in range(MAX, -1, -1): + if parent[l][i] != parent[r][i]: + l = parent[l][i] + r = parent[r][i] + return parent[l][0] + +m = int(input()) +for _ in range(m): p, q = map(int, input().split()) - print(solve(p, q)) + ancestor = solve(p, q) + print(dis[p] + dis[q] - dis[ancestor] * 2)