Skip to content

Commit

Permalink
Add set_encoder_decoder method
Browse files Browse the repository at this point in the history
* Add transfer learning for the AutoClassifier.ipynb for transfer learning functionality
* Update tools.py
  • Loading branch information
jzsmoreno committed Feb 3, 2025
1 parent ccd42f2 commit aff011a
Show file tree
Hide file tree
Showing 3 changed files with 535 additions and 24 deletions.
436 changes: 436 additions & 0 deletions examples/transfer learning for the AutoClassifier.ipynb

Large diffs are not rendered by default.

117 changes: 95 additions & 22 deletions likelihood/models/deep/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,43 @@ def __init__(self, input_shape_parm, num_classes, units, activation, **kwargs):

def build(self, input_shape):
# Encoder with L2 regularization
self.encoder = tf.keras.Sequential(
[
tf.keras.layers.Dense(
units=self.units, activation=self.activation, kernel_regularizer=l2(self.l2_reg)
),
tf.keras.layers.Dense(
units=int(self.units / 2),
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
),
]
self.encoder = (
tf.keras.Sequential(
[
tf.keras.layers.Dense(
units=self.units,
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
),
tf.keras.layers.Dense(
units=int(self.units / 2),
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
),
]
)
if not self.encoder
else self.encoder
)

# Decoder with L2 regularization
self.decoder = tf.keras.Sequential(
[
tf.keras.layers.Dense(
units=self.units, activation=self.activation, kernel_regularizer=l2(self.l2_reg)
),
tf.keras.layers.Dense(
units=self.input_shape_parm,
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
),
]
self.decoder = (
tf.keras.Sequential(
[
tf.keras.layers.Dense(
units=self.units,
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
),
tf.keras.layers.Dense(
units=self.input_shape_parm,
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
),
]
)
if not self.decoder
else self.decoder
)

# Classifier with L2 regularization
Expand Down Expand Up @@ -174,6 +186,66 @@ def unfreeze_encoder_decoder(self):
for layer in self.decoder.layers:
layer.trainable = True

def set_encoder_decoder(self, source_model):
"""
Sets the encoder and decoder layers from another AutoClassifier instance,
ensuring compatibility in dimensions.
Parameters:
-----------
source_model : AutoClassifier
The source model to copy the encoder and decoder layers from.
Raises:
-------
ValueError
If the input shape or units of the source model do not match.
"""
if not isinstance(source_model, AutoClassifier):
raise ValueError("Source model must be an instance of AutoClassifier.")

# Check compatibility in input shape and units
if self.input_shape_parm != source_model.input_shape_parm:
raise ValueError(
f"Incompatible input shape. Expected {self.input_shape_parm}, got {source_model.input_shape_parm}."
)
if self.units != source_model.units:
raise ValueError(
f"Incompatible number of units. Expected {self.units}, got {source_model.units}."
)
self.encoder, self.decoder = tf.keras.Sequential(), tf.keras.Sequential()
# Copy the encoder layers
for i, layer in enumerate(source_model.encoder.layers):
if isinstance(layer, tf.keras.layers.Dense): # Make sure it's a Dense layer
dummy_input = tf.convert_to_tensor(tf.random.normal([1, layer.input_shape[1]]))
dense_layer = tf.keras.layers.Dense(
units=layer.units,
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
)
dense_layer.build(dummy_input.shape)
self.encoder.add(dense_layer)
# Set the weights correctly
self.encoder.layers[i].set_weights(layer.get_weights())
else:
raise ValueError(f"Layer type {type(layer)} not supported for copying.")

# Copy the decoder layers
for i, layer in enumerate(source_model.decoder.layers):
if isinstance(layer, tf.keras.layers.Dense): # Ensure it's a Dense layer
dummy_input = tf.convert_to_tensor(tf.random.normal([1, layer.input_shape[1]]))
dense_layer = tf.keras.layers.Dense(
units=layer.units,
activation=self.activation,
kernel_regularizer=l2(self.l2_reg),
)
dense_layer.build(dummy_input.shape)
self.decoder.add(dense_layer)
# Set the weights correctly
self.decoder.layers[i].set_weights(layer.get_weights())
else:
raise ValueError(f"Layer type {type(layer)} not supported for copying.")

def get_config(self):
config = {
"input_shape_parm": self.input_shape_parm,
Expand All @@ -198,6 +270,7 @@ def from_config(cls, config):
classifier_activation=config["classifier_activation"],
num_layers=config["num_layers"],
dropout=config["dropout"],
l2_reg=config["l2_reg"],
)


Expand Down
6 changes: 4 additions & 2 deletions likelihood/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,7 @@ def __init__(self) -> None:
def f_mean(self, y_true: np.ndarray, y_pred: np.ndarray, labels: List[int]) -> float:
F_vec = self._f1_score(y_true, y_pred, labels)
mean_f_measure = np.mean(F_vec)
mean_f_measure = np.around(mean_f_measure, decimals=4)

for label, f_measure in zip(labels, F_vec):
print(f"F-measure of label {label} -> {f_measure}")
Expand All @@ -1005,9 +1006,9 @@ def resp(self, y_true: np.ndarray, y_pred: np.ndarray, labels: List[int]) -> flo

def _summary_pred(self, y_true: np.ndarray, y_pred: np.ndarray, labels: List[int]) -> None:
count_mat = self._confu_mat(y_true, y_pred, labels)
print(" ", " | ".join(f"--{label}--" for label in labels))
print(" " * 6, " | ".join(f"--{label}--" for label in labels))
for i, label_i in enumerate(labels):
row = [f" {int(count_mat[i, j])} " for j in range(len(labels))]
row = [f" {int(count_mat[i, j]):5d} " for j in range(len(labels))]
print(f"--{label_i}--|", " | ".join(row))

def _f1_score(self, y_true: np.ndarray, y_pred: np.ndarray, labels: List[int]) -> np.ndarray:
Expand All @@ -1023,6 +1024,7 @@ def _f1_score(self, y_true: np.ndarray, y_pred: np.ndarray, labels: List[int]) -
count_mat.diagonal(), sum_rows, out=np.zeros_like(sum_rows), where=sum_rows != 0
)
f1_vec = 2 * ((precision * recall) / (precision + recall))
f1_vec = np.around(f1_vec, decimals=4)

return f1_vec

Expand Down

0 comments on commit aff011a

Please sign in to comment.