From f89e9b7dbdc2d45a6ded4cbb6fd74296dd136cef Mon Sep 17 00:00:00 2001 From: Ido Perlmuter Date: Thu, 9 Mar 2023 17:12:35 +0200 Subject: [PATCH] Switch to ChatGPT API by default, allow other models (#27) * Switch to ChatGPT API by default, allow other models aiac will now use the ChatGPT API (and thus the gpt-3.5-turbo model) instead of the text-davinci-003 model by default. It also allows selecting the model to use, with the aforementioned two being supported, along with the code-davinci-002 model, which is specifically designed to generate code. A model can be selected via the `--model` flag, and a list of all supported models is available via the `list-models` command. Since ChatGPT usually returns a Markdown-formatted description rather than just code, the library will extract the code from the response. A new flag, `--full`, is added to prevent that and echo or save the complete response from the API. The command line prompt will now also allow users to modify the prompt after receiving a response (using the "m" key), along with the previous option to retry with the same prompt. Also included are small fixes: entering Ctrl+D and Ctrl+C after receiving a response will quit the program as expected, rather than regenerate a response. --------- Co-authored-by: Liav Yona --- .golangci.yml | 3 + README.md | 19 +++- libaiac/libaiac.go | 230 +++++++++++++++++++++++++++++++++++++++++---- main.go | 57 ++++++++--- 4 files changed, 276 insertions(+), 33 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 454d924..48e60d9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -24,6 +24,9 @@ linters: - promlinter - testpackage - varnamelen + - nonamedreturns + - tagliatelle + - exhaustruct issues: new: true new-from-rev: HEAD diff --git a/README.md b/README.md index d4b1fb5..b798a7a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ Generator. * [Instructions](#instructions) * [Installation](#installation) * [Usage](#usage) + * [Choosing a Different Model](#choosing-a-different-model) * [Example Output](#example-output) * [Troubleshooting](#troubleshooting) * [Support Channels](#support-channels) @@ -33,7 +34,8 @@ Generator. via [OpenAI](https://openai.com/)'s API. The CLI allows you to ask the model to generate templates for different scenarios (e.g. "get terraform for AWS EC2"). It will make the request, and store the resulting code to a file, or simply print it to standard -output. +output. By default, `aiac` uses the same model used by ChatGPT, but allows using +different models. ## Use Cases and Example Prompts @@ -92,7 +94,7 @@ Using `docker`: Using `go install`: - go install github.com/gofireflyio/aiac/v2 + go install github.com/gofireflyio/aiac/v2@latest Alternatively, clone the repository and build from source: @@ -101,12 +103,12 @@ Alternatively, clone the repository and build from source: ### Usage -1. Create your OpenAI API key [here](https://beta.openai.com/account/api-keys). +1. Create your OpenAI API key [here](platform.openai.com/account/api-keys). 2. Click “Create new secret key” and copy it. 3. Provide the API key via the `OPENAI_API_KEY` environment variable or via the `--api-key` command line flag. By default, aiac prints the extracted code to standard output and asks if it -should save or regenerate the code: +should save the code, regenerate it, or modify the prompt: aiac get terraform for AWS EC2 @@ -121,6 +123,15 @@ To run using `docker`: -e OPENAI_API_KEY=[PUT YOUR KEY HERE] \ ghcr.io/gofireflyio/aiac get terraform for ec2 +If you want to receive and/or store the complete Markdown output from OpenAI, +including explanations (if any), use the `--full` flag. + +### Choosing a Different Model + +Use the `--model` flag to select a different model than the default (currently +"gpt-3.5-turbo"). Not all OpenAI models are supported, use `aiac list-models` +to get a list of all supported models. + ## Example Output Command line prompt: diff --git a/libaiac/libaiac.go b/libaiac/libaiac.go index 9b8a165..a4f237c 100644 --- a/libaiac/libaiac.go +++ b/libaiac/libaiac.go @@ -3,13 +3,16 @@ package libaiac import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" "os" + "regexp" "strings" "time" + "github.com/alecthomas/kong" "github.com/briandowns/spinner" "github.com/fatih/color" "github.com/ido50/requests" @@ -20,8 +23,79 @@ import ( type Client struct { *requests.HTTPClient apiKey string + model Model + full bool } +var ( + // ErrResultTruncated is returned when the OpenAI API returned a truncated + // result. The reason for the truncation will be appended to the error + // string. + ErrResultTruncated = errors.New("result was truncated") + + // ErrNoResults is returned if the OpenAI API returned an empty result. This + // should not generally happen. + ErrNoResults = errors.New("no results return from API") + + // ErrUnsupportedModel is returned if the SetModel method is provided with + // an unsupported model + ErrUnsupportedModel = errors.New("unsupported model") + + // ErrUnexpectedStatus is returned when the OpenAI API returned a response + // with an unexpected status code + ErrUnexpectedStatus = errors.New("OpenAI returned unexpected response") + + // ErrRequestFailed is returned when the OpenAI API returned an error for + // the request + ErrRequestFailed = errors.New("request failed") +) + +// Model is an enum used to select the language model to use +type Model string + +const ( + // ModelChatGPT represents the gpt-3.5-turbo model used by ChatGPT. + ModelChatGPT = "gpt-3.5-turbo" + + // ModelTextDaVinci3 represents the text-davinci-003 language generation + // model. + ModelTextDaVinci3 = "text-davinci-003" + + // ModelCodeDaVinci2 represents the code-davinci-002 code generation model. + ModelCodeDaVinci2 = "code-davinci-002" +) + +// Decode is used by the kong library to map CLI-provided values to the Model +// type +func (m *Model) Decode(ctx *kong.DecodeContext) error { + var provided string + + err := ctx.Scan.PopValueInto("string", &provided) + if err != nil { + return fmt.Errorf("failed getting model value: %w", err) + } + + for _, supported := range []Model{ + ModelChatGPT, + ModelTextDaVinci3, + ModelCodeDaVinci2, + } { + if string(supported) == provided { + *m = supported + return nil + } + } + + return fmt.Errorf("%w %s", ErrUnsupportedModel, provided) +} + +// SupportedModels is a list of all models supported by aiac +var SupportedModels = []string{ModelChatGPT, ModelTextDaVinci3, ModelCodeDaVinci2} + +// MaxTokens is the maximum amount of tokens supported by the model used. Newer +// OpenAI models support a maximum of 4096 tokens. +var MaxTokens = 4096 + // NewClient creates a new instance of the Client struct, with the provided // input options. Neither the OpenAI API nor ChatGPT are yet contacted at this // point. @@ -32,6 +106,7 @@ func NewClient(apiKey string) *Client { cli := &Client{ apiKey: strings.TrimPrefix(apiKey, "Bearer "), + model: ModelChatGPT, } cli.HTTPClient = requests.NewClient("https://api.openai.com/v1"). @@ -52,17 +127,37 @@ func NewClient(apiKey string) *Client { err := json.NewDecoder(body).Decode(&res) if err != nil { return fmt.Errorf( - "OpenAI returned response %s", + "%w %s", + ErrUnexpectedStatus, http.StatusText(httpStatus), ) } - return fmt.Errorf("[%s] %s", res.Error.Type, res.Error.Message) + return fmt.Errorf( + "%w: [%s]: %s", + ErrRequestFailed, + res.Error.Type, + res.Error.Message, + ) }) return cli } +// SetModel changes the language model to use with the OpenAI API +func (client *Client) SetModel(model Model) *Client { + client.model = model + return client +} + +// SetFull sets whether output is returned/stored in full, including +// explanations (if any), or if only the code is extracted. Defaults to false +// (meaning only code is extracted). +func (client *Client) SetFull(full bool) *Client { + client.full = full + return client +} + // Ask asks the OpenAI API to generate code based on the provided prompt. // It is only meant to be used in command line applications (see GenerateCode // for library usage). The generated code will always be printed to standard @@ -81,10 +176,12 @@ func (client *Client) Ask( outputPath string, ) (err error) { spin := spinner.New(spinner.CharSets[2], - 100*time.Millisecond, + 100*time.Millisecond, //nolint: gomnd spinner.WithWriter(color.Error), spinner.WithSuffix("\tGenerating code ...")) + spin.Start() + killed := false defer func() { @@ -99,6 +196,7 @@ func (client *Client) Ask( } spin.Stop() + killed = true fmt.Fprintln(os.Stdout, code) @@ -108,24 +206,49 @@ func (client *Client) Ask( } if shouldRetry { + errInvalidInput := errors.New("invalid input, please try again") //nolint: goerr113 + input := promptui.Prompt{ - Label: "Hit [S/s] to save the file or [R/r] to retry [Q/q] to quit", + Label: "Hit [S/s] to save the file, [R/r] to retry, [M/m] to modify prompt, [Q/q] to quit", Validate: func(s string) error { - if strings.ToLower(s) != "s" && strings.ToLower(s) != "r" && strings.ToLower(s) != "q" { - return fmt.Errorf("Invalid input. Try again please.") + switch strings.ToLower(s) { + case "s", "r", "m", "q": + return nil } - return nil + + return errInvalidInput }, } result, err := input.Run() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } - if strings.ToLower(result) == "q" { - // finish without saving - return nil - } else if err != nil || strings.ToLower(result) == "r" { + return fmt.Errorf("prompt failed: %w", err) + } + + switch strings.ToLower(result) { + case "r": // retry once more return client.Ask(ctx, prompt, shouldRetry, shouldQuit, outputPath) + case "m": + // let user modify prompt + input := promptui.Prompt{ + Label: "New prompt", + Default: prompt, + } + + prompt, err = input.Run() + if err != nil { + return fmt.Errorf("prompt failed: %w", err) + } + + return client.Ask(ctx, prompt, shouldRetry, shouldQuit, outputPath) + case "q": + // finish without saving + return nil } } @@ -136,7 +259,7 @@ func (client *Client) Ask( outputPath, err = input.Run() if err != nil { - return err + return fmt.Errorf("prompt failed: %w", err) } } @@ -159,12 +282,84 @@ func (client *Client) Ask( return nil } +var codeRegex = regexp.MustCompile("(?ms)^```(?:[^\n]*)\n(.*?)\n```$") + // GenerateCode sends the provided prompt to the OpenAI API and returns the // generated code. func (client *Client) GenerateCode(ctx context.Context, prompt string) ( code string, err error, ) { + if client.model == ModelChatGPT { + code, err = client.generateWithChatModel(ctx, prompt) + } else { + code, err = client.generateWithCompletionsModel(ctx, prompt) + } + + if err != nil { + return "", err + } + + if client.full { + return code, nil + } + + m := codeRegex.FindStringSubmatch(code) + if m == nil || m[1] == "" { + return code, nil + } + + return m[1], nil +} + +func (client *Client) generateWithChatModel(ctx context.Context, prompt string) ( + code string, + err error, +) { + var answer struct { + Choices []struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + Index int64 `json:"index"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + } + + err = client.NewRequest("POST", "/chat/completions"). + JSONBody(map[string]interface{}{ + "model": client.model, + "messages": []map[string]string{ + {"role": "user", "content": prompt}, + }, + "max_tokens": MaxTokens + 1 - len(prompt), + }). + Into(&answer). + RunContext(ctx) + if err != nil { + return code, fmt.Errorf("failed sending prompt: %w", err) + } + + if len(answer.Choices) == 0 { + return code, ErrNoResults + } + + if answer.Choices[0].FinishReason != "stop" { + return code, fmt.Errorf( + "%w: %s", + ErrResultTruncated, + answer.Choices[0].FinishReason, + ) + } + + return strings.TrimSpace(answer.Choices[0].Message.Content), nil +} + +func (client *Client) generateWithCompletionsModel( + ctx context.Context, + prompt string, +) (code string, err error) { var answer struct { Choices []struct { Text string `json:"text"` @@ -173,27 +368,26 @@ func (client *Client) GenerateCode(ctx context.Context, prompt string) ( } `json:"choices"` } - var status int err = client.NewRequest("POST", "/completions"). JSONBody(map[string]interface{}{ - "model": "text-davinci-003", + "model": client.model, "prompt": prompt, - "max_tokens": 4097 - len(prompt), + "max_tokens": MaxTokens + 1 - len(prompt), }). Into(&answer). - StatusInto(&status). RunContext(ctx) if err != nil { return code, fmt.Errorf("failed sending prompt: %w", err) } if len(answer.Choices) == 0 { - return code, fmt.Errorf("no results returned from API") + return code, ErrNoResults } if answer.Choices[0].FinishReason != "stop" { return code, fmt.Errorf( - "result was truncated by API due to %s", + "%w: %s", + ErrResultTruncated, answer.Choices[0].FinishReason, ) } diff --git a/main.go b/main.go index 97dab1e..0d16445 100644 --- a/main.go +++ b/main.go @@ -11,35 +11,70 @@ import ( ) type flags struct { - APIKey string `help:"OpenAI API key" required:"" env:"OPENAI_API_KEY"` - OutputFile string `help:"Output file to push resulting code to, defaults to stdout" default:"-" type:"path" short:"o"` - Save bool `help:"Save AIaC response without retry prompt" default:false short:"s"` - Quiet bool `help:"Print AIaC response to stdout and exit (non-interactive mode)" default:false short:"q"` + APIKey string `help:"OpenAI API key" required:"" env:"OPENAI_API_KEY"` + ListModels struct{} `cmd:"" help:"List supported models"` + OutputFile string `help:"Output file to push resulting code to, defaults to stdout" default:"-" type:"path" short:"o"` //nolint: lll + Quiet bool `help:"Print AIaC response to stdout and exit (non-interactive mode)" default:"false" short:"q"` + Save bool `help:"Save AIaC response without retry prompt" default:"false" short:"s"` Get struct { - What []string `arg:"" help:"Which IaC template to generate"` + Full bool `help:"Return full output, including explanations, if any" default:"false" short:"f"` + Model libaiac.Model `help:"Model to use, default to \"gpt-3.5-turbo\""` + What []string `arg:"" help:"Which IaC template to generate"` } `cmd:"" help:"Generate IaC code" aliases:"generate"` } func main() { - if len(os.Args) < 2 { + if len(os.Args) < 2 { //nolint: gomnd os.Args = append(os.Args, "--help") } + var cli flags - cmd := kong.Parse(&cli) + parser := kong.Must( + &cli, + kong.Name("aiac"), + kong.Description("Artificial Intelligence Infrastructure-as-Code Generator."), + ) + + ctx, err := parser.Parse(os.Args[1:]) + if err != nil { + if err.Error() == "missing flags: --api-key=STRING" { + fmt.Fprintln(os.Stderr, `You must provide an OpenAI API key via the --api-key flag, or +the OPENAI_API_KEY environment variable. - if cmd.Command() != "get " { +Get your API key from https://platform.openai.com/account/api-keys.`) + } else { + fmt.Fprintf(os.Stderr, "%v\n", err) + } + + os.Exit(1) + } + + if ctx.Command() == "list-models" { + for _, model := range libaiac.SupportedModels { + fmt.Println(model) + } + + os.Exit(0) + } + + if ctx.Command() != "get " { fmt.Fprintln(os.Stderr, "Unknown command") os.Exit(1) } - client := libaiac.NewClient(cli.APIKey) + client := libaiac.NewClient(cli.APIKey). + SetFull(cli.Get.Full) + + if cli.Get.Model != "" { + client.SetModel(cli.Get.Model) + } - err := client.Ask( + err = client.Ask( context.TODO(), // NOTE: we are prepending the word "generate" to the prompt, this // ensures the language model actually generates code. The word "get", // on the other hand, doesn't necessarily result in code being generated. - fmt.Sprintf("generate %s", strings.Join(cli.Get.What, " ")), + fmt.Sprintf("generate code for a %s", strings.Join(cli.Get.What, " ")), !cli.Save, cli.Quiet, cli.OutputFile,