diff --git a/openfold/utils/all_atom_multimer.py b/openfold/utils/all_atom_multimer.py index 6ffd7ac1..e414d4d9 100644 --- a/openfold/utils/all_atom_multimer.py +++ b/openfold/utils/all_atom_multimer.py @@ -56,28 +56,32 @@ def atom14_to_atom37( def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): - """Convert Atom37 positions to Atom14 positions.""" + """Convert atom37 positions to atom14 positions.""" residx_atom14_to_atom37 = get_rc_tensor( - rc.RESTYPE_ATOM14_TO_ATOM37, aatype - ) - no_batch_dims = len(aatype.shape) + rc.RESTYPE_ATOM14_TO_ATOM37, + aatype # (..., num_residues) + ) # (..., num_residues, 14) + no_batch_dims = len(aatype.shape) - 1 atom14_mask = tensor_utils.batched_gather( - all_atom_mask, - residx_atom14_to_atom37, + all_atom_mask, # (..., num_residues, 37) + residx_atom14_to_atom37, # (..., num_residues, 14) dim=no_batch_dims + 1, no_batch_dims=no_batch_dims + 1, - ).to(all_atom_pos.dtype) + ).to(all_atom_pos.dtype) # (..., num_residues, 14) # create a mask for known groundtruth positions - atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype) + atom14_mask *= get_rc_tensor(rc.RESTYPE_ATOM14_MASK, aatype) # gather the groundtruth positions atom14_positions = tensor_utils.batched_gather( - all_atom_pos, - residx_atom14_to_atom37, + all_atom_pos, # (..., num_residues, 37, 3) + residx_atom14_to_atom37, dim=no_batch_dims + 1, no_batch_dims=no_batch_dims + 1, - ), - atom14_positions = atom14_mask * atom14_positions - return atom14_positions, atom14_mask + ) + atom14_positions = atom14_mask[..., None] * atom14_positions + return ( + atom14_positions, # (..., num_residues, 14, 3) + atom14_mask # (..., num_residues, 14) + ) def get_alt_atom14(aatype, positions: torch.Tensor, mask):