From 87e0aaeb197343937886dae992cf2ca08b37f560 Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Tue, 28 Nov 2023 17:25:20 -0800
Subject: [PATCH 01/10] update init_weights_path before training model

---
 element_deeplabcut/train.py | 23 ++++++++++++++++++-----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py
index b4f2765..6cde106 100644
--- a/element_deeplabcut/train.py
+++ b/element_deeplabcut/train.py
@@ -11,6 +11,7 @@
 from pathlib import Path
 from element_interface.utils import find_full_path, dict_to_uuid
 from .readers import dlc_reader
+import yaml
 
 schema = dj.schema()
 _linking_module = None
@@ -241,8 +242,7 @@ class ModelTraining(dj.Computed):
     # https://github.com/DeepLabCut/DeepLabCut/issues/70
 
     def make(self, key):
-        from deeplabcut import train_network  # isort:skip
-
+        import deeplabcut
         try:
             from deeplabcut.utils.auxiliaryfunctions import (
                 get_model_folder,
@@ -288,13 +288,26 @@ def make(self, key):
         )
         model_train_folder = project_path / model_folder / "train"
 
+        # update init_weight
+        with open(model_train_folder / "pose_cfg.yaml", "r") as f:
+            pose_cfg = yaml.safe_load(f)
+        init_weights_path = Path(pose_cfg["init_weights"])
+
+        if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix():
+            # this is the res_net models, construct new path here
+            init_weights_path = Path(deeplabcut.__file__).parent / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name
+        else:
+            # this is existing snapshot weights, update path here
+            init_weights_path = model_train_folder / init_weights_path.name
+        
         edit_config(
             model_train_folder / "pose_cfg.yaml",
-            {"project_path": project_path.as_posix()},
+            {"project_path": project_path.as_posix(), 
+             "init_weights": init_weights_path.as_posix()},
         )
 
         # ---- Trigger DLC model training job ----
-        train_network_input_args = list(inspect.signature(train_network).parameters)
+        train_network_input_args = list(inspect.signature(deeplabcut.train_network).parameters)
         train_network_kwargs = {
             k: int(v) if k in ("shuffle", "trainingsetindex", "maxiters") else v
             for k, v in dlc_config.items()
@@ -304,7 +317,7 @@ def make(self, key):
             train_network_kwargs[k] = int(train_network_kwargs[k])
 
         try:
-            train_network(dlc_cfg_filepath, **train_network_kwargs)
+            deeplabcut.train_network(dlc_cfg_filepath, **train_network_kwargs)
         except KeyboardInterrupt:  # Instructions indicate to train until interrupt
             print("DLC training stopped via Keyboard Interrupt")
 

From 935938e12a5892aba91951a516b5631221c270ff Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Tue, 28 Nov 2023 17:36:04 -0800
Subject: [PATCH 02/10] update to use deeplabcut.__path__

---
 element_deeplabcut/train.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py
index 6cde106..2f02193 100644
--- a/element_deeplabcut/train.py
+++ b/element_deeplabcut/train.py
@@ -295,7 +295,7 @@ def make(self, key):
 
         if "pose_estimation_tensorflow/models/pretrained" in init_weights_path.as_posix():
             # this is the res_net models, construct new path here
-            init_weights_path = Path(deeplabcut.__file__).parent / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name
+            init_weights_path = Path(deeplabcut.__path__[0]) / "pose_estimation_tensorflow/models/pretrained" / init_weights_path.name
         else:
             # this is existing snapshot weights, update path here
             init_weights_path = model_train_folder / init_weights_path.name

From 4543d087ccf51540d77a8b04facd4664dc8fae3d Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Tue, 28 Nov 2023 19:31:35 -0800
Subject: [PATCH 03/10] update version

---
 element_deeplabcut/version.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/element_deeplabcut/version.py b/element_deeplabcut/version.py
index 1719e56..1ce9d0b 100644
--- a/element_deeplabcut/version.py
+++ b/element_deeplabcut/version.py
@@ -1,4 +1,4 @@
 """
 Package metadata
 """
-__version__ = "0.2.10"
+__version__ = "0.2.11"

From fcc700737d6c1ffa7f167a2ed2b76b6e85443dc2 Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Tue, 28 Nov 2023 19:34:42 -0800
Subject: [PATCH 04/10] update changelog

---
 CHANGELOG.md | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index f7f5241..b9a9e10 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,10 @@
 Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and 
 [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
 
+## [0.2.11] - 2023-11-28
+
++ Fix - Modify training to update init_weights path in pose_cfg.yaml
+
 ## [0.2.10] - 2023-11-20
 
 + Fix - Revert fixing of networkx version in setup 

From 474269e94176c61a3b536c48916ef4fac1a887d9 Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Wed, 29 Nov 2023 15:59:22 -0800
Subject: [PATCH 05/10] modify snapshotindex assignment in model training

---
 element_deeplabcut/train.py | 15 ++++++---------
 1 file changed, 6 insertions(+), 9 deletions(-)

diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py
index 2f02193..8f61ddf 100644
--- a/element_deeplabcut/train.py
+++ b/element_deeplabcut/train.py
@@ -322,22 +322,19 @@ def make(self, key):
             print("DLC training stopped via Keyboard Interrupt")
 
         snapshots = list(model_train_folder.glob("*index*"))
-        max_modified_time = 0
         # DLC goes by snapshot magnitude when judging 'latest' for evaluation
         # Here, we mean most recently generated
-        for snapshot in snapshots:
-            modified_time = os.path.getmtime(snapshot)
-            if modified_time > max_modified_time:
-                latest_snapshot = int(snapshot.stem[9:])
-                max_modified_time = modified_time
+        
+        # `snapshotindex` refers to the file index of the snapshot generated
+        # The most recent snapshot index will be the number of snapshots
 
         # update snapshotindex in the config
-        dlc_config["snapshotindex"] = latest_snapshot
+        dlc_config["snapshotindex"] = len(snapshots)
         edit_config(
             dlc_cfg_filepath,
-            {"snapshotindex": latest_snapshot},
+            {"snapshotindex": len(snapshots)},
         )
 
         self.insert1(
-            {**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config}
+            {**key, "latest_snapshot": len(snapshots), "config_template": dlc_config}
         )

From 627814833d2005755cc011228dfc44faafbb0e17 Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Wed, 29 Nov 2023 17:54:42 -0800
Subject: [PATCH 06/10] modify latest snapshot determination

---
 element_deeplabcut/train.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py
index 8f61ddf..782ad8b 100644
--- a/element_deeplabcut/train.py
+++ b/element_deeplabcut/train.py
@@ -326,15 +326,17 @@ def make(self, key):
         # Here, we mean most recently generated
         
         # `snapshotindex` refers to the file index of the snapshot generated
-        # The most recent snapshot index will be the number of snapshots
+        # The most recent snapshot index will be the number of snapshots - 1
+        # because it is an index
 
         # update snapshotindex in the config
-        dlc_config["snapshotindex"] = len(snapshots)
+        latest_snapshot = len(snapshots)-1
+        dlc_config["snapshotindex"] = latest_snapshot
         edit_config(
             dlc_cfg_filepath,
-            {"snapshotindex": len(snapshots)},
+            {"snapshotindex": latest_snapshot},
         )
 
         self.insert1(
-            {**key, "latest_snapshot": len(snapshots), "config_template": dlc_config}
+            {**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config}
         )

From 556556bb5751e9b1d6d80e8655776c3fded6f412 Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Thu, 30 Nov 2023 11:55:36 -0800
Subject: [PATCH 07/10] modify snapshot index determination

---
 element_deeplabcut/train.py | 22 ++++++++++++++--------
 1 file changed, 14 insertions(+), 8 deletions(-)

diff --git a/element_deeplabcut/train.py b/element_deeplabcut/train.py
index 782ad8b..750f3bd 100644
--- a/element_deeplabcut/train.py
+++ b/element_deeplabcut/train.py
@@ -321,22 +321,28 @@ def make(self, key):
         except KeyboardInterrupt:  # Instructions indicate to train until interrupt
             print("DLC training stopped via Keyboard Interrupt")
 
-        snapshots = list(model_train_folder.glob("*index*"))
+        snapshots = sorted(list(model_train_folder.glob("*index*")))
+        max_modified_time = 0
         # DLC goes by snapshot magnitude when judging 'latest' for evaluation
         # Here, we mean most recently generated
-        
-        # `snapshotindex` refers to the file index of the snapshot generated
-        # The most recent snapshot index will be the number of snapshots - 1
-        # because it is an index
+        for snapshot in snapshots:
+            modified_time = os.path.getmtime(snapshot)
+            if modified_time > max_modified_time:
+                latest_snapshot_file = snapshot
+                latest_snapshot = int(latest_snapshot_file.stem[9:])
+                max_modified_time = modified_time
 
         # update snapshotindex in the config
-        latest_snapshot = len(snapshots)-1
-        dlc_config["snapshotindex"] = latest_snapshot
+        snapshotindex = snapshots.index(latest_snapshot_file)
+        
+        dlc_config["snapshotindex"] = snapshotindex
         edit_config(
             dlc_cfg_filepath,
-            {"snapshotindex": latest_snapshot},
+            {"snapshotindex": snapshotindex},
         )
 
         self.insert1(
             {**key, "latest_snapshot": latest_snapshot, "config_template": dlc_config}
         )
+
+

From 7c8ccdd5e421e3f218fb7a1d47459d344e580c80 Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Wed, 6 Dec 2023 14:54:07 -0600
Subject: [PATCH 08/10] update yaml safe loading

---
 element_deeplabcut/model.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/element_deeplabcut/model.py b/element_deeplabcut/model.py
index 1aa3902..88104eb 100644
--- a/element_deeplabcut/model.py
+++ b/element_deeplabcut/model.py
@@ -7,7 +7,7 @@
 import os
 import cv2
 import csv
-import yaml
+from ruamel.yaml import YAML
 import inspect
 import importlib
 import numpy as np
@@ -239,8 +239,9 @@ def extract_new_body_parts(cls, dlc_config: dict, verbose: bool = True):
                 ".yaml",
             ), f"dlc_config is neither dict nor filepath\n Check: {dlc_config_fp}"
             if dlc_config_fp.suffix in (".yml", ".yaml"):
+                yaml = YAML(typ="safe", pure=True)
                 with open(dlc_config_fp, "rb") as f:
-                    dlc_config = yaml.safe_load(f)
+                    dlc_config = yaml.load(f)
         # -- Check and insert new BodyPart --
         assert "bodyparts" in dlc_config, f"Found no bodyparts section in {dlc_config}"
         tracked_body_parts = cls.fetch("body_part")
@@ -381,8 +382,9 @@ def insert_new_model(
             "dlc_config is not a filepath" + f"\n Check: {dlc_config_fp}"
         )
         if dlc_config_fp.suffix in (".yml", ".yaml"):
+            yaml = YAML(typ="safe", pure=True)
             with open(dlc_config_fp, "rb") as f:
-                dlc_config = yaml.safe_load(f)
+                dlc_config = yaml.load(f)
         if isinstance(params, dict):
             dlc_config.update(params)
 

From b4a6ae0685b669741b2380c920b2df11305b02e5 Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Wed, 6 Dec 2023 16:09:48 -0600
Subject: [PATCH 09/10] update yaml safe load in dlc_reader

---
 element_deeplabcut/readers/dlc_reader.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/element_deeplabcut/readers/dlc_reader.py b/element_deeplabcut/readers/dlc_reader.py
index a7f6a32..d726454 100644
--- a/element_deeplabcut/readers/dlc_reader.py
+++ b/element_deeplabcut/readers/dlc_reader.py
@@ -4,7 +4,7 @@
 import pandas as pd
 from pathlib import Path
 import pickle
-import ruamel.yaml as yaml
+from ruamel.yaml import YAML
 from element_interface.utils import find_root_directory, dict_to_uuid
 from .. import model
 from ..model import get_dlc_root_data_dir
@@ -145,7 +145,8 @@ def yml(self):
         """json-structured config.yaml file contents"""
         if self._yml is None:
             with open(self.yml_path, "rb") as f:
-                self._yml = yaml.safe_load(f)
+                yaml = YAML(typ="safe", pure=True)
+                self._yml = yaml.load(f)
         return self._yml
 
     @property

From f53852dc096bcf3d72ad5a6ffc3ada0c479ad9ae Mon Sep 17 00:00:00 2001
From: Sidharth Hulyalkar <sidharth.hulyalkar@datajoint.com>
Date: Fri, 5 Jan 2024 13:20:14 -0800
Subject: [PATCH 10/10] update to create output_dir during pose estimation and
 update task

---
 element_deeplabcut/model.py | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)

diff --git a/element_deeplabcut/model.py b/element_deeplabcut/model.py
index 8a6e47f..c06f617 100644
--- a/element_deeplabcut/model.py
+++ b/element_deeplabcut/model.py
@@ -706,7 +706,14 @@ def make(self, key):
         task_mode, output_dir = (PoseEstimationTask & key).fetch1(
             "task_mode", "pose_estimation_output_dir"
         )
-
+        if not output_dir:
+            output_dir = PoseEstimationTask.infer_output_dir(
+                key, relative=True, mkdir=True
+            )
+            # update pose_estimation_output_dir
+            PoseEstimationTask.update1(
+                {**key, "pose_estimation_output_dir": output_dir.as_posix()}
+            )
         output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)
 
         # Triger PoseEstimation