From 360c1276c3e1aee0d05dc12d336c5fff365e78f0 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 17 Dec 2024 20:47:43 +0000 Subject: [PATCH] improved error messages for internal failure to infer TP overlap options in te.Linear Signed-off-by: Alp Dener --- transformer_engine/pytorch/module/linear.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 48608b441d..bc36571fe8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -928,11 +928,18 @@ def __init__( assert ub_name is not None, f"Userbuffer name [string] is not set." self.ub_name = ub_name - assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), "Internal TE error!" - assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), "Internal TE error!" - assert not ( - self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad) - ), "Internal TE error!" + assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), ( + "Failed to infer TP overlap options! " + + "Forward pass cannot do AG+GEMM and GEMM+RS at the same time." + ) + assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), ( + "Failed to infer TP overlap options! " + + "Backward pass cannot do AG+DGRAD and DGRAD+RS at the same time." + ) + assert not self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad), ( + "Failed to infer TP overlap options! " + + "Backward pass cannot do DGRAD+RS and bulk overlaps at the same time." + ) self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name