diff --git a/src/chatdbg/assistant/assistant.py b/src/chatdbg/assistant/assistant.py new file mode 100644 index 0000000..f051869 --- /dev/null +++ b/src/chatdbg/assistant/assistant.py @@ -0,0 +1,227 @@ +import atexit +import inspect +import json +import textwrap +import time + +import llm_utils +from openai import * +from pydantic import BaseModel + + +class Assistant: + """ + An Assistant is a wrapper around OpenAI's assistant API. Example usage: + + assistant = Assistant("Assistant Name", instructions, + model='gpt-4-1106-preview', debug=True) + assistant.add_function(my_func) + response = assistant.run(user_prompt) + + Name can be any name you want. + + If debug is True, it will create a log of all messages and JSON responses in + json.txt. + """ + + def __init__(self, name, instructions, model="gpt-3.5-turbo-1106", debug=True): + + if debug: + self.json = open(f'json.txt', 'w') + else: + self.json = None + + try: + self.client = OpenAI(timeout=30) + except OpenAIError: + print(textwrap.dedent("""\ + You need an OpenAI key to use this tool. + You can get a key here: https://platform.openai.com/api-keys + Set the environment variable OPENAI_API_KEY to your key value. + """)) + return + + + self.assistants = self.client.beta.assistants + self.threads = self.client.beta.threads + self.functions = dict() + + self.assistant = self.assistants.create(name=name, + instructions=instructions, + model=model) + + self._log(self.assistant) + + atexit.register(self._delete_assistant) + + self.thread = self.threads.create() + self._log(self.thread) + + def _delete_assistant(self): + if self.assistant != None: + try: + id = self.assistant.id + response = self.assistants.delete(id) + self._log(response) + assert response.deleted + except Exception as e: + print(f'Assistant {id} was not deleted ({e}).\nYou can do so at https://platform.openai.com/assistants.') + + def add_function(self, function): + """ + Add a new function to the list of function tools for the assistant. + The function should have the necessary json spec as is pydoc string. + """ + function_json = json.loads(function.__doc__) + assert 'name' in function_json, "Bad JSON in pydoc for function tool." + try: + name = function_json['name'] + self.functions[name] = function + + tools = [ + { + "type": "function", + "function": json.loads(function.__doc__) + } for function in self.functions.values() + ] + + assistant = self.assistants.update(self.assistant.id, + tools=tools) + self._log(assistant) + except OpenAIError as e: + print(f"*** OpenAI Error: {e}") + + + def _make_call(self, tool_call): + name = tool_call.function.name + args = tool_call.function.arguments + + # There is a sketchy case that happens occasionally because + # the API produces a bad call... + try: + args = json.loads(args) + function = self.functions[name] + result = function(**args) + except Exception as e: + result = f"Ill-formed function call ({e})\n" + + return result + + def _print_messages(self, messages, client_print): + client_print() + for i,m in enumerate(messages): + message_text = m.content[0].text.value + if i == 0: + message_text = '(Message) ' + message_text + client_print(message_text) + + + + def _wait_on_run(self, run, thread, client_print): + try: + while run.status == "queued" or run.status == "in_progress": + run = self.threads.runs.retrieve( + thread_id=thread.id, + run_id=run.id, + ) + time.sleep(0.5) + return run + finally: + if run.status == 'in_progress': + client_print("Cancelling message that's in progress.") + self.threads.runs.cancel(thread_id=thread.id, run_id=run.id) + + def run(self, prompt, client_print = print): + """ + Give the prompt to the assistant and get the response, which may included + intermediate function calls. + All output is printed to the given file. + """ + try: + if self.assistant == None: + return + + assert len(prompt) <= 32768 + + message = self.threads.messages.create(thread_id=self.thread.id, + role="user", + content=prompt) + self._log(message) + + last_printed_message_id = message.id + + + run = self.threads.runs.create(thread_id=self.thread.id, + assistant_id=self.assistant.id) + self._log(run) + + run = self._wait_on_run(run, self.thread, client_print) + self._log(run) + + while run.status == "requires_action": + + messages = self.threads.messages.list(thread_id=self.thread.id, + after=last_printed_message_id, + order='asc') + + mlist = list(messages) + if len(mlist) > 0: + self._print_messages(mlist, client_print) + last_printed_message_id = mlist[-1].id + client_print() + + + outputs = [] + for tool_call in run.required_action.submit_tool_outputs.tool_calls: + output = self._make_call(tool_call) + self._log(output) + outputs += [ { 'tool_call_id' : tool_call.id, 'output' : output } ] + + try: + run = self.threads.runs.submit_tool_outputs(thread_id=self.thread.id, + run_id=run.id, + tool_outputs=outputs) + self._log(run) + except Exception as e: + self._log(run, f'FAILED to submit tool call results: {e}') + + run = self._wait_on_run(run, self.thread, client_print) + self._log(run) + + if run.status == 'failed': + message = f"\n**Internal Failure ({run.last_error.code}):** {run.last_error.message}" + client_print(message) + return + + messages = self.threads.messages.list(thread_id=self.thread.id, + after=last_printed_message_id, + order='asc') + self._print_messages(messages, client_print) + + cost = llm_utils.calculate_cost(run.usage.prompt_tokens, + run.usage.completion_tokens, + self.assistant.model) + client_print() + client_print(f'[Cost: ~${cost:.2f} USD]') + except OpenAIError as e: + client_print(f"*** OpenAI Error: {e}") + + + + + def _log(self, obj, title=''): + if self.json != None: + stack = inspect.stack() + caller_frame_record = stack[1] + lineno, function = caller_frame_record[2:4] + loc = f'{function}:{lineno}' + + print('-' * 70, file=self.json) + print(f'{loc} {title}', file=self.json) + if isinstance(obj, BaseModel): + json_obj = json.loads(obj.model_dump_json()) + else: + json_obj = obj + print(f'\n{json.dumps(json_obj, indent=2)}\n', file=self.json) + self.json.flush() + return obj \ No newline at end of file diff --git a/src/chatdbg/chatdbg.py b/src/chatdbg/chatdbg.py index a507867..1a8a34a 100644 --- a/src/chatdbg/chatdbg.py +++ b/src/chatdbg/chatdbg.py @@ -1,26 +1,419 @@ #! /usr/bin/env python3 -import importlib +import importlib.metadata +import inspect +import os import pdb +import pydoc import sys +import textwrap +import traceback +from io import StringIO -from . import chatdbg_why +import llm_utils + +from .assistant.assistant import Assistant + +_config = { + 'model' : 'gpt-4-1106-preview', + 'debug' : False +} + +_basic_instructions=f"""\ +You are a debugging assistant. You will be given a Python stack trace for an +error and answer questions related to the root cause of the error. + +Call the `pdb` function to run Pdb debugger commands on the stopped program. The +Pdb debugger keeps track of a current frame. You may call the `pdb` function +with the following strings: + + bt + Print a stack trace, with the most recent frame at the bottom. + An arrow indicates the "current frame", which determines the + context of most commands. + + up + Move the current frame count one level up in the + stack trace (to an older frame). + down + Move the current frame count one level down in the + stack trace (to a newer frame). + + p expression + Print the value of the expression. + + whatis expression + Print the type of the expression. + + list + List the source code for the current frame. + The current line in the current frame is indicated by "->". + + info expression + Print the documentation and source code for the given expression, + which should be callable. + +Call the `info` function to get the documentation and source code for any +function that is visible in the current frame. + +Call the `pdb` and `info` functions as many times as you would like. + +Call `pdb` to print any variable value or expression that you believe may +contribute to the error. + +Unless it is in a common, widely-used library, you MUST call `info` on any +function that is called in the code, that apppears in the argument list for a +function call in the code, or that appears on the call stack. + +The root cause of any error is likely due to a problem in the source code within +the {os.path.dirname(sys.argv[0])} directory. + +Keep your answers under about 8-10 sentences. Conclude each response with +either a propopsed fix if you have identified the root cause or a bullet list of +1-3 suggestions for how to continue debugging. +""" + +class CopyingTextIOWrapper: + """ + File wrapper that will stash a copy of everything written. + """ + def __init__(self, file): + self.file = file + self.buffer = StringIO() + + def write(self, data): + self.buffer.write(data) + return self.file.write(data) + + def getvalue(self): + return self.buffer.getvalue() + + def getfile(self): + return self.file + + def __getattr__(self, attr): + # Delegate attribute access to the file object + return getattr(self.file, attr) class ChatDBG(pdb.Pdb): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.prompt = "(ChatDBG Pdb) " + + self.prompt = '(ChatDBG pdb) ' + self.chat_prefix = ' ' + self.text_width = 80 + self._assistant = None + self._history = [] + self._error_specific_prompt = '' + + def _is_user_file(self, file_name): + return file_name.startswith(os.path.dirname(sys.argv[0])) + + def grab_active_call_from_frame(self, tb): + """ + Extract the text for the function call currently running in the + top-most frame in tb. + """ + frame = tb.tb_frame + lineno = pdb.lasti2lineno(frame.f_code, tb.tb_lasti) + lines = inspect.getsourcelines(frame)[0] + for index, line in enumerate(lines, frame.f_code.co_firstlineno): + if index == lineno: + leading_spaces = len(line) - len(line.lstrip()) + # Degrade gracefully when using older Python versions that don't have column info. + try: + positions = inspect.getframeinfo(frame).positions + return line[positions.col_offset:positions.end_col_offset] + except: + return line + assert False + + def _hide_lib_calls(self, tb): + """ + Remove all frames from the stack that are not part of + the user code. Return a tuple (tb,lib_entry) where + - tb is the new traceback + - lib_entry is the most recent user frame calling into + a library. + """ + head = tb + if head != None: + while not self._is_user_file(head.tb_frame.f_code.co_filename): + head = head.tb_next + + tail = head + lib_entry = None + while tail.tb_next != None: + tail_file_name = tail.tb_next.tb_frame.f_code.co_filename + if self._is_user_file(tail_file_name): + tail = tail.tb_next + lib_entry = None + else: + if lib_entry == None: + lib_entry = tail.tb_next + tail.tb_next = tail.tb_next.tb_next + return head, tail if lib_entry != None else None + + def interaction(self, frame, tb): + """ + Override to remove all lib code from the stack and create more + precise details about where to look for the error. + """ + if tb != None: + exc_type, exc_value, _ = sys.exc_info() + tb, lib_entry_point = self._hide_lib_calls(tb) + + if lib_entry_point != None: + tb_str = ''.join(traceback.format_tb(tb)) + details = textwrap.dedent(f"""\ + An exception was raised during the call to + {self.grab_active_call_from_frame(lib_entry_point)}. The + root cause is most likely related to the arguments passed + into that function. You MUST look at the values passed in as + arguments and the specification for the function. You MUST + consider the order that the arguments are listed.\n""") + else: + tb_str = ''.join(traceback.format_exception(exc_type, exc_value, tb)) + details = '' + prompt = f"Here is the stack trace for the error:\n{tb_str}\n{details}\n" + self._error_specific_prompt = prompt + + super().interaction(frame, tb) + + + def onecmd(self, line: str) -> bool: + """ + Override to stash the results in our history. + """ + if not line: + # blank -- let super call back to into onecmd + return super().onecmd(line) + else: + hist_file = CopyingTextIOWrapper(self.stdout) + self.stdout = hist_file + try: + return super().onecmd(line) + finally: + if line not in [ 'hist', 'test_prompt' ]: + self._history += [ (line, hist_file.getvalue()) ] + self.stdout = hist_file.getfile() + + def message(self, msg) -> None: + """ + Override to remove tabs for messages so we can indent them. + """ + return super().message(str(msg).expandtabs()) + + def error(self, msg) -> None: + """ + Override to remove tabs for messages so we can indent them. + """ + return super().error(str(msg).expandtabs()) + + def _capture_onecmd(self, line): + """ + Run one Pdb command, but capture and return stdout. + """ + stdout = self.stdout + lastcmd = self.lastcmd + try: + self.stdout = StringIO() + super().onecmd(line) + result = self.stdout.getvalue().rstrip() + return result + finally: + self.stdout = stdout + self.lastcmd = lastcmd + + def format_history_entry(self, entry, indent = ''): + line, output = entry + output = llm_utils.word_wrap_except_code_blocks(output, + self.text_width) + if output: + entry = f"(ChatDBG pdb) {line}\n{output}" + else: + entry = f"(ChatDBG pdb) {line}" + return textwrap.indent(entry, indent, lambda _ : True) + + def _clear_history(self): + self._history = [ ] + + # override to make lines starting with : be chat lines. + def default(self, line): + if line[:1] == ':': + line = line[1:].strip() + self.do_chat(line) + else: + super().default(line) + + def do_hist(self, arg): + """hist + Print the history of user-issued commands since the last chat. + """ + entry_strs = [ self.format_history_entry(x) for x in self._history ] + history_str = "\n".join(entry_strs) + self.message(history_str) + + def do_pydoc(self, arg): + """pydoc name + Print the pydoc string for a name. + """ + try: + obj = self._getval(arg) + if obj.__doc__ != None: + pydoc.doc(obj, output = self.stdout) + else: + self.message(f'No documentation is available.') + except NameError: + # message already printed in _getval + pass + + def do_info(self, arg): + """info name + Print the pydoc string (and source code, if available) for a name. + """ + try: + obj = self._getval(arg) + if self._is_user_file(inspect.getfile(obj)): + self.do_source(arg) + else: + self.do_pydoc(arg) + self.message(f'You MUST assume that `{arg}` is specified and implemented correctly.') + except NameError: + # message already printed in _getval + pass + except Exception: + self.do_pydoc(arg) + self.message(f'You MUST assume that `{arg}` is specified and implemented correctly.') + + def do_test_prompt(self, arg): + """test_prompt + [For debugging] Prints the prompts to be sent to the assistant. + """ + self.message('Instructions:') + self.message(self._instructions()) + self.message('-' * 80) + self.message('Prompt:') + self.message(self._prompt(arg)) + + def _instructions(self): + return _basic_instructions + '\n' + self._error_specific_prompt + + def _stack_prompt(self): + stack_frames = textwrap.indent(self._capture_onecmd('bt'), '') + stack = textwrap.dedent(f""" + This is the current stack. + The current frame is indicated by an arrow '>' at + the start of the line. + ```""") + f'\n{stack_frames}\n```' + return stack + + def _prompt(self, arg): + if arg == 'why': + arg = "Explain the root cause of the error." + + user_prompt = '' + if len(self._history) > 0: + hist = textwrap.indent(self._capture_onecmd('hist'), '') + self._clear_history() + hist = f"This is the history of some pdb commands I ran and the results.\n```\n{hist}\n```\n" + user_prompt += hist + + stack = self._stack_prompt() + user_prompt += stack + '\n' + arg + + return user_prompt + + def do_chat(self, arg): + """chat/: + Send a chat message. + """ + + prompt = self._prompt(arg) + + if self._assistant == None: + self._make_assistant() + + def client_print(line=''): + line = llm_utils.word_wrap_except_code_blocks(line, + self.text_width - 10) + line = textwrap.indent(line, + self.chat_prefix, + lambda _ : True) + print(line, file=self.stdout, flush=True) + + self._assistant.run(prompt, client_print) + + + def _make_assistant(self): + + def info(name): + """ + { + "name": "info", + "description": "Get the documentation and source code (if available) for any function visible in the current frame", + "parameters": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the function to get the information for" + } + }, + "required": [ "name" ] + } + } + """ + command = f'info {name}' + result = self._capture_onecmd(command) + self.message(self.format_history_entry((command, result), + indent = self.chat_prefix)) + return result + + def pdb(command): + """ + { + "name": "pdb", + "description": "Run a pdb command and get the response.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The pdb command to run." + } + }, + "required": [ "command" ] + } + } + """ + cmd = command if command != 'list' else 'll' + result = self._capture_onecmd(cmd) + self.message(self.format_history_entry((command, result), + indent = self.chat_prefix)) + + # help the LLM know where it is... + result += self._stack_prompt() + + return result + + self._clear_history() + self._assistant = Assistant("ChatDBG", + self._instructions(), + model=_config['model'], + debug=_config['debug']) + self._assistant.add_function(pdb) + self._assistant.add_function(info) + - def do_why(self, arg): - chatdbg_why.why(self, arg) _usage = f"""\ usage: chatdbg [-c command] ... [-m module | pyfile] [arg] ... A Python debugger that uses AI to tell you `why`. -(version {importlib.metadata.metadata('ChatDBG')['Version']}) +(version {importlib.metadata.metadata('chatdbg')['Version']}) https://github.com/plasma-umass/ChatDBG @@ -33,21 +426,42 @@ def do_why(self, arg): -c are executed after commands from .pdbrc files. To let the script run until an exception occurs, use "-c continue". -You can then type `why` to get an explanation of the root cause of +You can then type `:why` to get an explanation of the root cause of the exception, along with a suggested fix. NOTE: you must have an OpenAI key saved as the environment variable OPENAI_API_KEY. You can get a key here: https://openai.com/api/ +You may also ask any other question that starts with the word `:`. + To let the script run up to a given line X in the debugged file, use -"-c 'until X'".""" +"-c 'until X'". + +ChatDBG supports the following configuration flags before the +pyfile or -m switch: + --chat.model= + --chat.debug +""" + +_valid_models = [ + 'gpt-4-turbo-preview', + 'gpt-4-0125-preview', + 'gpt-4-1106-preview', + 'gpt-3.5-turbo-0125', + 'gpt-3.5-turbo-1106', + 'gpt-4', # no parallel calls + 'gpt-3.5-turbo' # no parallel calls +] def main(): + import getopt - opts, args = getopt.getopt(sys.argv[1:], "mhc:", ["help", "command="]) + opts, args = getopt.getopt(sys.argv[1:], + "mhc:", + ["help", "command=","chat.model=","chat.debug"]) - if any(opt in ["-h", "--help"] for opt, optarg in opts): + if any(opt in ["-h", "--help"] for opt, _ in opts): print(_usage) sys.exit() @@ -55,5 +469,21 @@ def main(): print(_usage) sys.exit(2) + for o, a in opts: + if o == '--chat.model': + if a not in _valid_models: + print(f'{a} is not supported.') + print(f'The supported models are {_valid_models}.') + _config['model'] = a + elif o == '--chat.debug': + _config['debug'] = True + elif o.startswith('--chat.'): + print(f'{o} not defined.') + print(_usage) + sys.exit(2) + + # drop all --chat options + sys.argv[:] = [x for x in sys.argv if not x.startswith('--chat.')] + pdb.Pdb = ChatDBG - pdb.main() \ No newline at end of file + pdb.main() diff --git a/src/chatdbg/chatdbg_lldb.py b/src/chatdbg/chatdbg_lldb.py index 4386d8b..6b6f6e0 100644 --- a/src/chatdbg/chatdbg_lldb.py +++ b/src/chatdbg/chatdbg_lldb.py @@ -9,6 +9,8 @@ import llm_utils import openai +import textwrap +from assistant.assistant import Assistant sys.path.append(os.path.abspath(pathlib.Path(__file__).parent.resolve())) import chatdbg_utils @@ -416,3 +418,199 @@ def converse( sys.exit(1) print(conversation.converse(client, args)) + + +_assistant = None + +def _format_history_entry(entry, indent = ''): + line, output = entry + output = llm_utils.word_wrap_except_code_blocks(output, 120) + if output: + entry = f"(ChatDBG lldb) {line}\n{output}" + else: + entry = f"(ChatDBG lldb) {line}" + return textwrap.indent(entry, indent, lambda _ : True) + +def _capture_onecmd(debugger, cmd): + # Get the command interpreter from the debugger + interpreter = debugger.GetCommandInterpreter() + + # Create an object to hold the result of the command execution + result = lldb.SBCommandReturnObject() + + # Execute a command (e.g., "version" to get the LLDB version) + interpreter.HandleCommand(cmd, result) + + # Check if the command was executed successfully + if result.Succeeded(): + # Get the output of the command + output = result.GetOutput() + return output + else: + # Get the error message if the command failed + error = result.GetError() + return f"Command Error: {error}" + +def _stack_prompt(debugger: lldb.SBDebugger): + stack_frames = textwrap.indent(_capture_onecmd(debugger, 'bt'), '') + stack = textwrap.dedent(f""" + This is the current stack. + The current frame is indicated by a the text ' * ' at + the start of the line. + ```""") + f'\n{stack_frames}\n```' + return stack + +_basic_instructions=f"""\ +You are a debugging assistant. You will be given a Python stack trace for an +error and answer questions related to the root cause of the error. + +Call the `lldb` function to run lldb debugger commands on the stopped program. The +lldb debugger keeps track of a current frame. You may call the `lldb` function +with the following strings: + + bt + Print a stack trace, with the most recent frame at the bottom. + An arrow indicates the "current frame", which determines the + context of most commands. + + up + Move the current frame count one level up in the + stack trace (to an older frame). + down + Move the current frame count one level down in the + stack trace (to a newer frame). + + p expression + Print the value of the expression. + + f + List the source code for the current frame. + The current line in the current frame is indicated by "->". + + info expression + Print the documentation and source code for the given expression, + which should be callable. + +Call the `info` function to get the documentation and source code for any +function that is visible in the current frame. + +Call the `lldb` and `info` functions as many times as you would like. + +Call `ldb` to print any variable value or expression that you believe may +contribute to the error. + +Unless it is in a common, widely-used library, you MUST call `info` on any +function that is called in the code, that apppears in the argument list for a +function call in the code, or that appears on the call stack. + +The root cause of any error is likely due to a problem in the source code within +the {os.path.dirname(sys.argv[0])} directory. + +Keep your answers under about 8-10 sentences. Conclude each response with +either a propopsed fix if you have identified the root cause or a bullet list of +1-3 suggestions for how to continue debugging. +""" + +def _instructions(debugger: lldb.SBDebugger): + source_code, traceback, exception = buildPrompt(debugger) + return f""" + +{_basic_instructions} + +In your response, never refer to the frames given below (as in, 'frame 0'). Instead, +always refer only to specific lines and filenames of source code. + +Source code for each stack frame: +``` +{source_code} +``` + +Traceback: +{traceback} + +Stop reason: {exception} + """.strip() + + +def _make_assistant(debugger: lldb.SBDebugger): + global _assistant + _assistant = Assistant("ChatDBG", + _instructions(debugger), + debug=True) + + def lldb(command): + """ + { + "name": "lldb", + "description": "Run a lldb command and get the response.", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The lldb command to run." + } + }, + "required": [ "command" ] + } + } + """ + cmd = command # any special modifications + + result = _capture_onecmd(debugger, cmd) + + print(_format_history_entry((command, result), + indent = ' ')) + + # help the LLM know where it is... + result += _stack_prompt() + + return result + + _assistant.add_function(lldb) + + +@lldb.command("chat") +def chat( + debugger: lldb.SBDebugger, + command: str, + result: str, + internal_dict: dict): + + if _assistant == None: + _make_assistant(debugger) + + def client_print(line=''): + line = llm_utils.word_wrap_except_code_blocks(line, 115) + line = textwrap.indent(line, + ' ', + lambda _ : True) + print(line, file=sys.stdout, flush=True) + + _assistant.run(command, client_print) + +@lldb.command("test") +def test( + debugger: lldb.SBDebugger, + command: str, + result: str, + internal_dict: dict): + + # Get the command interpreter from the debugger + interpreter = debugger.GetCommandInterpreter() + + # Create an object to hold the result of the command execution + result = lldb.SBCommandReturnObject() + + # Execute a command (e.g., "version" to get the LLDB version) + interpreter.HandleCommand(command, result) + + # Check if the command was executed successfully + if result.Succeeded(): + # Get the output of the command + output = result.GetOutput() + print("Command Output:", output) + else: + # Get the error message if the command failed + error = result.GetError() + print("Command Error:", error) diff --git a/test/a.out b/test/a.out index ebeda25..cdce2ab 100755 Binary files a/test/a.out and b/test/a.out differ diff --git a/test/python/README.md b/test/python/README.md new file mode 100644 index 0000000..79e0a86 --- /dev/null +++ b/test/python/README.md @@ -0,0 +1,8 @@ +# Python Tests + +* Install the packages in `requirements.txt`. Some of the tests use them +* Run, for example: + + ``` + chatpdb -c continue marbles.py + ``` diff --git a/test/python/bootstrap.py b/test/python/bootstrap.py new file mode 100644 index 0000000..7d57ce5 --- /dev/null +++ b/test/python/bootstrap.py @@ -0,0 +1,19 @@ +from datascience import * +from cs104 import * + +def make_marble_bag(): + table = Table().read_table('marble-sample.csv') + return table.column('color') + +def percent_blue(sample): + return sample + +def main(observed_marbles): + num_trials = 5 + stats = bootstrap_statistic(observed_marbles, + percent_blue, + num_trials) + assert len(stats) == 5 + +observed_marbles = make_marble_bag() +main(observed_marbles) diff --git a/test/python/marble-sample.csv b/test/python/marble-sample.csv new file mode 100644 index 0000000..19c7a7c --- /dev/null +++ b/test/python/marble-sample.csv @@ -0,0 +1,31 @@ +color +R +R +R +R +R +R +R +R +R +B +B +B +B +B +B +B +B +B +B +B +B +B +B +B +B +B +B +B +B +B diff --git a/test/python/marbles.py b/test/python/marbles.py new file mode 100644 index 0000000..35fd38a --- /dev/null +++ b/test/python/marbles.py @@ -0,0 +1,18 @@ +from datascience import * +from cs104 import * + +def make_marble_bag(): + table = Table().read_table('marble-sample.csv') + return table.column('color') + +def ratio(x,y): + return x / y + +def ratio_blue_to_red(sample): + blues = np.count_nonzero(sample == 'B') + reds = np.count_nonzero(sample == 'r') + return ratio(blues, reds) + +marbles = make_marble_bag() +if 'R' in marbles: + print(ratio_blue_to_red(marbles)) diff --git a/test/python/mean.py b/test/python/mean.py new file mode 100644 index 0000000..1358969 --- /dev/null +++ b/test/python/mean.py @@ -0,0 +1,25 @@ + +import numpy as np +from datascience import * +from cs104 import * + +def make_marble_bag(): + table = Table().read_table('marble-sample.csv') + return table.column('color') + +observed_marbles = make_marble_bag() + +def percent_blue(sample): + return np.count_nonzero(sample == 'B') / len(sample) + +def main(): + + num_trials = 5 + + stats = bootstrap_statistic(observed_marbles, + percent_blue, + num_trials) + + assert np.isclose(np.mean(stats), 0.7) + +main() \ No newline at end of file diff --git a/test/python/requirements.txt b/test/python/requirements.txt new file mode 100644 index 0000000..51214f5 --- /dev/null +++ b/test/python/requirements.txt @@ -0,0 +1,3 @@ +# packages necessary for the test files +cs104 @ git+https://github.com/cs104williams/cs104-toolbox@standalone +datascience @ git+https://github.com/cs104williams/cs104-datascience diff --git a/test/python/sample.py b/test/python/sample.py new file mode 100644 index 0000000..4691679 --- /dev/null +++ b/test/python/sample.py @@ -0,0 +1,11 @@ +import numpy as np + +red_blue_proportions = np.array([0.3, 0.7]) + +def make_sample(sample_size, probabilities): + return np.random.multinomial(sample_size, probabilities) + +def make_marble_bag(size): + return make_sample(red_blue_proportions, size) + +make_marble_bag(10)