From 425bac7537a9d778aa6c109a7cadc1b0677141ed Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 18 Dec 2024 10:56:38 -0800 Subject: [PATCH] Modify Qwen2 TRL command to avoid OOM. (#1630) Add --use_flash_attention to avoid OOM for Qwen2 --- examples/text-generation/utils.py | 3 ++- examples/trl/README.md | 3 ++- tests/test_text_generation_example.py | 19 +++++++------------ 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 545fd6edbb..4fe6567f64 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -442,7 +442,8 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): # Construct model with fake meta tensors, later will be replaced on devices during ds-inference ckpt load with deepspeed.OnDevice(dtype=model_dtype, device="meta"): if ( - hasattr(config, 'rope_scaling') and config.rope_scaling + hasattr(config, "rope_scaling") + and config.rope_scaling and config.rope_scaling["rope_type"] == "llama3" and config.max_position_embeddings > 8192 ): diff --git a/examples/trl/README.md b/examples/trl/README.md index 750fc82b08..18fb0fc0fa 100644 --- a/examples/trl/README.md +++ b/examples/trl/README.md @@ -39,7 +39,8 @@ $ pip install -U -r requirements.txt --lora_dropout=0.05 \ --lora_target_modules "q_proj" "v_proj" "k_proj" "o_proj" \ --max_seq_length 512 \ - --adam_epsilon 1e-08 + --adam_epsilon 1e-08 \ + --use_flash_attention ``` 2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-v0.1 on 4 cards: diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index b7349b5f02..10ac6b7adb 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -383,8 +383,8 @@ def test_text_generation_bf16_1x( check_output=check_output, ) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), - reason="Skipping test for G1") + +@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") @pytest.mark.parametrize( "model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline", MODELS_TO_TEST["fp8"] ) @@ -413,8 +413,7 @@ def test_text_generation_fp8( ) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), - reason="Skipping test for G1") +@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") @pytest.mark.parametrize( "model_name, world_size, batch_size, reuse_cache, input_len, output_len, baseline", MODELS_TO_TEST["load_quantized_model_with_autogptq"], @@ -450,23 +449,20 @@ def test_text_generation_deepspeed(model_name: str, baseline: float, world_size: _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, batch_size=batch_size) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), - reason="Skipping test for G1") +@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") @pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["torch_compile"]) def test_text_generation_torch_compile(model_name: str, baseline: float, token: str): _test_text_generation(model_name, baseline, token, torch_compile=True) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), - reason="Skipping test for G1") +@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") @pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["torch_compile_distributed"]) def test_text_generation_torch_compile_distributed(model_name: str, baseline: float, token: str): world_size = 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), - reason="Skipping test for G1") +@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") @pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["distributed_tp"]) def test_text_generation_distributed_tp(model_name: str, baseline: float, token: str): world_size = 8 @@ -489,8 +485,7 @@ def test_text_generation_contrastive_search( _test_text_generation(model_name, baseline, token, batch_size, reuse_cache, contrastive_search=True) -@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), - reason="Skipping test for G1") +@pytest.mark.skipif(condition=not bool(int(os.environ.get("GAUDI2_CI", "0"))), reason="Skipping test for G1") @pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["beam_search"]) def test_text_generation_beam_search(model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str): _test_text_generation(model_name, baseline, token, batch_size, reuse_cache, num_beams=3)