Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MoCo Version and MDLM params #632

Merged
merged 18 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion sub-packages/bionemo-moco/VERSION

This file was deleted.

1 change: 1 addition & 0 deletions sub-packages/bionemo-moco/VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.0.1
136 changes: 26 additions & 110 deletions sub-packages/bionemo-moco/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
* [bionemo.moco.distributions.time](#mocodistributionstime)
* [bionemo.moco.distributions.time.beta](#mocodistributionstimebeta)
* [bionemo.moco.distributions.time.utils](#mocodistributionstimeutils)
* [bionemo.moco.schedules.discrete\_noise\_schedules](#mocoschedulesdiscrete_noise_schedules)
* [bionemo.moco.schedules.noise.continuous\_snr\_transforms](#mocoschedulesnoisecontinuous_snr_transforms)
* [bionemo.moco.schedules.noise.discrete\_noise\_schedules](#mocoschedulesnoisediscrete_noise_schedules)
* [bionemo.moco.schedules.noise](#mocoschedulesnoise)
Expand Down Expand Up @@ -903,113 +902,6 @@ Convert a float time value to a time index.

- `torch.Tensor` - A tensor of time indices corresponding to the input float time values.

<a id="mocoschedulesdiscrete_noise_schedules"></a>

# bionemo.moco.schedules.discrete\_noise\_schedules

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedule"></a>

## DiscreteNoiseSchedule Objects

```python
class DiscreteNoiseSchedule(ABC)
```

A base class for discrete schedules. No matter the definition this class returns objects using a unified direction of time.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedule__init__"></a>

#### \_\_init\_\_

```python
def __init__(nsteps: int, direction: TimeDirection)
```

Initialize the DiscreteNoiseSchedule.

**Arguments**:

- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
- `direction` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedulegenerate_schedule"></a>

#### generate\_schedule

```python
def generate_schedule(nsteps: Optional[int] = None,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
```

Public wrapper to generate the time schedule as a tensor.

**Arguments**:

- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).


**Returns**:

- `Tensor` - A tensor of time steps + 1 unless full is False.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteNoiseSchedulecalculate_derivative"></a>

#### calculate\_derivative

```python
def calculate_derivative(
nsteps: Optional[int] = None,
device: Union[str, torch.device] = "cpu",
synchronize: Optional[TimeDirection] = None) -> Tensor
```

Calculate the time derivative of the schedule.

**Arguments**:

- `nsteps` _Optional[int]_ - Number of time steps. If None, uses the value from initialization.
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").
- `synchronize` _Optional[str]_ - TimeDirection to synchronize the schedule with. If the schedule is defined with a different direction, this parameter allows to flip the direction to match the specified one (default is None).


**Returns**:

- `Tensor` - A tensor representing the time derivative of the schedule.


**Raises**:

- `NotImplementedError` - If the derivative calculation is not implemented for this schedule.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteCosineNoiseSchedule"></a>

## DiscreteCosineNoiseSchedule Objects

```python
class DiscreteCosineNoiseSchedule(DiscreteNoiseSchedule)
```

A cosine noise schedule for Diffusion Models.

<a id="mocoschedulesdiscrete_noise_schedulesDiscreteCosineNoiseSchedule__init__"></a>

#### \_\_init\_\_

```python
def __init__(nsteps: int, nu: Float = 1.0, s: Float = 0.008)
```

Initialize the CosineNoiseSchedule.

**Arguments**:

- `nsteps` _int_ - Number of time steps.
- `nu` _Optional[Float]_ - Hyperparameter for the cosine schedule (default is 1.0).
- `s` _Optional[Float]_ - Hyperparameter for the cosine schedule (default is 0.008).

<a id="mocoschedulesnoisecontinuous_snr_transforms"></a>

# bionemo.moco.schedules.noise.continuous\_snr\_transforms
Expand Down Expand Up @@ -2227,6 +2119,28 @@ Perform a single step of MDLM DDPM step.

- `Tensor` - The updated state.

<a id="mocointerpolantscontinuous_timediscretemdlmMDLMget_num_steps_confidence"></a>

#### get\_num\_steps\_confidence

```python
def get_num_steps_confidence(xt: Tensor)
```

Calculate the maximum number of steps with confidence.

This method computes the maximum count of occurrences where the input tensor `xt` matches the `mask_index`
along the last dimension (-1). The result is returned as a single float value.

**Arguments**:

- `xt` _Tensor_ - Input tensor to evaluate against the mask index.


**Returns**:

- `float` - The maximum number of steps with confidence (i.e., matching the mask index).

<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_confidence"></a>

#### step\_confidence
Expand All @@ -2238,7 +2152,8 @@ def step_confidence(logits: Tensor,
num_steps: int,
logit_temperature: float = 1.0,
randomness: float = 1.0,
confidence_temperature: float = 1.0) -> Tensor
confidence_temperature: float = 1.0,
num_tokens_unmask: int = 1) -> Tensor
```

Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.
Expand All @@ -2254,11 +2169,12 @@ Method taken from GenMol Seul et al.
- `logit_temperature` - Temperature for softmax over logits
- `randomness` - Scale for Gumbel noise
- `confidence_temperature` - Temperature for Gumbel confidence
- `num_tokens_unmask` - number of tokens to unmask each step


**Returns**:

Updated input sequence xt
Updated input sequence xt unmasking num_tokens_unmask token each step.

<a id="mocointerpolantscontinuous_timediscretemdlmMDLMstep_argmax"></a>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,20 @@ def _sample_categorical(self, categorical_probs: Tensor) -> Tensor:
scaled_proability = categorical_probs / gumbel_norm
return scaled_proability.argmax(dim=-1)

def get_num_steps_confidence(self, xt: Tensor):
"""Calculate the maximum number of steps with confidence.

This method computes the maximum count of occurrences where the input tensor `xt` matches the `mask_index`
along the last dimension (-1). The result is returned as a single float value.

Args:
xt (Tensor): Input tensor to evaluate against the mask index.

Returns:
float: The maximum number of steps with confidence (i.e., matching the mask index).
"""
return (xt == self.mask_index).sum(-1).max().item()

def step_confidence(
self,
logits: Tensor,
Expand All @@ -261,6 +275,7 @@ def step_confidence(
logit_temperature: float = 1.0,
randomness: float = 1.0,
confidence_temperature: float = 1.0,
num_tokens_unmask: int = 1,
) -> Tensor:
"""Update the input sequence xt by sampling from the predicted logits and adding Gumbel noise.

Expand All @@ -274,9 +289,10 @@ def step_confidence(
logit_temperature: Temperature for softmax over logits
randomness: Scale for Gumbel noise
confidence_temperature: Temperature for Gumbel confidence
num_tokens_unmask: number of tokens to unmask each step

Returns:
Updated input sequence xt
Updated input sequence xt unmasking num_tokens_unmask token each step.
"""
if xt.ndim > 3:
raise NotImplementedError(
Expand Down Expand Up @@ -304,7 +320,7 @@ def step_confidence(
confidence[~mask] = -torch.inf

# choose the predicted token with the highest confidence
confidence_threshold, idx_mask = torch.topk(confidence, k=1, dim=-1)
confidence_threshold, idx_mask = torch.topk(confidence, k=num_tokens_unmask, dim=-1)
confidence_threshold = confidence_threshold[:, -1].unsqueeze(-1)

# replace the chosen tokens
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def test_mdlm_step_confidence(mdlm, device):
xt = data.clone()
xt[:, 0] = noise[:, 0]
time = time * 0 + 2 / 100
_ = mdlm.get_num_steps_confidence(xt)
nvdreidenbach marked this conversation as resolved.
Show resolved Hide resolved
next_xt = mdlm.step_confidence(model_out, xt, curr_step=90, num_steps=100)
# Assert shapes
assert next_xt.shape == data.shape
Expand Down
Loading
Loading