Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚅⚡ Test Training with PyTorch Lightning #930

Merged
merged 19 commits into from
May 24, 2022
Merged

🚅⚡ Test Training with PyTorch Lightning #930

merged 19 commits into from
May 24, 2022

Conversation

mberr
Copy link
Member

@mberr mberr commented May 19, 2022

This PR adds tests for the PyTorch Lightning integration.

Issues

Related

tests/test_lightning.py Outdated Show resolved Hide resolved
@cthoyt
Copy link
Member

cthoyt commented May 20, 2022

todo: add Lightning [![Lightning](https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai) to the readme

@mberr
Copy link
Member Author

mberr commented May 20, 2022

There are some RuntimeErrors for CompGCN, RESCAL, and TransH which do not occur for not training with PyL.

src/pykeen/contrib/lightning.py:288: in lit_pipeline
    trainer.fit(model=lit)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:768: in fit
    self._call_and_handle_interrupt(
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:721: in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:809: in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1234: in _run
    results = self._run_stage()
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1321: in _run_stage
    return self._run_train()
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1351: in _run_train
    self.fit_loop.run()
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:204: in run
    self.advance(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:269: in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:204: in run
    self.advance(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py:208: in advance
    batch_output = self.batch_loop.run(batch, batch_idx)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:204: in run
    self.advance(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py:88: in advance
    outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/base.py:204: in run
    self.advance(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:203: in advance
    result = self._run_optimization(
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:256: in _run_optimization
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:369: in _optimizer_step
    self.trainer._call_lightning_module_hook(
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1593: in _call_lightning_module_hook
    output = fn(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py:1625: in optimizer_step
    optimizer.step(closure=optimizer_closure)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py:168: in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:193: in optimizer_step
    return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:155: in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/torch/optim/optimizer.py:88: in wrapper
    return func(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/torch/optim/adam.py:100: in step
    loss = closure()
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:140: in _wrap_closure
    closure_result = closure()
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:148: in __call__
    self._result = self.closure(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:143: in closure
    self._backward_fn(step_output.closure_loss)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py:311: in backward_fn
    self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1763: in _call_strategy_hook
    output = fn(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py:168: in backward
    self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py:80: in backward
    model.backward(closure_loss, optimizer, *args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py:1370: in backward
    loss.backward(*args, **kwargs)
venv/venv-gpu/lib/python3.8/site-packages/torch/_tensor.py:363: in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

trigger ci
@mberr mberr marked this pull request as ready for review May 24, 2022 14:45
@mberr mberr requested a review from cthoyt May 24, 2022 14:46
@mberr mberr changed the title Test Training with PyTorch Lightning 🚅⚡ Test Training with PyTorch Lightning May 24, 2022
"""
lit = lit_module_resolver.make(training_loop, pos_kwargs=training_loop_kwargs)
trainer = pytorch_lightning.Trainer(**(trainer_kwargs or {}))
trainer.fit(model=lit)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of this? it doesn't return anything

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It modifies the model's weights in-place

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a bit of extra description then? for poor fools like me ;)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training_loop_kwargs=dict(
model=model,
# use a small configuration for testing
# TODO: this does not properly work for all models
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which ones doesn't it work for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1d00631 - this comment was outdated.

mberr added 2 commits May 24, 2022 16:55
@mberr mberr enabled auto-merge (squash) May 24, 2022 15:02
@cthoyt
Copy link
Member

cthoyt commented May 24, 2022

@PyKEEN-bot test

@mberr mberr merged commit 6d6e374 into master May 24, 2022
@mberr mberr deleted the pyl-tests branch May 24, 2022 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants