Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DATA-3467: Cloud Inference CLI #4748

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2748,6 +2748,52 @@ This won't work unless you have an existing installation of our GitHub app on yo
},
},
},
{
Name: "infer",
Usage: "run cloud hosted inference on an image",
UsageText: createUsageText("inference infer", []string{
generalFlagOrgID, inferenceFlagFileOrgID, inferenceFlagFileID,
inferenceFlagFileLocationID, inferenceFlagModelOrgID, inferenceFlagModelName, inferenceFlagModelVersion,
}, true, false),
Flags: []cli.Flag{
&cli.StringFlag{
Name: generalFlagOrgID,
Usage: "organization ID that is executing the inference job",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileOrgID,
Usage: "organization ID that owns the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileID,
Usage: "file ID of the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagFileLocationID,
Usage: "location ID of the file to run inference on",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelOrgID,
Usage: "organization ID that hosts the model to use to run inference",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelName,
Usage: "name of the model to use to run inference",
Required: true,
},
&cli.StringFlag{
Name: inferenceFlagModelVersion,
Usage: "version of the model to use to run inference",
Required: true,
},
},
Action: createCommandWithT[mlInferenceInferArgs](MLInferenceInferAction),
},
{
Name: "version",
Usage: "print version info for this program",
Expand Down
2 changes: 2 additions & 0 deletions cli/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
buildpb "go.viam.com/api/app/build/v1"
datapb "go.viam.com/api/app/data/v1"
datasetpb "go.viam.com/api/app/dataset/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
mltrainingpb "go.viam.com/api/app/mltraining/v1"
packagepb "go.viam.com/api/app/packages/v1"
apppb "go.viam.com/api/app/v1"
Expand Down Expand Up @@ -544,6 +545,7 @@ func (c *viamClient) ensureLoggedInInner() error {
c.packageClient = packagepb.NewPackageServiceClient(conn)
c.datasetClient = datasetpb.NewDatasetServiceClient(conn)
c.mlTrainingClient = mltrainingpb.NewMLTrainingServiceClient(conn)
c.mlInferenceClient = mlinferencepb.NewMLInferenceServiceClient(conn)
c.buildClient = buildpb.NewBuildServiceClient(conn)

return nil
Expand Down
22 changes: 12 additions & 10 deletions cli/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
buildpb "go.viam.com/api/app/build/v1"
datapb "go.viam.com/api/app/data/v1"
datasetpb "go.viam.com/api/app/dataset/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
mltrainingpb "go.viam.com/api/app/mltraining/v1"
packagepb "go.viam.com/api/app/packages/v1"
apppb "go.viam.com/api/app/v1"
Expand Down Expand Up @@ -69,16 +70,17 @@ var errNoShellService = errors.New("shell service is not enabled on this machine
// viamClient wraps a cli.Context and provides all the CLI command functionality
// needed to talk to the app and data services but not directly to robot parts.
type viamClient struct {
c *cli.Context
conf *Config
client apppb.AppServiceClient
dataClient datapb.DataServiceClient
packageClient packagepb.PackageServiceClient
datasetClient datasetpb.DatasetServiceClient
mlTrainingClient mltrainingpb.MLTrainingServiceClient
buildClient buildpb.BuildServiceClient
baseURL *url.URL
authFlow *authFlow
c *cli.Context
conf *Config
client apppb.AppServiceClient
dataClient datapb.DataServiceClient
packageClient packagepb.PackageServiceClient
datasetClient datasetpb.DatasetServiceClient
mlTrainingClient mltrainingpb.MLTrainingServiceClient
mlInferenceClient mlinferencepb.MLInferenceServiceClient
buildClient buildpb.BuildServiceClient
baseURL *url.URL
authFlow *authFlow

selectedOrg *apppb.Organization
selectedLoc *apppb.Location
Expand Down
122 changes: 122 additions & 0 deletions cli/ml_inference.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package cli

import (
"context"
"fmt"
"strings"

"github.com/pkg/errors"
"github.com/urfave/cli/v2"
v1 "go.viam.com/api/app/data/v1"
mlinferencepb "go.viam.com/api/app/mlinference/v1"
)

const (
inferenceFlagFileOrgID = "file-org-id"
inferenceFlagFileID = "file-id"
inferenceFlagFileLocationID = "file-location-id"
inferenceFlagModelOrgID = "model-org-id"
inferenceFlagModelName = "model-name"
inferenceFlagModelVersion = "model-version"
)

type mlInferenceInferArgs struct {
OrgID string
FileOrgID string
FileID string
FileLocationID string
ModelOrgID string
ModelName string
ModelVersion string
}

// MLInferenceInferAction is the corresponding action for 'inference infer'.
func MLInferenceInferAction(c *cli.Context, args mlInferenceInferArgs) error {
client, err := newViamClient(c)
if err != nil {
return err
}

_, err = client.mlRunInference(
args.OrgID, args.FileOrgID, args.FileID, args.FileLocationID,
args.ModelOrgID, args.ModelName, args.ModelVersion)
if err != nil {
return err
}
return nil
}

// mlRunInference runs inference on an image with the specified parameters.
func (c *viamClient) mlRunInference(orgID, fileOrgID, fileID, fileLocation, modelOrgID,
modelName, modelVersion string,
) (*mlinferencepb.GetInferenceResponse, error) {
if err := c.ensureLoggedIn(); err != nil {
return nil, err
}

req := &mlinferencepb.GetInferenceRequest{
OrganizationId: orgID,
BinaryId: &v1.BinaryID{
FileId: fileID,
OrganizationId: fileOrgID,
LocationId: fileLocation,
},
RegistryItemId: fmt.Sprintf("%s:%s", modelOrgID, modelName),
RegistryItemVersion: modelVersion,
}

resp, err := c.mlInferenceClient.GetInference(context.Background(), req)
if err != nil {
return nil, errors.Wrapf(err, "received error from server")
}
c.printInferenceResponse(resp)
return resp, nil
}

// printInferenceResponse prints a neat representation of the GetInferenceResponse.
func (c *viamClient) printInferenceResponse(resp *mlinferencepb.GetInferenceResponse) {
printf(c.c.App.Writer, "Inference Response:")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would know better than I would but is this the kind of format that you would want? Meaning, normally from a CLI I would expect JSON or something configurable that I can do programmatically. But I also know absolutely nothing about ML and this could be the way ML people want the response.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, this format isn't very good for most usecases, but it's readable. I think the primary use case of this command will be to debug / easily check if the service is functioning.

printf(c.c.App.Writer, "Output Tensors:")
if resp.OutputTensors != nil {
for name, tensor := range resp.OutputTensors.Tensors {
printf(c.c.App.Writer, " Tensor Name: %s", name)
printf(c.c.App.Writer, " Shape: %v", tensor.Shape)
if tensor.Tensor != nil {
var sb strings.Builder
for i, value := range tensor.GetDoubleTensor().GetData() {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%.4f", value))
}
printf(c.c.App.Writer, " Values: [%s]", sb.String())
} else {
printf(c.c.App.Writer, " No values available.")
}
}
} else {
printf(c.c.App.Writer, " No output tensors.")
}

printf(c.c.App.Writer, "Annotations:")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just include here the format [x_min, y_min, x_max, y_max]

printf(c.c.App.Writer, "Bounding Box Format: [x_min, y_min, x_max, y_max]")
if resp.Annotations != nil {
for _, bbox := range resp.Annotations.Bboxes {
printf(c.c.App.Writer, " Bounding Box ID: %s, Label: %s",
bbox.Id, bbox.Label)
printf(c.c.App.Writer, " Coordinates: [%f, %f, %f, %f]",
bbox.XMinNormalized, bbox.YMinNormalized, bbox.XMaxNormalized, bbox.YMaxNormalized)
if bbox.Confidence != nil {
printf(c.c.App.Writer, " Confidence: %.4f", *bbox.Confidence)
}
}
for _, classification := range resp.Annotations.Classifications {
printf(c.c.App.Writer, " Classification Label: %s", classification.Label)
if classification.Confidence != nil {
printf(c.c.App.Writer, " Confidence: %.4f", *classification.Confidence)
}
}
} else {
printf(c.c.App.Writer, " No annotations.")
}
}
Loading