Skip to content

Commit

Permalink
[Tests] Add non optional packages tests (#974)
Browse files Browse the repository at this point in the history
* add non-peft tests

* change name

* test

* change

* fix test
  • Loading branch information
younesbelkada authored Nov 9, 2023
1 parent 2f726ce commit c2884b5
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
24 changes: 23 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,29 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install .[test]
pip install -e ".[test, peft, diffusers]"
- name: Test with pytest
run: |
make test
tests_no_optional_dep:
needs: check_code_quality
runs-on: 'ubuntu-latest'
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v4
with:
python-version: '3.9'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install .[test]
- name: Test with pytest
run: |
make test
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
"tyro>=0.5.11",
]
EXTRAS = {
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "peft>=0.4.0", "diffusers>=0.18.0"],
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate"],
"peft": ["peft>=0.4.0"],
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
Expand Down
9 changes: 8 additions & 1 deletion tests/test_ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@

import torch

from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
from trl import is_diffusers_available

from .testing_utils import require_diffusers


if is_diffusers_available():
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline


def scorer_function(images, prompts, metadata):
Expand All @@ -27,6 +33,7 @@ def prompt_function():
return ("cabbages", {})


@require_diffusers
class DDPOTrainerTester(unittest.TestCase):
"""
Test the DDPOTrainer class.
Expand Down
11 changes: 10 additions & 1 deletion tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from trl import is_peft_available, is_wandb_available, is_xpu_available
from trl import is_diffusers_available, is_peft_available, is_wandb_available, is_xpu_available


def require_peft(test_case):
Expand All @@ -27,6 +27,15 @@ def require_peft(test_case):
return test_case


def require_diffusers(test_case):
"""
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available.
"""
if not is_diffusers_available():
test_case = unittest.skip("test requires diffusers")(test_case)
return test_case


def require_wandb(test_case, required: bool = True):
"""
Decorator marking a test that requires wandb. Skips the test if wandb is not available.
Expand Down

0 comments on commit c2884b5

Please sign in to comment.