diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 9b97a9f19ad..a93fb7f31cc 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -411,7 +411,10 @@ def git_checkout( if not is_subpath_of(tmp_dir, source_path): err = f"'{subpath}' is a path outside of the cloned repository." msg.fail(err, repo, exits=1) - shutil.copytree(str(source_path), str(dest)) + if os.path.isdir(source_path): + shutil.copytree(str(source_path), str(dest)) + else: + shutil.copyfile(source_path, dest) except FileNotFoundError: err = f"Can't clone {subpath}. Make sure the directory exists in the repo (branch '{branch}')" msg.fail(err, repo, exits=1) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 3104b49ff27..f3abd5d587c 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -18,7 +18,7 @@ from spacy.cli._util import parse_config_overrides, string_to_list from spacy.cli._util import substitute_project_variables from spacy.cli._util import validate_project_commands -from spacy.cli._util import upload_file, download_file +from spacy.cli._util import upload_file, download_file, git_checkout from spacy.cli.debug_data import _compile_gold, _get_labels_from_model from spacy.cli.debug_data import _get_labels_from_spancat from spacy.cli.debug_data import _get_distribution, _get_kl_divergence @@ -140,7 +140,32 @@ def test_issue11235(): assert os.path.exists(d / "cfg") assert os.path.exists(d / f"{lang_var}_model") assert cfg["commands"][0]["script"][0] == f"hello {lang_var}" - + + +def test_project_git_dir_asset(): + with make_tempdir() as d: + # Use a very small repo. + git_checkout( + "https://github.com/explosion/os-signpost.git", + "os_signpost", + d / "signpost", + branch="v0.0.3", + ) + assert os.path.isdir(d / "signpost") + + +@pytest.mark.issue(12168) +def test_project_git_file_asset(): + with make_tempdir() as d: + # Use a very small repo. + git_checkout( + "https://github.com/explosion/os-signpost.git", + "README.md", + d / "readme.md", + branch="v0.0.3", + ) + assert os.path.isfile(d / "readme.md") + def test_cli_info(): nlp = Dutch()