Skip to content

Commit

Permalink
Added documentation for TL-Net models
Browse files Browse the repository at this point in the history
  • Loading branch information
zahidemon committed Aug 3, 2023
1 parent 5985657 commit 1225227
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Examples/Python/deep_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def Run_Pipeline(args):
"use_best_model": True,
"tl_net":{
"enabled": False,
"ae_epochs": 10000,
"ae_epochs": 100,
"tf_epochs":100,
"joint_epochs":25,
"alpha":1,
Expand Down Expand Up @@ -664,7 +664,7 @@ def Run_Pipeline(args):
print("Validation world particle MSE: "+str(mean_MSE)+" +- "+str(std_MSE))
template_mesh = train_mesh_files[ref_index]
template_particles = train_local_particles[ref_index].replace("./", data_dir)
# Get distabce between clipped true and predicted meshes
# Get distance between clipped true and predicted meshes
mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_val_local_particles, val_mesh_files,
template_particles, template_mesh, val_out_dir,
planes=val_planes)
Expand Down
12 changes: 8 additions & 4 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def __init__(self, num_latent, num_corr):
def forward(self, z):
pt_out = self.decoder(z)
return pt_out


"""
DeepSSM TL-Net Model
"""
class DeepSSMNet_TLNet(nn.Module):
def __init__(self, conflict_file):
super(DeepSSMNet_TLNet, self).__init__()
Expand All @@ -180,14 +183,15 @@ def __init__(self, conflict_file):
self.ImageEncoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir)

def forward(self, pt, x):
# import pdb; pdb.set_trace()
# for testing
if len(pt.shape) < 3:
zt, aa = self.ImageEncoder(x)
zt, _ = self.ImageEncoder(x)
pt_out = self.CorrespondenceDecoder(zt)
return [zt, pt_out.reshape(-1, self.num_corr, 3)]
# for training
else:
pt1 = pt.view(-1, pt.shape[1]*pt.shape[2])
z = self.CorrespondenceEncoder(pt1)
pt_out = self.CorrespondenceDecoder(z)
zt, aa = self.ImageEncoder(x)
zt, _ = self.ImageEncoder(x)
return [pt_out.view(-1, pt.shape[1], pt.shape[2]), z, zt]
6 changes: 6 additions & 0 deletions Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,12 @@ def supervised_train(config_file):
json.dump(parameters, json_file, indent=2)
print("Fine tuning complete, model saved. Best model after epoch " + str(best_ft_epoch))

'''
Network training method for TL-Net model
defines, initializes, and trains the models
logs training and validation errors
saves the model and returns the path it is saved to
'''
def supervised_train_tl(config_file):
with open(config_file) as json_file:
parameters = json.load(json_file)
Expand Down
8 changes: 7 additions & 1 deletion docs/deep-learning/deep-ssm.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,16 @@ The next step is to reformat the data (original and augmented) into PyTorch tens

### 3. Training

PyTorch is used in constructing and training DeepSSM. The network architecture is defined to have five convolution layers followed by two fully connected layers, as illustrated in the figure below. Parametric ReLU activation is used, and the weights are initialized using Xavier initialization. The network is trained for the specified number of epochs using Adam optimization to minimize the L2 loss function with a learning rate of 0.0001. The average training and validation error are printed and logged each epoch to determine convergence.
PyTorch is used in constructing and training DeepSSM. We have implemented two different network architectures:

* **Base-DeepSSM:** The network architecture is defined to have five convolution layers followed by two fully connected layers, as illustrated in the figure below. Parametric ReLU activation is used, and the weights are initialized using Xavier initialization. The network is trained for the specified number of epochs using Adam optimization to minimize the L2 loss function with a learning rate of 0.0001. The average training and validation error are printed and logged each epoch to determine convergence.

![DeepSSM Architecture](../img/deep-learning/Architecture.png)

* **TL-DeepSSM:** In TL-DeepSSM, the input is an image and correspondence pair. The network architecture of the TL-DeepSSM consists of two parts: (i) the autoencoder that learns the latent dimension for each correspondence, and (ii) the network that learns the latent dimension from the image (this is called the T-flank and it is similar to the Base-DeepSSM architecture). The training routine is broken into three parts. First, the correspondence autoencoder is trained. Next, the T-flank is trained while the correspondence autoencoder weights are kept frozen. Finally, the entire model is trained jointly. For inference using a testing sample, one can directly obtain the correspondences from an image via the T-flank and decoder.

![DeepSSM Architecture](../img/deep-learning/TLNet-DeepSSM.jpg)

### 4. Testing

The trained model is then used to predict the PCA score from the images in the test set. These PCA scores are then un-whitened and mapped back to the particle coordinates using the eigenvalues and eigenvectors from PCA. Thus a PDM is acquired for each test image.
Expand Down
Binary file added docs/img/deep-learning/TLNet-DeepSSM.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 1225227

Please sign in to comment.