Skip to content

Commit

Permalink
Merge pull request #7 from ghchinoy/multimodal
Browse files Browse the repository at this point in the history
Multimodal
  • Loading branch information
ghchinoy authored Apr 21, 2024
2 parents 1223b11 + 12d7897 commit e567a0a
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 9 deletions.
6 changes: 4 additions & 2 deletions cmd/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
"github.com/spf13/cobra"
)

var (
systemInstructions string
)

func init() {
rootCmd.AddCommand(promptCmd)

Expand Down Expand Up @@ -51,8 +55,6 @@ func generateContent(cmd *cobra.Command, args []string) {
log.Printf("prompt: %s", args)
}

//fmt.Printf("\nModel name: %s\n", modelName)

// Lookup the model based on the name
m, err := model.Get(modelName)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/ghchinoy/gen

go 1.22.1
go 1.22.2

require (
cloud.google.com/go/aiplatform v1.67.0
Expand All @@ -10,7 +10,7 @@ require (
github.com/spf13/cobra v1.8.0
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.18.2
google.golang.org/api v0.174.0
google.golang.org/api v0.175.0
google.golang.org/genproto/googleapis/api v0.0.0-20240415180920-8c6c420018be
google.golang.org/protobuf v1.33.0
)
Expand Down
5 changes: 3 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cloud.google.com/go/auth v0.2.2 h1:gmxNJs4YZYcw6YvKRtVBaF2fyUE6UrWPyzU8jHvYfmI=
cloud.google.com/go/auth v0.2.2/go.mod h1:2bDNJWtWziDT3Pu1URxHHbkHE/BbOCuyUiKIGcNvafo=
cloud.google.com/go/auth/oauth2adapt v0.2.1 h1:VSPmMmUlT8CkIZ2PzD9AlLN+R3+D1clXMWHHa6vG/Ag=
cloud.google.com/go/auth/oauth2adapt v0.2.1/go.mod h1:tOdK/k+D2e4GEwfBRA48dKNQiDsqIXxLh7VU319eV0g=
cloud.google.com/go/compute v1.25.1 h1:ZRpHJedLtTpKgr3RV1Fx23NuaAEN1Zfx9hw1u4aJdjU=
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM=
Expand Down Expand Up @@ -187,8 +188,8 @@ golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.174.0 h1:zB1BWl7ocxfTea2aQ9mgdzXjnfPySllpPOskdnO+q34=
google.golang.org/api v0.174.0/go.mod h1:aC7tB6j0HR1Nl0ni5ghpx6iLasmAX78Zkh/wgxAAjLg=
google.golang.org/api v0.175.0 h1:9bMDh10V9cBuU8N45Wlc3cKkItfqMRV0Fi8UscLEtbY=
google.golang.org/api v0.175.0/go.mod h1:Rra+ltKu14pps/4xTycZfobMgLpbosoaaL7c+SEMrO8=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
Expand Down
96 changes: 93 additions & 3 deletions internal/model/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ import (
"fmt"
"io"
"log"
"mime"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"

"cloud.google.com/go/vertexai/genai"
Expand All @@ -16,9 +21,35 @@ import (
// UseGeminiModel calls Gemini's generate content method
func UseGeminiModel(ctx context.Context, modelName string, cfg Config, args []string) error {
log.Printf("Gemini [%s]", modelName)
prompt := genai.Text(args[0])

var promptParts []genai.Part
for _, arg := range args {
if argLooksLikeGCSURL(arg) {
part := genai.FileData{
MIMEType: mime.TypeByExtension(filepath.Ext(arg)),
FileURI: arg,
}
promptParts = append(promptParts, part)
} else if argLooksLikeURL(arg) {
part, err := getPartFromURL(arg)
if err != nil {
return err
}
promptParts = append(promptParts, part)
} else if argLooksLikeFilename(arg) {
part, err := getPartFromFile(arg)
if err != nil {
return err
}
promptParts = append(promptParts, part)
} else {

promptParts = append(promptParts, genai.Text(arg))
}
}

var buf bytes.Buffer
if err := GenerateContentGemini(ctx, modelName, cfg, &buf, []genai.Part{prompt}); err != nil {
if err := GenerateContentGemini(ctx, modelName, cfg, &buf, promptParts); err != nil {
log.Printf("error generating content: %v", err)
os.Exit(1)
}
Expand All @@ -35,10 +66,11 @@ func GenerateContentGemini(ctx context.Context, modelName string, cfg Config, w
// others be made public or should this one be made private.

client, err := genai.NewClient(ctx, cfg.ProjectID, cfg.RegionID)

if err != nil {
return fmt.Errorf("error creating a client: %v", err)
}
defer client.Close()

gemini := client.GenerativeModel(modelName)

if cfg.ConfigFile != "" {
Expand Down Expand Up @@ -68,6 +100,7 @@ func GenerateContentGemini(ctx context.Context, modelName string, cfg Config, w
}
return fmt.Errorf("error generating content: %w", err)
}

if cfg.OutputType == "json" {
rb, _ := json.MarshalIndent(resp, "", " ")
fmt.Fprintln(w, string(rb))
Expand All @@ -84,3 +117,60 @@ func GenerateContentGemini(ctx context.Context, modelName string, cfg Config, w
}
return nil
}

// thanks to eilben's https://github.com/eliben/gemini-cli/blob/main/internal/commands/prompt.go

// argLooksLikeFilename says if command-line argument looks like a filename,
// which we consider to have an alphabetical extension following a dot separator,
// but not look like a URL.
func argLooksLikeFilename(arg string) bool {
re := regexp.MustCompile(`\.[a-zA-Z]+$`)
return re.MatchString(arg) && strings.Index(arg, "://") < 0
}

func argLooksLikeGCSURL(arg string) bool {
return strings.HasPrefix(arg, "gs://")
}

func argLooksLikeURL(arg string) bool {
_, err := url.ParseRequestURI(arg)
return err == nil
}

func getPartFromFile(path string) (genai.Part, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, err
}

ext := filepath.Ext(path)
switch strings.TrimSpace(ext) {
case ".jpg", ".jpeg":
return genai.ImageData("jpeg", b), nil
case ".png":
return genai.ImageData("png", b), nil
default:
return nil, fmt.Errorf("invalid image file extension: %s", ext)
}
}

func getPartFromURL(url string) (genai.Part, error) {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("failed to fetch image from url: %w", err)
}
defer resp.Body.Close()

urlData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read image bytes: %w", err)
}

mimeType := resp.Header.Get("Content-Type")
parts := strings.Split(mimeType, "/")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid mime type %v", mimeType)
}

return genai.ImageData(parts[1], urlData), nil
}
1 change: 1 addition & 0 deletions internal/model/models
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ gemini,text,gemini-1.0-pro
gemini,text,gemini-1.0-pro-001
gemini,text,gemini-1.0-pro-001
gemini,text,gemini-1.0-ultra-001
gemini,multimodal,gemini-1.0-pro-vision
gemini,multimodal,gemini-1.0-pro-vision-001
gemini,multimodal,gemini-1.0-ultra-vision-001
gemini,multimodal,gemini-1.5-pro-preview-0409
Expand Down

0 comments on commit e567a0a

Please sign in to comment.