Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MoCo Version and MDLM params #632

Merged
merged 18 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sub-packages/bionemo-moco/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ MoCo currently supports the following continuous data interpolants:
### Discrete Data Interpolants
MoCo also supports the following discrete data interpolants:
- D3PM (Discrete Denoising Diffusion Probabilistic Models)
- MDLM (Markov Diffusion Language Models)
- MDLM (Masked Diffusion Language Models)
- DFM (Discrete Flow Matching)

### Useful Abstractions
Expand Down
1 change: 0 additions & 1 deletion sub-packages/bionemo-moco/VERSION

This file was deleted.

1 change: 1 addition & 0 deletions sub-packages/bionemo-moco/VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.0.1
162 changes: 42 additions & 120 deletions sub-packages/bionemo-moco/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
* [bionemo.moco.distributions.time](#mocodistributionstime)
* [bionemo.moco.distributions.time.beta](#mocodistributionstimebeta)
* [bionemo.moco.distributions.time.utils](#mocodistributionstimeutils)
* [bionemo.moco.schedules.discrete\_noise\_schedules](#mocoschedulesdiscrete_noise_schedules)
* [bionemo.moco.schedules.noise.continuous\_snr\_transforms](#mocoschedulesnoisecontinuous_snr_transforms)
* [bionemo.moco.schedules.noise.discrete\_noise\_schedules](#mocoschedulesnoisediscrete_noise_schedules)
* [bionemo.moco.schedules.noise](#mocoschedulesnoise)
Expand Down Expand Up @@ -903,113 +902,6 @@ Convert a float time value to a time index.

- `torch.Tensor` - A tensor of time indices corresponding to the input float time values.

<a id="mocoschedulesdiscrete_noise_schedules"></a>

# bionemo.moco.schedules.discrete\_noise\_schedules

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedule"></a>

## DiscreteNoiseSchedule Objects

```python
class DiscreteNoiseSchedule(ABC)
```

A base class for discrete schedules. No matter the definition this class returns objects using a unified direction of time.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedule__init__"></a>

#### \_\_init\_\_

```python
def __init__(nsteps: int, direction: TimeDirection)
```

Initialize the DiscreteNoiseSchedule.

**Arguments**:

- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedulegenerate_schedule"></a>

#### generate\_schedule

```python
def generate_schedule(nsteps: Optional[int] = None,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
```

Public wrapper to generate the time schedule as a tensor.

**Arguments**:

- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).


**Returns**:

- `Tensor` - A tensor of time steps + 1 unless full is False.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedulecalculate_derivative"></a>

#### calculate\_derivative

```python
def calculate_derivative(
nsteps: Optional[int] = None,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
```

Calculate the time derivative of the schedule.

**Arguments**:

- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).


**Returns**:

- `Tensor` - A tensor representing the time derivative of the schedule.


**Raises**:

- `NotImplementedError` - If the derivative calculation is not implemented for this schedule.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteCosineNoiseSchedule"></a>

## DiscreteCosineNoiseSchedule Objects

```python
class DiscreteCosineNoiseSchedule(DiscreteNoiseSchedule)
```

A cosine noise schedule for Diffusion Models.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteCosineNoiseSchedule__init__"></a>

#### \_\_init\_\_

```python
def __init__(nsteps: int, nu: Float = 1.0, s: Float = 0.008)
```

Initialize the CosineNoiseSchedule.

**Arguments**:

- `nsteps` _int_ - Number of time steps.
- `nu` _Optional[Float]_ - Hyperparameter for the cosine schedule (default is 1.0).
- `s` _Optional[Float]_ - Hyperparameter for the cosine schedule (default is 0.008).

<a id="mocoschedulesnoisecontinuous_snr_transforms"></a>

# bionemo.moco.schedules.noise.continuous\_snr\_transforms
Expand Down Expand Up @@ -2210,23 +2102,50 @@ certain times in the diffusion process.
#### step

```python
def step(logits, t, xt, dt) -> Tensor
def step(logits: Tensor,
t: Tensor,
xt: Tensor,
dt: Tensor,
temperature: float = 1.0) -> Tensor
```

Perform a single step of MDLM DDPM step.

**Arguments**:

- `logits` _Tensor_ - The input logits.
- `t` _float_ - The current time step.
- `t` _Tensor_ - The current time step.
- `xt` _Tensor_ - The current state.
- `dt` _float_ - The time step increment.
- `dt` _Tensor_ - The time step increment.
- `temperature` _float_ - Softmax temperature defaults to 1.0.


**Returns**:

- `Tensor` - The updated state.

<a id="mocointerpolantscontinuous_timediscretemdlmMDLMget_num_steps_confidence"></a>

#### get\_num\_steps\_confidence

```python
def get_num_steps_confidence(xt: Tensor)
```

Calculate the maximum number of steps with confidence.

This method computes the maximum count of occurrences where the input tensor `xt` matches the `mask_index`
along the last dimension (-1). The result is returned as a single float value.

**Arguments**:

- `xt` _Tensor_ - Input tensor to evaluate against the mask index.


**Returns**:

- `float` - The maximum number of steps with confidence (i.e., matching the mask index).

<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_confidence"></a>

#### step\_confidence
Expand All @@ -2238,12 +2157,13 @@ def step_confidence(logits: Tensor,
num_steps: int,
logit_temperature: float = 1.0,
randomness: float = 1.0,
confidence_temperature: float = 1.0) -> Tensor
confidence_temperature: float = 1.0,
num_tokens_unmask: int = 1) -> Tensor
```

Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.

Method taken from GenMol Seul et al.
Method taken from GenMol Lee et al. https://arxiv.org/abs/2501.06158

**Arguments**:

Expand All @@ -2254,11 +2174,12 @@ Method taken from GenMol Seul et al.
- `logit_temperature` - Temperature for softmax over logits
- `randomness` - Scale for Gumbel noise
- `confidence_temperature` - Temperature for Gumbel confidence
- `num_tokens_unmask` - number of tokens to unmask each step


**Returns**:

Updated input sequence xt
Updated input sequence xt unmasking num_tokens_unmask token each step.

<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_argmax"></a>

Expand Down Expand Up @@ -3630,11 +3551,9 @@ def step_score_stochastic(model_out: Tensor,
center: Bool = False)
```

Perform a single ODE step integration using Euler method.

d x_t = [v(x_t, t) + g(t) * s(x_t, t) * sc_score_scale] dt + \sqrt{2 * g(t) * temperature} dw_t.
Perform a single SDE step integration using a score-based Langevin update.

At the moment we do not scale the vector field v but this can be added with sc_score_scale.
d x_t = [v(x_t, t) + g(t) * s(x_t, t) * score_temperature] dt + \sqrt{2 * g(t) * noise_temperature} dw_t.

**Arguments**:

Expand All @@ -3643,7 +3562,7 @@ At the moment we do not scale the vector field v but this can be added with sc_s
- `dt` _Tensor_ - The time step size.
- `t` _Tensor, optional_ - The current time. Defaults to None.
- `mask` _Optional[Tensor], optional_ - A mask to apply to the model output. Defaults to None.
- `gt_mode` _str, optional_ - The mode for the gt function. Defaults to "1/t".
- `gt_mode` _str, optional_ - The mode for the gt function. Defaults to "tan".
- `gt_p` _Float, optional_ - The parameter for the gt function. Defaults to 1.0.
- `gt_clamp` - (Float, optional): Upper limit of gt term. Defaults to None.
- `score_temperature` _Float, optional_ - The temperature for the score part of the step. Defaults to 1.0.
Expand Down Expand Up @@ -4293,7 +4212,10 @@ Do one step integration.

**Notes**:

The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.
The temperature parameter controls the level of randomness in the sampling process.
A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2)
result in less random and more deterministic samples. This can be useful for tasks
that require more control over the generation process.

Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -334,19 +334,17 @@ def step_score_stochastic(
t_lim_ode: Float = 0.99,
center: Bool = False,
):
r"""Perform a single ODE step integration using Euler method.
r"""Perform a single SDE step integration using a score-based Langevin update.

d x_t = [v(x_t, t) + g(t) * s(x_t, t) * sc_score_scale] dt + \sqrt{2 * g(t) * temperature} dw_t.

At the moment we do not scale the vector field v but this can be added with sc_score_scale.
d x_t = [v(x_t, t) + g(t) * s(x_t, t) * score_temperature] dt + \sqrt{2 * g(t) * noise_temperature} dw_t.

Args:
model_out (Tensor): The output of the model at the current time step.
xt (Tensor): The current intermediate state.
dt (Tensor): The time step size.
t (Tensor, optional): The current time. Defaults to None.
mask (Optional[Tensor], optional): A mask to apply to the model output. Defaults to None.
gt_mode (str, optional): The mode for the gt function. Defaults to "1/t".
gt_mode (str, optional): The mode for the gt function. Defaults to "tan".
gt_p (Float, optional): The parameter for the gt function. Defaults to 1.0.
gt_clamp: (Float, optional): Upper limit of gt term. Defaults to None.
score_temperature (Float, optional): The temperature for the score part of the step. Defaults to 1.0.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,15 @@ def _subs_parameterization(self, logits: Tensor, xt: Tensor) -> Tensor:
logprobs[unmasked_indices, xt[unmasked_indices]] = 0 # Unmasked token remains unchanged
return logprobs

def step(self, logits, t, xt, dt) -> Tensor:
def step(self, logits: Tensor, t: Tensor, xt: Tensor, dt: Tensor, temperature: float = 1.0) -> Tensor:
nvdreidenbach marked this conversation as resolved.
Show resolved Hide resolved
"""Perform a single step of MDLM DDPM step.

Parameters:
logits (Tensor): The input logits.
t (float): The current time step.
t (Tensor): The current time step.
xt (Tensor): The current state.
dt (float): The time step increment.
dt (Tensor): The time step increment.
temperature (float): Softmax temperature defaults to 1.0.

Returns:
Tensor: The updated state.
Expand All @@ -223,11 +224,13 @@ def step(self, logits, t, xt, dt) -> Tensor:
alpha_s = pad_like(alpha_s, logits)
p_mask_s = pad_like(p_mask_s, logits)
# Apply subs parameterization
log_p_x0 = self._subs_parameterization(logits, xt)
log_p_x0 = self._subs_parameterization(logits, xt) / temperature
if p_mask_s.ndim != log_p_x0.ndim:
raise ValueError(f"Dimension Mistmatch {p_mask_s.shape} {log_p_x0.shape}")
# Equation 6 from MDLM
prob_s_given_t = log_p_x0.exp() * (alpha_s - alpha_t) # righthand side (alpha_s - alpha_t)*x
# Equation 7 from MDLM
prob_s_given_t = log_p_x0.exp() * (
alpha_s - alpha_t
) # righthand side (alpha_s - alpha_t)*x = (1 - alpha_t - (1 - alpha_s)) * x
prob_s_given_t[..., self.mask_index] = p_mask_s[..., 0] # lefthand side (1 - alpha_s)*M
sampled_x = self._sample_categorical(prob_s_given_t)
carry_over_unmask = (xt != self.mask_index).to(xt.dtype)
Expand All @@ -252,6 +255,20 @@ def _sample_categorical(self, categorical_probs: Tensor) -> Tensor:
scaled_proability = categorical_probs / gumbel_norm
return scaled_proability.argmax(dim=-1)

def get_num_steps_confidence(self, xt: Tensor):
"""Calculate the maximum number of steps with confidence.

This method computes the maximum count of occurrences where the input tensor `xt` matches the `mask_index`
along the last dimension (-1). The result is returned as a single float value.

Args:
xt (Tensor): Input tensor to evaluate against the mask index.

Returns:
float: The maximum number of steps with confidence (i.e., matching the mask index).
"""
return (xt == self.mask_index).sum(-1).max().item()

def step_confidence(
self,
logits: Tensor,
Expand All @@ -261,10 +278,11 @@ def step_confidence(
logit_temperature: float = 1.0,
randomness: float = 1.0,
confidence_temperature: float = 1.0,
num_tokens_unmask: int = 1,
) -> Tensor:
"""Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.

Method taken from GenMol Seul et al.
Method taken from GenMol Lee et al. https://arxiv.org/abs/2501.06158

Args:
logits: Predicted logits
Expand All @@ -274,9 +292,10 @@ def step_confidence(
logit_temperature: Temperature for softmax over logits
randomness: Scale for Gumbel noise
confidence_temperature: Temperature for Gumbel confidence
num_tokens_unmask: number of tokens to unmask each step

Returns:
Updated input sequence xt
Updated input sequence xt unmasking num_tokens_unmask token each step.
"""
if xt.ndim > 3:
raise NotImplementedError(
Expand Down Expand Up @@ -304,7 +323,7 @@ def step_confidence(
confidence[~mask] = -torch.inf

# choose the predicted token with the highest confidence
confidence_threshold, idx_mask = torch.topk(confidence, k=1, dim=-1)
confidence_threshold, idx_mask = torch.topk(confidence, k=num_tokens_unmask, dim=-1)
confidence_threshold = confidence_threshold[:, -1].unsqueeze(-1)

# replace the chosen tokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,10 @@ def step_noise(
temperature (Float, optional): The temperature parameter for low temperature sampling. Defaults to 1.0.

Note:
The temperature parameter controls the level of randomness in the sampling process. A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2) result in less random and more deterministic samples. This can be useful for tasks that require more control over the generation process.
The temperature parameter controls the level of randomness in the sampling process.
A temperature of 1.0 corresponds to standard diffusion sampling, while lower temperatures (e.g. 0.5, 0.2)
result in less random and more deterministic samples. This can be useful for tasks
that require more control over the generation process.

Note for discrete time we sample from [T-1, ..., 1, 0] for T steps so we sample t = 0 hence the mask.
For continuous time we start from [1, 1 -dt, ..., dt] for T steps where s = t - 1 when t = 0 i.e dt is then 0
Expand All @@ -385,9 +388,10 @@ def step_noise(
recip_sqrt_alpha_t = pad_like(recip_sqrt_alpha_t, xt)
var = pad_like(var, xt)

x_next = recip_sqrt_alpha_t * (xt - eps_factor * eps_hat) + nonzero_mask * var.sqrt() * torch.randn_like(
eps_hat
).to(model_out.device)
x_next = (
nvdreidenbach marked this conversation as resolved.
Show resolved Hide resolved
recip_sqrt_alpha_t * (xt - eps_factor * eps_hat)
+ nonzero_mask * var.sqrt() * torch.randn_like(eps_hat).to(model_out.device) * temperature
)
x_next = self.clean_mask_center(x_next, mask, center)
return x_next

Expand Down
Loading
Loading