From 950b04547159322de091ded0f9a22056392244ee Mon Sep 17 00:00:00 2001 From: Aman Rusia Date: Sat, 1 Feb 2025 18:34:26 +0530 Subject: [PATCH] code-writer allowed globs with workspace dir being a file path interaction bug fix --- pyproject.toml | 2 +- src/wcgw/client/tools.py | 6 ++++ tests/test_initialize.py | 75 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 80d69a3..cc3d716 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] authors = [{ name = "Aman Rusia", email = "gapypi@arcfu.com" }] name = "wcgw" -version = "2.8.7" +version = "2.8.8" description = "Shell and coding agent on claude and chatgpt" readme = "README.md" requires-python = ">=3.11, <3.13" diff --git a/src/wcgw/client/tools.py b/src/wcgw/client/tools.py index 2e9ae03..f3c03ef 100644 --- a/src/wcgw/client/tools.py +++ b/src/wcgw/client/tools.py @@ -457,6 +457,12 @@ def initialize( folder_to_start = None if any_workspace_path: if os.path.exists(any_workspace_path): + if os.path.isfile(any_workspace_path): + # Set any_workspace_path to the directory containing the file + # Add the file to read_files_ only if empty to avoid duplicates + if not read_files_: + read_files_ = [any_workspace_path] + any_workspace_path = os.path.dirname(any_workspace_path) repo_context, folder_to_start = get_repo_context(any_workspace_path, 200) repo_context = f"---\n# Workspace structure\n{repo_context}\n---\n" diff --git a/tests/test_initialize.py b/tests/test_initialize.py index 321d5cd..bc931f1 100644 --- a/tests/test_initialize.py +++ b/tests/test_initialize.py @@ -248,3 +248,78 @@ def test_load_bash_state_error(self): # Verify task memory was still loaded despite state error self.assertIn("Following is the retrieved task:\ntest_memory", result) + + + def test_workspace_path_is_file(self): + """Test initialize when workspace path points to a file""" + # Create a test file + test_file = os.path.join(self.test_workspace, "test.py") + os.makedirs(os.path.dirname(test_file), exist_ok=True) + with open(test_file, "w") as f: + f.write("print('test')") + + with ( + patch("os.path.exists") as mock_exists, + patch("os.path.isfile") as mock_isfile, + patch("wcgw.client.tools.get_repo_context") as mock_get_context, + patch("wcgw.client.tools.read_files") as mock_read_files, + ): + mock_exists.return_value = True + mock_isfile.return_value = True + mock_get_context.return_value = (self.repo_context, os.path.dirname(test_file)) + mock_read_files.return_value = f"``` {test_file}print('test')```" + + # Call initialize with file path and no read_files_ + result = initialize( + any_workspace_path=test_file, + read_files_=[], + task_id_to_resume="", + max_tokens=None, + mode=Modes.wcgw, + ) + + # Verify read_files was called with the file path + mock_read_files.assert_called_once_with([test_file], None) + + # Verify repo context uses parent directory + mock_get_context.assert_called_once_with(os.path.dirname(test_file), 200) + + # Verify file content is in output + self.assertIn(f"``` {test_file}print('test')", result) + + # Verify cwd + self.assertIn(f"Initialized in directory (also cwd): {os.path.dirname(test_file)}\n", result) + + def test_workspace_path_is_file_with_read_files(self): + """Test initialize when workspace path points to a file and read_files_ is provided""" + test_file = os.path.join(self.test_workspace, "test.py") + extra_file = os.path.join(self.test_workspace, "extra.py") + + with ( + patch("os.path.exists") as mock_exists, + patch("os.path.isfile") as mock_isfile, + patch("wcgw.client.tools.get_repo_context") as mock_get_context, + patch("wcgw.client.tools.read_files") as mock_read_files, + ): + mock_exists.return_value = True + mock_isfile.return_value = True + mock_get_context.return_value = (self.repo_context, os.path.dirname(test_file)) + mock_read_files.return_value = f"``` {extra_file}extra content```" + + # Call initialize with file path and explicit read_files_ + result = initialize( + any_workspace_path=test_file, + read_files_=[extra_file], + task_id_to_resume="", + max_tokens=None, + mode=Modes.wcgw, + ) + + # Verify read_files was called with the provided read_files_ only + mock_read_files.assert_called_once_with([extra_file], None) + + # Verify repo context uses parent directory + mock_get_context.assert_called_once_with(os.path.dirname(test_file), 200) + + # Verify extra file content is in output + self.assertIn(f"``` {extra_file}extra content", result)