Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Add THD + SWA unit tests #1390

Merged
merged 3 commits into from
Jan 8, 2025
Merged

Conversation

zlsh80826
Copy link
Collaborator

@zlsh80826 zlsh80826 commented Jan 6, 2025

Description

  • Adds THD + SWA unit tests
  • Generalizes make_swa_mask by accepting pos_q and pos_kv instead of seqlen, making it work for THD/non-THD and both top-left and bottom-right alignments.
    • Previously, make_swa_mask took seqlen and relied on an attn_mask_type (e.g., CAUSAL_MASK or CAUSAL_BOTTOM_RIGHT_MASK) to differentiate top-left vs. bottom-right alignments.
    • Now, by providing pos_q and pos_kv directly, we can generate the same patterns without a separate attn_mask_type.

For examples, consider non-THD case with q_seqlen=2, kv_seqlen=4.
In the past, we can pass q_seqlen=2, kv_seqlen=4, attn_mask=AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK for bottom-right alignment, window_size=(-1, 0), which generates mask like

1 1 1 0
1 1 1 1

And passing attn_mask=AttnMaskType.CAUSAL_MASK for top-left alignment

1 0 0 0
1 1 0 0

Now, we can do the same thing without passing attn_mask type by passing appropriate pos_q and pos_kv:

# make_swa_mask(jnp.asarray([[2, 3]]), jnp.asarray([[0, 1, 2, 3]]), window_size=(-1, 0))
1 1 1 0
1 1 1 1
# make_swa_mask(jnp.asarray([[0, 1]]), jnp.asarray([[0, 1, 2, 3]]), window_size=(-1, 0))
1 0 0 0
1 1 0 0

Besides, the new make_swa_mask can support the complicated THD case even with the reordering

>>> segment_ids_q = jnp.asarray([[1, 1, 1, 2, 2, 2]])
>>> segment_ids_kv = jnp.asarray([[1, 1, 1, 1, 2, 2, 2, 2]])
>>> segment_pos_q = jnp.asarray([[0, 1, 2, 2, 1, 0]])
>>> segment_pos_kv = jnp.asarray([[0, 1, 2, 3, 3, 2, 1, 0]])
>>> swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size=(2, 0))
>>> segment_mask = make_attention_mask(segment_ids_q, segment_ids_kv, jnp.equal)
>>> mask = combine_masks(swa_mask, segment_mask)
>>> mask
Array([[[[1., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0., 0.],
         [0., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 1., 0.],
         [0., 0., 0., 0., 0., 0., 1., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.]]]], dtype=float32)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Adds THD + SWA unit tests
  • Generalizes make_swa_mask by accepting pos_q and pos_kv instead of seqlen

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@zlsh80826
Copy link
Collaborator Author

/te-ci jax L1

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I liked the last example in your PR description. I suppose it'll be helpful for the ring attention or context parallelism implementation when SWA is also in place.

Copy link
Collaborator

@huanghua1994 huanghua1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The examples are helpful!

@zlsh80826 zlsh80826 merged commit b898cbe into NVIDIA:main Jan 8, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants