diff --git a/Examples/Python/deep_ssm.py b/Examples/Python/deep_ssm.py index 7fa4e31dbd..5dc5abb31a 100644 --- a/Examples/Python/deep_ssm.py +++ b/Examples/Python/deep_ssm.py @@ -613,7 +613,7 @@ def Run_Pipeline(args): "use_best_model": True, "tl_net":{ "enabled": False, - "ae_epochs": 10000, + "ae_epochs": 100, "tf_epochs":100, "joint_epochs":25, "alpha":1, @@ -664,7 +664,7 @@ def Run_Pipeline(args): print("Validation world particle MSE: "+str(mean_MSE)+" +- "+str(std_MSE)) template_mesh = train_mesh_files[ref_index] template_particles = train_local_particles[ref_index].replace("./", data_dir) - # Get distabce between clipped true and predicted meshes + # Get distance between clipped true and predicted meshes mean_dist = DeepSSMUtils.analyzeMeshDistance(predicted_val_local_particles, val_mesh_files, template_particles, template_mesh, val_out_dir, planes=val_planes) diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py index 056935be7f..7986a2683c 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/model.py @@ -158,7 +158,10 @@ def __init__(self, num_latent, num_corr): def forward(self, z): pt_out = self.decoder(z) return pt_out - + +""" +DeepSSM TL-Net Model +""" class DeepSSMNet_TLNet(nn.Module): def __init__(self, conflict_file): super(DeepSSMNet_TLNet, self).__init__() @@ -180,14 +183,15 @@ def __init__(self, conflict_file): self.ImageEncoder = DeterministicEncoder(self.num_latent, self.img_dims, self.loader_dir) def forward(self, pt, x): - # import pdb; pdb.set_trace() + # for testing if len(pt.shape) < 3: - zt, aa = self.ImageEncoder(x) + zt, _ = self.ImageEncoder(x) pt_out = self.CorrespondenceDecoder(zt) return [zt, pt_out.reshape(-1, self.num_corr, 3)] + # for training else: pt1 = pt.view(-1, pt.shape[1]*pt.shape[2]) z = self.CorrespondenceEncoder(pt1) pt_out = self.CorrespondenceDecoder(z) - zt, aa = self.ImageEncoder(x) + zt, _ = self.ImageEncoder(x) return [pt_out.view(-1, pt.shape[1], pt.shape[2]), z, zt] \ No newline at end of file diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py index e5fdb5ae99..ca84e630d1 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/trainer.py @@ -316,6 +316,12 @@ def supervised_train(config_file): json.dump(parameters, json_file, indent=2) print("Fine tuning complete, model saved. Best model after epoch " + str(best_ft_epoch)) +''' +Network training method for TL-Net model + defines, initializes, and trains the models + logs training and validation errors + saves the model and returns the path it is saved to +''' def supervised_train_tl(config_file): with open(config_file) as json_file: parameters = json.load(json_file) diff --git a/docs/deep-learning/deep-ssm.md b/docs/deep-learning/deep-ssm.md index 4223512ba8..ab34c7a9ab 100644 --- a/docs/deep-learning/deep-ssm.md +++ b/docs/deep-learning/deep-ssm.md @@ -42,10 +42,16 @@ The next step is to reformat the data (original and augmented) into PyTorch tens ### 3. Training -PyTorch is used in constructing and training DeepSSM. The network architecture is defined to have five convolution layers followed by two fully connected layers, as illustrated in the figure below. Parametric ReLU activation is used, and the weights are initialized using Xavier initialization. The network is trained for the specified number of epochs using Adam optimization to minimize the L2 loss function with a learning rate of 0.0001. The average training and validation error are printed and logged each epoch to determine convergence. +PyTorch is used in constructing and training DeepSSM. We have implemented two different network architectures: + +* **Base-DeepSSM:** The network architecture is defined to have five convolution layers followed by two fully connected layers, as illustrated in the figure below. Parametric ReLU activation is used, and the weights are initialized using Xavier initialization. The network is trained for the specified number of epochs using Adam optimization to minimize the L2 loss function with a learning rate of 0.0001. The average training and validation error are printed and logged each epoch to determine convergence. ![DeepSSM Architecture](../img/deep-learning/Architecture.png) +* **TL-DeepSSM:** In TL-DeepSSM, the input is an image and correspondence pair. The network architecture of the TL-DeepSSM consists of two parts: (i) the autoencoder that learns the latent dimension for each correspondence, and (ii) the network that learns the latent dimension from the image (this is called the T-flank and it is similar to the Base-DeepSSM architecture). The training routine is broken into three parts. First, the correspondence autoencoder is trained. Next, the T-flank is trained while the correspondence autoencoder weights are kept frozen. Finally, the entire model is trained jointly. For inference using a testing sample, one can directly obtain the correspondences from an image via the T-flank and decoder. + +![DeepSSM Architecture](../img/deep-learning/TLNet-DeepSSM.jpg) + ### 4. Testing The trained model is then used to predict the PCA score from the images in the test set. These PCA scores are then un-whitened and mapped back to the particle coordinates using the eigenvalues and eigenvectors from PCA. Thus a PDM is acquired for each test image. diff --git a/docs/img/deep-learning/TLNet-DeepSSM.jpg b/docs/img/deep-learning/TLNet-DeepSSM.jpg new file mode 100644 index 0000000000..7f86162d33 Binary files /dev/null and b/docs/img/deep-learning/TLNet-DeepSSM.jpg differ