From bc76eeebc87be20369addeae9fd04dcc7310dded Mon Sep 17 00:00:00 2001 From: Eric Joanis Date: Thu, 30 Jan 2025 12:01:03 -0500 Subject: [PATCH] feat: better wizard resume message for various sw and versions (#629) cases handled: - recent versions that are known to be compatible - older version - back from the future - wrong software name --- everyvoice/tests/test_wizard.py | 106 +++++++++++++++++++++++++++----- everyvoice/wizard/tour.py | 43 +++++++++++-- pyproject.toml | 1 + 3 files changed, 129 insertions(+), 21 deletions(-) diff --git a/everyvoice/tests/test_wizard.py b/everyvoice/tests/test_wizard.py index 2fe032ac..ba360da4 100644 --- a/everyvoice/tests/test_wizard.py +++ b/everyvoice/tests/test_wizard.py @@ -8,17 +8,20 @@ from copy import deepcopy from enum import Enum from pathlib import Path +from textwrap import dedent from types import MethodType from typing import Callable, Iterable, NamedTuple, Optional, Sequence from unittest import TestCase from anytree import PreOrderIter, RenderTree +from packaging.version import Version # [Unit testing questionary](https://github.com/prompt-toolkit/python-prompt-toolkit/blob/master/docs/pages/advanced_topics/unit_testing.rst) from prompt_toolkit.application import create_app_session from prompt_toolkit.input import create_pipe_input from prompt_toolkit.output import DummyOutput +from everyvoice._version import VERSION from everyvoice.tests.stubs import ( Say, capture_stderr, @@ -1845,6 +1848,19 @@ def test_control_c_display_tree(self): self.assertRegex(output.getvalue(), r"Contact Name: *Jane Doe") self.assertEqual(tour.state, self.trivial_tour_results) + progress_template = dedent( + """\ + - - EveryVoice Wizard + - {version} + - - Root + - null + - - Name Step + - project_name + - - Contact Name Step + - Jane Doe + """ + ) + def test_control_c_save_progress(self): # Ctrl-C plus option 3 saves progress to file with tempfile.TemporaryDirectory() as tmpdirname: @@ -1865,8 +1881,20 @@ def test_control_c_save_progress(self): ): with patch_menu_prompt(3): tour.run() - self.assertTrue(progress_file.exists()) + self.assertTrue(progress_file.exists()) + with open(progress_file, encoding="utf8") as f: + progress_contents = f.read() + # print(progress_contents) + self.assertEqual( + progress_contents, self.progress_template.format(version=VERSION) + ) + def test_resume_from(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + progress_file = tmpdir / "saved-progress" + with open(progress_file, "w") as f: + f.write(self.progress_template.format(version=VERSION)) # resume works tour = make_trivial_tour() with patch_input("email@mail.com"), capture_stdout() as out: @@ -1874,26 +1902,66 @@ def test_control_c_save_progress(self): self.assertIn("Applying saved response", out.getvalue()) self.assertEqual(tour.state, self.trivial_tour_results) - with open(progress_file, encoding="utf8") as f: - progress_lines = f.readlines() + def test_resume_from_the_future(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + # resume from a future version works with a warning + changed_version = tmpdir / "changed-version" + with open(changed_version, "w", encoding="utf8") as f: + v = Version(VERSION) + f.write( + self.progress_template.format( + version=f"{v.major + 1}.{v.minor}.{v.micro}" + ) + ) + tour = make_trivial_tour() + with patch_input("email@mail.com"), capture_stdout() as out: + tour.run(resume_from=changed_version) + self.assertRegex(out.getvalue(), r"(?s)Proceeding.*anyway") + self.assertRegex(out.getvalue(), r"(?s)consider.*updating.*your.*software") + self.assertIn("Applying saved response", out.getvalue()) + self.assertEqual(tour.state, self.trivial_tour_results) + + def test_resume_from_near_past(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) # resume from a changed version works with a warning changed_version = tmpdir / "changed-version" with open(changed_version, "w", encoding="utf8") as f: - f.write(progress_lines[0]) - f.write(progress_lines[1].replace("\n", "changed\n")) - f.write("".join(progress_lines[2:])) + f.write(self.progress_template.format(version=VERSION + ".dev0")) + tour = make_trivial_tour() with patch_input("email@mail.com"), capture_stdout() as out: tour.run(resume_from=changed_version) + self.assertRegex(out.getvalue(), r"(?s)expected.*to.*be.*compatible") self.assertRegex(out.getvalue(), r"(?s)Proceeding.*anyway") self.assertIn("Applying saved response", out.getvalue()) self.assertEqual(tour.state, self.trivial_tour_results) + def test_resume_from_far_past(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + # resume from a potentially incompatible older version + changed_version = tmpdir / "changed-version" + with open(changed_version, "w", encoding="utf8") as f: + f.write(self.progress_template.format(version="0.1.2")) + tour = make_trivial_tour() + with patch_input("email@mail.com"), capture_stdout() as out: + tour.run(resume_from=changed_version) + self.assertRegex(out.getvalue(), r"(?s)not.*fully.*compatible") + self.assertRegex(out.getvalue(), r"(?s)Proceeding.*anyway") + self.assertIn("Applying saved response", out.getvalue()) + self.assertEqual(tour.state, self.trivial_tour_results) + + def test_resume_with_invalid_progress_files(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + # This one has an invalid response but lets the user recover invalid_response = tmpdir / "invalid-response" with open(invalid_response, "w", encoding="utf8") as f: - f.write("".join(progress_lines)) + f.write(self.progress_template.format(version=VERSION)) f.write("- - Contact Email Step\n - invalid email\n") tour = make_trivial_tour() with patch_input("email@mail.com"), capture_stdout() as out: @@ -1915,6 +1983,10 @@ def test_control_c_save_progress(self): with self.assertRaises(SystemExit), capture_stdout(): tour.run(resume_from=bad_progress_file2) + progress_lines = self.progress_template.format(version=VERSION).splitlines( + keepends=True + ) + truncated_progress_file = tmpdir / "truncated-progress" with open(truncated_progress_file, "w", encoding="utf8") as f: f.write("".join(progress_lines[:-1])) @@ -1932,15 +2004,9 @@ def test_control_c_save_progress(self): questions_out_of_order = tmpdir / "questions-out-of-order" with open(questions_out_of_order, "w", encoding="utf8") as f: - f.write( - "".join( - [ - *progress_lines[:-4], - *progress_lines[-2:], - *progress_lines[-4:-2], - ] - ) - ) + f.write("".join(progress_lines[:-4])) + f.write("".join(progress_lines[-2:])) + f.write("".join(progress_lines[-4:-2])) with self.assertRaises(SystemExit), capture_stdout() as out: tour.run(resume_from=questions_out_of_order) self.assertIn("out of sync", out.getvalue()) @@ -1955,6 +2021,14 @@ def test_control_c_save_progress(self): tour.run(resume_from=extra_question_not_in_tour) self.assertIn("saved responses left", out.getvalue()) + wrong_software_name = tmpdir / "wrong-software-name" + with open(wrong_software_name, "w", encoding="utf8") as f: + f.write("- - Wrong Software\n") + f.write("".join(progress_lines[1:])) + with self.assertRaises(SystemExit), capture_stdout() as out: + tour.run(resume_from=wrong_software_name) + self.assertRegex(out.getvalue(), r"(?s)it.*is.*for.*software") + def test_control_c_exit(self): # Ctrl-C plus option 4 (Exit) exits tour = make_trivial_tour() diff --git a/everyvoice/wizard/tour.py b/everyvoice/wizard/tour.py index 8f40d4f0..0437b628 100644 --- a/everyvoice/wizard/tour.py +++ b/everyvoice/wizard/tour.py @@ -6,6 +6,7 @@ import questionary import yaml from anytree import PreOrderIter, RenderTree +from packaging.version import Version from rich import print as rich_print from rich.panel import Panel @@ -170,6 +171,9 @@ def validate(self, response): return response is None +SOFTWARE_NAME = "EveryVoice Wizard" + + class Tour: def __init__( self, @@ -305,12 +309,41 @@ def resume(self, resume_from: Path) -> Optional[Step]: q_and_a_iter = iter(q_and_a_list) software, version = next(q_and_a_iter) - if software != "EveryVoice Wizard" or version != VERSION: + if software != SOFTWARE_NAME: rich_print( - f"[yellow]Warning: saved progress file is for {software} version '{version}', " - f"but this is version '{VERSION}'. Proceeding anyway, but be aware that " - "the saved responses may not be compatible.[/yellow]" + f"Error loading progress from {resume_from}: it is for software " + f"{software}, but this is {SOFTWARE_NAME}." ) + sys.exit(1) + + # When we introduce breaking changes to the wizard question sequence, code + # is to be added here to automatically fix resume-from files, adding defaults + # for new questions if possible, or else giving a warning explaining what needs + # to be changed if auto upgrade is not possible. + # Regression testing should warn us when such auto-upgrade code is required here. + compatible_since = Version("0.2.0a0") + if version != VERSION: + if Version(version) >= Version(VERSION): + rich_print( + "[red]Warning: saved progress file is from the future, for " + f"{software} version '{version}', but this is version '{VERSION}'. " + "Proceeding anyway, but we can't tell if they'll be compatible. " + "Please consider updating your software.[/red]" + ) + elif Version(version) >= compatible_since: + rich_print( + f"[yellow]Warning: saved progress file is for {software} version '{version}', " + f"but this is version '{VERSION}', which is expected to be compatible. " + "Proceeding anyway, but be aware that some things may have changed " + "between versions.[/yellow]" + ) + else: + rich_print( + f"[yellow]Warning: saved progress file is for {software} version '{version}', " + f"but this is version '{VERSION}', which is not fully compatible. " + "Proceeding anyway, but be aware that some saved responses may no " + "longer be compatible.[/yellow]" + ) q_and_a = next(q_and_a_iter, None) node = self.root while node is not None and q_and_a is not None: @@ -426,7 +459,7 @@ def save_progress(self, current_node: Step): try: with open(filename, "w", encoding="utf8") as f: yaml.dump( - [["EveryVoice Wizard", VERSION]] + self.get_progress(current_node), + [[SOFTWARE_NAME, VERSION]] + self.get_progress(current_node), f, allow_unicode=True, ) diff --git a/pyproject.toml b/pyproject.toml index 17d21cd5..af5350bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "merge-args", "nltk==3.9.1", "numpy<2", # torch < 2.4.1 requires numpy < 2 but fails to declare it + "packaging>=20.9", "pandas~=2.0", "panphon==0.20.0", "protobuf~=4.25", # https://github.com/EveryVoiceTTS/EveryVoice/issues/387