-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy patheval_orb_shape.py
114 lines (94 loc) · 4.29 KB
/
eval_orb_shape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import trimesh
from scipy.spatial import cKDTree as KDTree
import logging
import argparse
import os
from utils.base_utils import load_cfg
logger = logging.getLogger(__name__)
target_volume = 900 # roughly corresponds to a witdth of 10cm
num_points_align = 1000
max_iterations = 100
cost_threshold = 0
num_points_chamfer = 30000
def sample_surface_point(mesh, num_points, even=False):
if even:
sample_points, indexes = trimesh.sample.sample_surface_even(mesh, count=num_points)
while len(sample_points) < num_points:
more_sample_points, indexes = trimesh.sample.sample_surface_even(mesh, count=num_points)
sample_points = np.concatenate([sample_points, more_sample_points], axis=0)
else:
sample_points, indexes = trimesh.sample.sample_surface(mesh, count=num_points)
return sample_points[:num_points]
def load_mesh(fpath: str) -> trimesh.Trimesh:
if fpath.endswith('.npz'):
mesh_npz = np.load(fpath)
verts = mesh_npz['verts']
faces = mesh_npz['faces']
faces = np.concatenate((faces, faces[:, list(reversed(range(faces.shape[-1])))]), axis=0)
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
else:
mesh = trimesh.load_mesh(fpath)
return mesh
# https://github.com/facebookresearch/DeepSDF/blob/main/deep_sdf/metrics/chamfer.py
def compute_trimesh_chamfer(gt_points, gen_mesh, num_mesh_samples=30000):
"""
This function computes a symmetric chamfer distance, i.e. the sum of both chamfers.
gt_points: trimesh.points.PointCloud of just poins, sampled from the surface (see
compute_metrics.ply for more documentation)
gen_mesh: trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction
method (see compute_metrics.py for more)
"""
if gen_mesh is None:
gen_points_sampled = np.zeros((num_mesh_samples, 3))
else:
gen_points_sampled = trimesh.sample.sample_surface(gen_mesh, num_mesh_samples)[0]
# only need numpy array of points
gt_points_np = gt_points.vertices
# one direction
gen_points_kd_tree = KDTree(gen_points_sampled)
one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points_np)
gt_to_gen_chamfer = np.mean(np.square(one_distances))
# other direction
gt_points_kd_tree = KDTree(gt_points_np)
two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points_sampled)
gen_to_gt_chamfer = np.mean(np.square(two_distances))
return gt_to_gen_chamfer, gen_to_gt_chamfer
def compute_shape_score(output_mesh, target_mesh):
if output_mesh is None:
logger.error('output mesh not found')
return {}
try:
mesh_result = load_mesh(output_mesh)
except ValueError:
import traceback; traceback.print_exc()
mesh_result = None
mesh_scan = load_mesh(target_mesh)
gt_to_gen_chamfer, gen_to_gt_chamfer = compute_trimesh_chamfer(mesh_scan, mesh_result, num_points_chamfer)
bidir_chamfer = (gt_to_gen_chamfer + gen_to_gt_chamfer) / 2.
return {'bidir_chamfer': bidir_chamfer}
def calculate_scale(mesh, target_volume, method='volume'):
if method == 'bounding_box':
width, height, length = mesh.extents
bounding_box_volume = (width * height * length)
scale = (target_volume / bounding_box_volume)**(1/3)
elif method == 'volume':
voxel_length = mesh.extents.min() /100
voxel = mesh.voxelized(voxel_length).fill()
voxel_volume = voxel.volume
scale = (target_volume / voxel_volume)**(1/3)
return scale
def main(out_mesh_path, target_mesh_path):
cd = compute_shape_score(out_mesh_path, target_mesh_path)
orb_records = f'{out_mesh_path}: {cd}'
print(cd)
shape_metrics_dir = 'data/meshes/orb'
os.makedirs(shape_metrics_dir, exist_ok=True)
with open(f'{shape_metrics_dir}/orb_records.txt','a') as f:
f.write(orb_records+'\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--out_mesh_path', type=str, default='data/meshes/orb/cactus_scene001_shape-180000.ply')
parser.add_argument('--target_mesh_path', type=str, default='/home/riga/NeRF/nerf_data/ground_truth/cactus_scene001/mesh_blender/mesh.obj')
flags = parser.parse_args()
main(flags.out_mesh_path, flags.target_mesh_path)