Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch.jit.script #91

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 71 additions & 64 deletions docs/notebooks/introduction.ipynb

Large diffs are not rendered by default.

37 changes: 13 additions & 24 deletions docs/notebooks/momentum.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "c216fa33-de09-4be2-82cc-83cb73db3a42",
"metadata": {},
"outputs": [],
Expand All @@ -230,7 +230,10 @@
" 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False\n",
")\n",
"# Output log hazards\n",
"resnet.fc = torch.nn.Linear(in_features=resnet.fc.in_features, out_features=1)"
"resnet.fc = torch.nn.Linear(in_features=resnet.fc.in_features, out_features=1)\n",
"\n",
"# Compile model\n",
"# resnet = torch.compile(resnet)"
]
},
{
Expand Down Expand Up @@ -273,7 +276,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "1e7a2c7e-a1ef-42fa-ba74-1d33a1dcf2f3",
"metadata": {},
"outputs": [],
Expand All @@ -284,7 +287,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "3f577acf-a821-41a4-8544-318617755d1e",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -314,7 +317,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"id": "430079cc-4fad-4da2-8ea5-aa904c41ec0e",
"metadata": {},
"outputs": [
Expand All @@ -323,10 +326,10 @@
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params\n",
"---------------------------------\n",
"0 | model | ResNet | 11.2 M\n",
"---------------------------------\n",
" | Name | Type | Params\n",
"------------------------------------------\n",
"0 | model | OptimizedModule | 11.2 M\n",
"------------------------------------------\n",
"11.2 M Trainable params\n",
"0 Non-trainable params\n",
"11.2 M Total params\n",
Expand All @@ -337,21 +340,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 11/11 [02:19<00:00, 0.08it/s, loss_step=226.0, val_loss_step=257.0, cindex_step=0.672, val_loss_epoch=256.0, cindex_epoch=0.676, loss_epoch=234.0]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=2` reached.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1: 100%|██████████| 11/11 [02:19<00:00, 0.08it/s, loss_step=226.0, val_loss_step=257.0, cindex_step=0.672, val_loss_epoch=256.0, cindex_epoch=0.676, loss_epoch=234.0]\n"

]
}
],
Expand Down
223 changes: 83 additions & 140 deletions src/torchsurv/loss/cox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,88 @@

import torch

from torchsurv.tools.validate_data import validate_loss


@torch.jit.script
def _partial_likelihood_cox(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
in the absence of ties in event time.
"""
log_hz_flipped = log_hz_sorted.flip(0)
log_denominator = torch.logcumsumexp(log_hz_flipped, dim=0).flip(0)
return (log_hz_sorted - log_denominator)[event_sorted]


@torch.jit.script
def _partial_likelihood_efron(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
time_unique: torch.Tensor,
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Efron's method to handle ties in event time.
"""
J = len(time_unique)

H = [
torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0]
for j in range(J)
]
R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)]

# Calculate the length of each element in H and store it in a tensor
m = torch.tensor([len(h) for h in H])

# Create a boolean tensor indicating whether each element in H has a length greater than 0
include = torch.tensor([len(h) > 0 for h in H])

log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H])

denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R])
denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H])

log_denominator_efron = torch.zeros(J, device=log_hz_sorted.device)
for j in range(J):
mj = int(m[j].item())
for l in range(1, mj + 1):
log_denominator_efron[j] += torch.log(
denominator_naive[j] - (l - 1) / float(m[j]) * denominator_ties[j]
)
return (log_nominator - log_denominator_efron)[include]


@torch.jit.script
def _partial_likelihood_breslow(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
):
"""
Compute the partial likelihood using Breslow's method for Cox proportional hazards model.

Args:
log_hz_sorted (torch.Tensor): Log hazard rates sorted by time.
event_sorted (torch.Tensor): Binary tensor indicating if the event occurred (1) or was censored (0), sorted by time.
time_sorted (torch.Tensor): Event or censoring times sorted in ascending order.

Returns:
torch.Tensor: The partial likelihood for the observed events.
"""
N = len(time_sorted)
R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)]
log_denominator = torch.stack(
[torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)]
)

return (log_hz_sorted - log_denominator)[event_sorted]


@torch.jit.script
def neg_partial_log_likelihood(
log_hz: torch.Tensor,
event: torch.Tensor,
Expand Down Expand Up @@ -118,9 +199,9 @@ def neg_partial_log_likelihood(
"""

if checks:
_check_inputs(log_hz, event, time)
validate_loss(log_hz, event, time, model_type="cox")

if any([event.sum() == 0, len(log_hz.size()) == 0]):
if any([event.sum().item() == 0, len(log_hz.size()) == 0]):
warnings.warn("No events OR single sample. Returning zero loss for the batch")
return torch.tensor(0.0, requires_grad=True)

Expand Down Expand Up @@ -168,144 +249,6 @@ def neg_partial_log_likelihood(
return loss


def _partial_likelihood_cox(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
) -> torch.Tensor:
r"""Calculate the partial log likelihood for the Cox proportional hazards model
in the absence of ties in event time.

Args:
log_hz_sorted (torch.Tensor, float):
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
event_sorted (torch.Tensor, bool):
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.

Returns:
(torch.tensor, float):
Vector of the partial log likelihoods.

Note:
Let :math:`\tau_1 < \tau_2 < \cdots < \tau_N`
be the ordered times and let :math:`R(\tau_i) = \{ j: \tau_j \geq \tau_i\}`
be the risk set at :math:`\tau_i`. The partial log likelihood is defined as:

.. math::

pll = \sum_{i: \: \delta_i = 1} \left(\log \theta_i - \log\left(\sum_{j \in R(\tau_i)} \theta_j \right) \right)
"""
log_denominator = torch.logcumsumexp(log_hz_sorted.flip(0), dim=0).flip(0)

return (log_hz_sorted - log_denominator)[event_sorted]


def _partial_likelihood_efron(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
time_unique: torch.Tensor,
) -> torch.Tensor:
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Efron's method to handle ties in event time.

Args:
log_hz_sorted (torch.Tensor, float):
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
event_sorted (torch.Tensor, bool):
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
time_sorted (torch.Tensor):
Time-to-event values sorted in order.
time_unique (torch.Tensor):
Set of unique time-to-event values.
Returns:
(torch.tensor, float):
Vector of partial log likelihood estimated using Efron's method.
"""
J = len(time_unique)

H = [
torch.where((time_sorted == time_unique[j]) & (event_sorted == 1))[0]
for j in range(J)
]
R = [torch.where(time_sorted >= time_unique[j])[0] for j in range(J)]

m = torch.tensor([len(h) for h in H])
include = torch.tensor([len(h) > 0 for h in H])

log_nominator = torch.stack([torch.sum(log_hz_sorted[h]) for h in H])

denominator_naive = torch.stack([torch.sum(torch.exp(log_hz_sorted[r])) for r in R])
denominator_ties = torch.stack([torch.sum(torch.exp(log_hz_sorted[h])) for h in H])

log_denominator_efron = torch.zeros(J).to(log_hz_sorted.device)
for j in range(J):
for l in range(1, m[j] + 1):
log_denominator_efron[j] += torch.log(
denominator_naive[j] - (l - 1) / m[j] * denominator_ties[j]
)

return (log_nominator - log_denominator_efron)[include]


def _partial_likelihood_breslow(
log_hz_sorted: torch.Tensor,
event_sorted: torch.Tensor,
time_sorted: torch.Tensor,
):
"""Calculate the partial log likelihood for the Cox proportional hazards model
using Breslow's method to handle ties in event time.

Args:
log_hz_sorted (torch.Tensor, float):
Log relative hazard of length n_samples, ordered by time-to-event or censoring.
event_sorted (torch.Tensor, bool):
Event indicator of length n_samples (= True if event occured), ordered by time-to-event or censoring.
time_sorted (torch.Tensor):
Time-to-event values sorted in order.

Returns:
(torch.tensor, float):
Vector containing partial log likelihood estimated using Breslow's method.
"""
N = len(time_sorted)

R = [torch.where(time_sorted >= time_sorted[i])[0] for i in range(N)]
log_denominator = torch.tensor(
[torch.logsumexp(log_hz_sorted[R[i]], dim=0) for i in range(N)]
)

return (log_hz_sorted - log_denominator)[event_sorted]


def _check_inputs(log_hz: torch.Tensor, event: torch.Tensor, time: torch.Tensor):
if not isinstance(log_hz, torch.Tensor):
raise TypeError("Input 'log_hz' must be a tensor.")

if not isinstance(event, torch.Tensor):
raise TypeError("Input 'event' must be a tensor.")

if not isinstance(time, torch.Tensor):
raise TypeError("Input 'time' must be a tensor.")

if len(log_hz) != len(event):
raise ValueError(
"Length mismatch: 'log_hz' and 'event' must have the same length."
)

if len(time) != len(event):
raise ValueError(
"Length mismatch: 'time' must have the same length as 'event'."
)

if any(val < 0 for val in time):
raise ValueError("Invalid values: All elements in 'time' must be non-negative.")

if any(val not in [True, False, 0, 1] for val in event):
raise ValueError(
"Invalid values: 'event' must contain only boolean values (True/False or 1/0)"
)


if __name__ == "__main__":
import doctest

Expand Down
32 changes: 19 additions & 13 deletions src/torchsurv/loss/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ def forward(

"""

estimate_q = self.online(inputs)
for estimate in zip(estimate_q, event, time):
self.memory_q.append(self.survtuple(*list(estimate)))
online_estimate = self.online(inputs)
for estimate in zip(online_estimate, event, time):
self.memory_q.append(self.survtuple(*estimate))
loss = self._bank_loss()
with torch.no_grad():
self._update_momentum_encoder()
estimate_k = self.target(inputs)
for estimate in zip(estimate_k, event, time):
self.memory_k.append(self.survtuple(*list(estimate)))
target_estimate = self.target(inputs)
for estimate in zip(target_estimate, event, time):
self.memory_k.append(self.survtuple(*estimate))
return loss

@torch.no_grad() # deactivates autograd
Expand Down Expand Up @@ -187,27 +187,33 @@ def infer(self, inputs: torch.Tensor) -> torch.Tensor:
return self.target(inputs)

def _bank_loss(self) -> torch.Tensor:
"""computer the negative loss likelyhood from memory bank"""
"""compute the negative log-likelihood from memory bank"""

# Combine current batch and momentum
bank = self.memory_k + self.memory_q
assert all(
x in bank[0]._fields for x in ["estimate", "event", "time"]
), "All fields must be present"
return self.loss(
torch.stack([mem.estimate.cpu() for mem in bank]).squeeze(),
torch.stack([mem.event.cpu() for mem in bank]).squeeze(),
torch.stack([mem.time.cpu() for mem in bank]).squeeze(),
)
log_estimates = torch.stack([mem.estimate.cpu() for mem in bank]).squeeze()
events = torch.stack([mem.event.cpu() for mem in bank]).squeeze()
times = torch.stack([mem.time.cpu() for mem in bank]).squeeze()
return self.loss(log_estimates, events, times)

@torch.no_grad()
def _update_momentum_encoder(self):
"""Exponantial moving average"""
"""Exponential moving average"""
for param_b, param_m in zip(self.online.parameters(), self.target.parameters()):
param_m.data = param_m.data * self.rate + param_b.data * (1.0 - self.rate)

@torch.no_grad()
def _init_encoder_k(self):
"""
Initialize the target network (encoder_k) with the parameters of the online network (encoder_q).
The requires_grad attribute of the target network parameters is set to False to prevent gradient updates during training,
ensuring that the target network remains a stable reference point.
This method uses the `copy_` method to copy the parameters from the online network to the target network
and sets the requires_grad attribute of the target network parameters to False to prevent gradient updates.
"""
for param_q, param_k in zip(self.online.parameters(), self.target.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
Expand Down
Loading
Loading