diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 58d7c4a2c0..b38e0bc95b 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -65,7 +65,7 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() - script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( model_config.torch_dtype diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index b9492e6042..d4915e0347 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -70,7 +70,7 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() - script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( model_config.torch_dtype diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index 1c465d90e6..d2d3cb05f8 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -55,7 +55,7 @@ if __name__ == "__main__": parser = TrlParser((ScriptArguments, XPOConfig, ModelConfig)) script_args, training_args, model_config = parser.parse_args_and_config() - script_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} torch_dtype = ( model_config.torch_dtype