Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add viz in human_seg #23

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pyrightconfig.json
*.lo
*.o
*.obj
*.off
*.ply
*.zip

# Precompiled Headers
*.gch
Expand Down
10 changes: 10 additions & 0 deletions experiments/human_segmentation_original/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,13 @@ or, with heat kernel signature features
python human_segmentation_original.py --input_features=hks --evaluate
```

### results

##### gt lable

![gt_label](https://github.com/Yannnnnnnnnnnn/diffusion-net/blob/master/experiments/human_segmentation_original/fig/gt1_viz.png)

###### pr label

![pr_label](https://github.com/Yannnnnnnnnnnn/diffusion-net/blob/master/experiments/human_segmentation_original/fig/pr1_viz.png)

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import trimesh

sys.path.append(os.path.join(os.path.dirname(__file__), "../../src/")) # add the path to the DiffusionNet src
import diffusion_net
Expand Down Expand Up @@ -52,7 +53,7 @@
# === Load datasets

# Load the test dataset
test_dataset = HumanSegOrigDataset(dataset_path, train=False, k_eig=k_eig, use_cache=True, op_cache_dir=op_cache_dir)
test_dataset = HumanSegOrigDataset(dataset_path, train=False, k_eig=k_eig, use_cache=False, op_cache_dir=op_cache_dir)
test_loader = DataLoader(test_dataset, batch_size=None)

# Load the train dataset
Expand Down Expand Up @@ -105,7 +106,7 @@ def train_epoch(epoch):
for data in tqdm(train_loader):

# Get data
verts, faces, frames, mass, L, evals, evecs, gradX, gradY, labels = data
verts, faces, frames, mass, L, evals, evecs, gradX, gradY, labels, _ = data

# Move to device
verts = verts.to(device)
Expand Down Expand Up @@ -151,6 +152,42 @@ def train_epoch(epoch):
return train_acc


# borrow code from subdivnet
# https://github.com/lzhengning/SubdivNet/blob/master/subdivnet/utils.py
segment_colors = np.array([
[0, 114, 189],
[217, 83, 26],
[238, 177, 32],
[126, 47, 142],
[117, 142, 48],
[76, 190, 238],
[162, 19, 48],
[240, 166, 202],
])

def save_results(mesh_path, preds, labels):

if not os.path.exists('results'):
os.mkdir('results')

labels = labels.cpu().numpy()
preds = preds.cpu().numpy()

save_path = './results'

mesh = trimesh.load_mesh(mesh_path, process=False)
mesh_name = os.path.basename(mesh_path)[:-4]

# print('mesh_path',mesh_path)
# print('mesh_name',mesh_name)
# print(save_path + '/pr-' + mesh_name + '.ply')

mesh.visual.face_colors[:, :3] = segment_colors[preds]
mesh.export(save_path + '/pr-' + mesh_name + '.ply')

mesh.visual.face_colors[:, :3] = segment_colors[labels]
mesh.export(save_path + '/gt-' + mesh_name + '.ply')

# Do an evaluation pass on the test dataset
def test():

Expand All @@ -163,7 +200,7 @@ def test():
for data in tqdm(test_loader):

# Get data
verts, faces, frames, mass, L, evals, evecs, gradX, gradY, labels = data
verts, faces, frames, mass, L, evals, evecs, gradX, gradY, labels, mesh_file = data

# Move to device
verts = verts.to(device)
Expand All @@ -185,9 +222,16 @@ def test():

# Apply the model
preds = model(features, mass, L=L, evals=evals, evecs=evecs, gradX=gradX, gradY=gradY, faces=faces)
pred_labels = torch.max(preds, dim=1).indices

# save
print('mesh_file',mesh_file)
print('pred',pred_labels.shape)
print('labels',labels.shape)
save_results(mesh_file,pred_labels,labels)

# track accuracy
pred_labels = torch.max(preds, dim=1).indices

this_correct = pred_labels.eq(labels).sum().item()
this_num = labels.shape[0]
correct += this_correct
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
self.faces_list = []
self.labels_list = [] # per-face labels!!

# Get all the files
self.mesh_files = []
label_files = []

# check the cache
if use_cache:
train_cache = os.path.join(self.cache_dir, "train.pt")
Expand All @@ -39,17 +43,13 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
print("using dataset cache path: " + str(load_cache))
if os.path.exists(load_cache):
print(" --> loading dataset from cache")
self.verts_list, self.faces_list, self.frames_list, self.massvec_list, self.L_list, self.evals_list, self.evecs_list, self.gradX_list, self.gradY_list, self.labels_list = torch.load( load_cache)
self.verts_list, self.faces_list, self.frames_list, self.massvec_list, self.L_list, self.evals_list, self.evecs_list, self.gradX_list, self.gradY_list, self.labels_list, self.mesh_files = torch.load( load_cache)
return
print(" --> dataset not in cache, repopulating")


# Load the meshes & labels

# Get all the files
mesh_files = []
label_files = []

# Train test split
if self.train:

Expand All @@ -59,7 +59,7 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
for fname in os.listdir(mesh_dirpath):
mesh_fullpath = os.path.join(mesh_dirpath, fname)
label_fullpath = os.path.join(label_dirpath, fname[:-4] + ".txt")
mesh_files.append(mesh_fullpath)
self.mesh_files.append(mesh_fullpath)
label_files.append(label_fullpath)

# faust
Expand All @@ -68,7 +68,7 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
for fname in os.listdir(mesh_dirpath):
mesh_fullpath = os.path.join(mesh_dirpath, fname)
label_fullpath = os.path.join(label_dirpath, "faust_corrected.txt")
mesh_files.append(mesh_fullpath)
self.mesh_files.append(mesh_fullpath)
label_files.append(label_fullpath)

# mit
Expand All @@ -80,7 +80,7 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
for fname in os.listdir(mesh_dirpath):
mesh_fullpath = os.path.join(mesh_dirpath, fname)
label_fullpath = os.path.join(label_dirpath, "mit_{}_corrected.txt".format(pose))
mesh_files.append(mesh_fullpath)
self.mesh_files.append(mesh_fullpath)
label_files.append(label_fullpath)

# scape
Expand All @@ -89,7 +89,7 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
for fname in os.listdir(mesh_dirpath):
mesh_fullpath = os.path.join(mesh_dirpath, fname)
label_fullpath = os.path.join(label_dirpath, "scape_corrected.txt")
mesh_files.append(mesh_fullpath)
self.mesh_files.append(mesh_fullpath)
label_files.append(label_fullpath)

else:
Expand All @@ -106,17 +106,16 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
label_fname = "shrec_{}_full.txt".format(iShrec)
mesh_fullpath = os.path.join(mesh_dirpath, mesh_fname)
label_fullpath = os.path.join(label_dirpath, label_fname)
mesh_files.append(mesh_fullpath)
self.mesh_files.append(mesh_fullpath)
label_files.append(label_fullpath)

print("loading {} meshes".format(len(mesh_files)))
print("loading {} meshes".format(len(self.mesh_files)))

# Load the actual files
for iFile in range(len(mesh_files)):
for iFile in range(len(self.mesh_files)):

print("loading mesh " + str(mesh_files[iFile]))

verts, faces = pp3d.read_mesh(mesh_files[iFile])
verts, faces = pp3d.read_mesh(self.mesh_files[iFile])
labels = np.loadtxt(label_files[iFile]).astype(int)-1

# to torch
Expand All @@ -140,10 +139,10 @@ def __init__(self, root_dir, train, k_eig=128, use_cache=True, op_cache_dir=None
# save to cache
if use_cache:
diffusion_net.utils.ensure_dir_exists(self.cache_dir)
torch.save((self.verts_list, self.faces_list, self.frames_list, self.massvec_list, self.L_list, self.evals_list, self.evecs_list, self.gradX_list, self.gradY_list, self.labels_list), load_cache)
torch.save((self.verts_list, self.faces_list, self.frames_list, self.massvec_list, self.L_list, self.evals_list, self.evecs_list, self.gradX_list, self.gradY_list, self.labels_list, self.mesh_files), load_cache)

def __len__(self):
return len(self.verts_list)

def __getitem__(self, idx):
return self.verts_list[idx], self.faces_list[idx], self.frames_list[idx], self.massvec_list[idx], self.L_list[idx], self.evals_list[idx], self.evecs_list[idx], self.gradX_list[idx], self.gradY_list[idx], self.labels_list[idx]
return self.verts_list[idx], self.faces_list[idx], self.frames_list[idx], self.massvec_list[idx], self.L_list[idx], self.evals_list[idx], self.evecs_list[idx], self.gradX_list[idx], self.gradY_list[idx], self.labels_list[idx], self.mesh_files[idx]