Skip to content

Commit

Permalink
improved error messages for internal failure to infer TP overlap opti…
Browse files Browse the repository at this point in the history
…ons in te.Linear

Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 17, 2024
1 parent 1d9b943 commit 360c127
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,11 +928,18 @@ def __init__(
assert ub_name is not None, f"Userbuffer name [string] is not set."
self.ub_name = ub_name

assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), "Internal TE error!"
assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), "Internal TE error!"
assert not (
self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad)
), "Internal TE error!"
assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), (
"Failed to infer TP overlap options! "
+ "Forward pass cannot do AG+GEMM and GEMM+RS at the same time."
)
assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), (
"Failed to infer TP overlap options! "
+ "Backward pass cannot do AG+DGRAD and DGRAD+RS at the same time."
)
assert not self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad), (
"Failed to infer TP overlap options! "
+ "Backward pass cannot do DGRAD+RS and bulk overlaps at the same time."
)

self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
Expand Down

0 comments on commit 360c127

Please sign in to comment.