Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Running examples/scripts/dpo.py crashes #914

Closed
JeanKaddour opened this issue Oct 24, 2023 · 4 comments
Closed

[BUG] Running examples/scripts/dpo.py crashes #914

JeanKaddour opened this issue Oct 24, 2023 · 4 comments

Comments

@JeanKaddour
Copy link

JeanKaddour commented Oct 24, 2023

Error

Running examples/scripts/dpo.py with no changes to the args crashes with the following error message

 File "/home/jean/trl/examples/scripts/dpo.py", line 170, in <module>
    dpo_trainer.train()
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/site-packages/transformers/trainer.py", line 1984, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/site-packages/transformers/trainer.py", line 2339, in _maybe_log_save_evaluate
    self._save_checkpoint(model, trial, metrics=metrics)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/site-packages/transformers/trainer.py", line 2471, in _save_checkpoint
    self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/site-packages/transformers/trainer_callback.py", line 106, in save_to_json
    json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/dataclasses.py", line 1238, in asdict
    return _asdict_inner(obj, dict_factory)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/dataclasses.py", line 1245, in _asdict_inner
    value = _asdict_inner(getattr(obj, f.name), dict_factory)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/dataclasses.py", line 1273, in _asdict_inner
    return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/dataclasses.py", line 1273, in <genexpr>
    return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/dataclasses.py", line 1275, in _asdict_inner
    return type(obj)((_asdict_inner(k, dict_factory),
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/dataclasses.py", line 1276, in <genexpr>
    _asdict_inner(v, dict_factory))
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/dataclasses.py", line 1279, in _asdict_inner
    return copy.deepcopy(obj)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/jean/anaconda3/envs/rlhf/lib/python3.10/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle '_thread.lock' object

Installed Environment

I installed the dependencies as provided in requirements.txt
specifically, I use

transformers==4.34.1
tokenizers==0.14.1
peft==0.5.0
accelerate==0.23.0
torch==2.1.0
tyro==0.5.10

and a source installation of this commit 7de7db6

mnoukhov added a commit to mnoukhov/trl that referenced this issue Oct 27, 2023
@mnoukhov
Copy link
Contributor

This seems to be caused by saving the trainer state during saving a checkpoint. We are logging a wandb.Table with self.log and it is then added to the self.state.log_history. During saving, the trainer attempts to convert the log_history to json but can't pickle the wandb.Table causing this issue.

I submitted a PR to remove the wandb.Table from the log history which should still log to wandb but keep it out of the trainer's state

lvwerra pushed a commit that referenced this issue Oct 31, 2023
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@lvwerra
Copy link
Member

lvwerra commented Nov 24, 2023

This should be fixed in #919.

@lvwerra lvwerra closed this as completed Nov 24, 2023
lapp0 pushed a commit to lapp0/trl that referenced this issue May 10, 2024
@fozziethebeat
Copy link

I think this problem has resurfaced at some point. I'm running TRL indirectly through Axolotl and I'm seeing this line triggering the ProgressCallback.

Then pretty naturally when that callback does logs = copy.deepcopy(logs) the WandB table in the logs breaks things with the same failure.

The key parts of my stacktrace are:

  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 2221, in train
    return inner_training_loop(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 2728, in _inner_training_loop
    self._maybe_log_save_evaluate(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 3299, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 3240, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 4277, in evaluate
    output = eval_loop(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1814, in evaluation_loop
    self.log(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1856, in log
    return super().log(logs)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer.py", line 3802, in log
    self.control = self.callback_handler.on_log(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer_callback.py", line 628, in on_log
    return self.call_event("on_log", args, state, control, logs=logs)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer_callback.py", line 641, in call_event
    result = getattr(callback, event)(
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/site-packages/transformers/trainer_callback.py", line 775, in on_log
    logs = copy.deepcopy(logs)
  File "/home/fozziethebeat/anaconda3/envs/axolotl-dev/lib/python3.10/copy.py", line 146, in deepcopy
    y = copier(x, memo)

Important versions of libraries are:

transformers==4.42.3
trl==0.9.6
wandb==0.17.4

The only solution i've found is to drop wandb logging or deleting the log line in the DPO Trainer.

I feel like the right fix is that when the DPO trainer calls self.log(...) it should not trigger the ProgressCallback and its kinda weird that it is

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants