Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

harden: Only store registered repositories #1155

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 36 additions & 43 deletions cmd/cli/app/provider/provider_enroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (

ghclient "github.com/stacklok/mediator/internal/providers/github"
"github.com/stacklok/mediator/internal/util"
"github.com/stacklok/mediator/internal/util/cli"
"github.com/stacklok/mediator/internal/util/rand"
pb "github.com/stacklok/mediator/pkg/api/protobuf/go/mediator/v1"
)
Expand All @@ -48,16 +49,10 @@ type Response struct {
// MAX_CALLS is the maximum number of calls to the gRPC server before stopping.
const MAX_CALLS = 300

// just syncs repos for the specific provider and project
func syncRepos(ctx context.Context, client pb.RepositoryServiceClient, provider string, project string) error {
_, err := client.SyncRepositories(ctx, &pb.SyncRepositoriesRequest{Provider: provider, ProjectId: project})
return err
}

// callBackServer starts a server and handler to listen for the OAuth callback.
// It will wait for either a success or failure response from the server.
func callBackServer(ctx context.Context, provider string, project string, port string,
wg *sync.WaitGroup, client pb.OAuthServiceClient, since int64, repos_client pb.RepositoryServiceClient) {
wg *sync.WaitGroup, client pb.OAuthServiceClient, since int64) {
server := &http.Server{
Addr: fmt.Sprintf(":%s", port),
ReadHeaderTimeout: time.Second * 10, // Set an appropriate timeout value
Expand Down Expand Up @@ -103,9 +98,7 @@ func callBackServer(ctx context.Context, provider string, project string, port s
res, err := client.VerifyProviderTokenFrom(clientCtx,
&pb.VerifyProviderTokenFromRequest{Provider: provider, ProjectId: project, Timestamp: timestamppb.New(t)})
if err == nil && res.Status == "OK" {
// we can sync repos
err := syncRepos(clientCtx, repos_client, provider, project)
util.ExitNicelyOnError(err, "Error syncing repos")
return
}
if err != nil || res.Status == "OK" || calls >= MAX_CALLS {
stopServer = true
Expand Down Expand Up @@ -141,7 +134,6 @@ actions such as adding repositories.`,
defer conn.Close()

client := pb.NewOAuthServiceClient(conn)
repos_client := pb.NewRepositoryServiceClient(conn)
ctx, cancel := util.GetAppContext()
defer cancel()
oAuthCallbackCtx, oAuthCancel := context.WithTimeout(context.Background(), MAX_CALLS*time.Second)
Expand All @@ -153,40 +145,41 @@ actions such as adding repositories.`,
&pb.StoreProviderTokenRequest{Provider: provider, ProjectId: project, AccessToken: pat, Owner: &owner})
util.ExitNicelyOnError(err, "Error storing token")

err = syncRepos(ctx, repos_client, provider, project)
util.ExitNicelyOnError(err, "Error syncing repos")
fmt.Println("Provider enrolled successfully")
} else {
// Get random port
port, err := rand.GetRandomPort()
util.ExitNicelyOnError(err, "Error getting random port")

resp, err := client.GetAuthorizationURL(ctx, &pb.GetAuthorizationURLRequest{
Provider: provider,
ProjectId: project,
Cli: true,
Port: int32(port),
Owner: &owner,
})
util.ExitNicelyOnError(err, "Error getting authorization URL")

fmt.Printf("Your browser will now be opened to: %s\n", resp.GetUrl())
fmt.Println("Please follow the instructions on the page to complete the OAuth flow.")
fmt.Println("Once the flow is complete, the CLI will close")
fmt.Println("If this is a headless environment, please copy and paste the URL into a browser on a different machine.")

if err := browser.OpenURL(resp.GetUrl()); err != nil {
fmt.Fprintf(os.Stderr, "Error opening browser: %s\n", err)
os.Exit(1)
}
openTime := time.Now().Unix()

var wg sync.WaitGroup
wg.Add(1)
cli.PrintCmd(cmd, "Provider enrolled successfully")
return
}

go callBackServer(oAuthCallbackCtx, provider, project, fmt.Sprintf("%d", port), &wg, client, openTime, repos_client)
wg.Wait()
// Get random port
port, err := rand.GetRandomPort()
util.ExitNicelyOnError(err, "Error getting random port")

resp, err := client.GetAuthorizationURL(ctx, &pb.GetAuthorizationURLRequest{
Provider: provider,
ProjectId: project,
Cli: true,
Port: int32(port),
Owner: &owner,
})
util.ExitNicelyOnError(err, "Error getting authorization URL")

fmt.Printf("Your browser will now be opened to: %s\n", resp.GetUrl())
fmt.Println("Please follow the instructions on the page to complete the OAuth flow.")
fmt.Println("Once the flow is complete, the CLI will close")
fmt.Println("If this is a headless environment, please copy and paste the URL into a browser on a different machine.")

if err := browser.OpenURL(resp.GetUrl()); err != nil {
fmt.Fprintf(os.Stderr, "Error opening browser: %s\n", err)
os.Exit(1)
}
openTime := time.Now().Unix()

var wg sync.WaitGroup
wg.Add(1)

go callBackServer(oAuthCallbackCtx, provider, project, fmt.Sprintf("%d", port), &wg, client, openTime)
wg.Wait()

cli.PrintCmd(cmd, "Provider enrolled successfully")
},
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/app/repo/repo_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ var repo_getCmd = &cobra.Command{
defer cancel()

// check repo by id
var repository *pb.RepositoryRecord
var repository *pb.Repository
if repoid != "" {
resp, err := client.GetRepositoryById(ctx, &pb.GetRepositoryByIdRequest{
RepositoryId: repoid,
Expand Down
40 changes: 28 additions & 12 deletions cmd/cli/app/repo/repo_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ import (
"fmt"
"os"

"github.com/olekukonko/tablewriter"
"github.com/charmbracelet/bubbles/table"
"github.com/spf13/cobra"
"github.com/spf13/viper"

github "github.com/stacklok/mediator/internal/providers/github"
"github.com/stacklok/mediator/internal/util"
"github.com/stacklok/mediator/internal/util/cli"
pb "github.com/stacklok/mediator/pkg/api/protobuf/go/mediator/v1"
)

Expand Down Expand Up @@ -68,7 +69,6 @@ var repo_listCmd = &cobra.Command{
resp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{
Provider: provider,
ProjectId: projectID,
Filter: pb.RepoFilter_REPO_FILTER_SHOW_REGISTERED_ONLY,
})
if err != nil {
fmt.Fprintf(os.Stderr, "Error getting repo of repos: %s\n", err)
Expand All @@ -77,21 +77,37 @@ var repo_listCmd = &cobra.Command{

switch format {
case "", "table":
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"Id", "Project ID", "Provider Id", "Name", "Is fork", "Is private"})
columns := []table.Column{
{Title: "ID", Width: 40},
{Title: "Project", Width: 40},
{Title: "Provider", Width: 15},
{Title: "Upstream ID", Width: 15},
{Title: "Owner", Width: 15},
{Title: "Name", Width: 15},
}

var rows []table.Row
for _, v := range resp.Results {
row := []string{
v.Id,
v.ProjectId,
row := table.Row{
*v.Id,
*v.Context.Project,
v.Context.Provider,
fmt.Sprintf("%d", v.GetRepoId()),
fmt.Sprintf("%s/%s", v.GetOwner(), v.GetName()),
fmt.Sprintf("%t", v.GetIsFork()),
fmt.Sprintf("%t", v.GetIsPrivate()),
v.GetOwner(),
v.GetName(),
}
table.Append(row)
rows = append(rows, row)
}
table.Render()

t := table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(false),
table.WithHeight(len(rows)),
table.WithStyles(cli.TableHiddenSelectStyles),
)

cli.PrintCmd(cmd, cli.TableRender(t))
case "json":
out, err := util.GetJsonFromProto(resp)
util.ExitNicelyOnError(err, "Error getting json from proto")
Expand Down
97 changes: 81 additions & 16 deletions cmd/cli/app/repo/repo_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ import (
"strings"

"github.com/AlecAivazis/survey/v2"
"github.com/charmbracelet/bubbles/table"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"k8s.io/utils/strings/slices"

github "github.com/stacklok/mediator/internal/providers/github"
"github.com/stacklok/mediator/internal/util"
"github.com/stacklok/mediator/internal/util/cli"
pb "github.com/stacklok/mediator/pkg/api/protobuf/go/mediator/v1"
)

Expand All @@ -53,7 +55,6 @@ var repo_registerCmd = &cobra.Command{
}
},
Run: func(cmd *cobra.Command, args []string) {

provider := util.GetConfigValue("provider", "provider", cmd, "").(string)
if provider != github.Github {
fmt.Fprintf(os.Stderr, "Only %s is supported at this time\n", github.Github)
Expand All @@ -69,21 +70,54 @@ var repo_registerCmd = &cobra.Command{
ctx, cancel := util.GetAppContext()
defer cancel()

req := &pb.ListRepositoriesRequest{
// Get the list of repos
listResp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{
Provider: provider,
ProjectId: projectID,
Filter: pb.RepoFilter_REPO_FILTER_SHOW_NOT_REGISTERED_ONLY,
})
if err != nil {
cli.PrintCmd(cmd, "Error getting list of repos: %s\n", err)
os.Exit(1)
}

// Get the list of repos
listResp, err := client.ListRepositories(ctx, req)
cli.PrintCmd(cmd, "Found %d registered repositories\n", len(listResp.Results))

// Get list of remot repos
remoteListResp, err := client.ListRemoteRepositoriesFromProvider(ctx, &pb.ListRemoteRepositoriesFromProviderRequest{
Provider: provider,
ProjectId: projectID,
})
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error getting list of repos: %s\n", err)
cli.PrintCmd(cmd, "Error getting list of remote repos: %s\n", err)
os.Exit(1)
}

cli.PrintCmd(cmd, "Found %d remote repositories\n", len(remoteListResp.Results))

// Unregistered repos are in remoteListResp but not in listResp
// build a list of unregistered repos
var unregisteredRepos []*pb.UpstreamRepositoryRef
for _, remoteRepo := range remoteListResp.Results {
found := false
for _, repo := range listResp.Results {
if remoteRepo.Owner == repo.Owner && remoteRepo.Name == repo.Name {
found = true
break
}
}
if !found {
unregisteredRepos = append(unregisteredRepos, &pb.UpstreamRepositoryRef{
Owner: remoteRepo.Owner,
Name: remoteRepo.Name,
RepoId: remoteRepo.RepoId,
})
}
}

cli.PrintCmd(cmd, "Found %d unregistered repositories\n", len(unregisteredRepos))

// Get the selected repos
selectedRepos, err := getSelectedRepositories(listResp, cfgFlagRepos)
selectedRepos, err := getSelectedRepositories(unregisteredRepos, cfgFlagRepos)
if err != nil {
if errors.Is(err, errNoRepositoriesSelected) {
_, _ = fmt.Fprintf(os.Stderr, "%v\n", err)
Expand Down Expand Up @@ -111,11 +145,42 @@ var repo_registerCmd = &cobra.Command{
os.Exit(1)
}

// Print the registered repos
for _, repo := range registerResp.Results {
fmt.Printf("Registered repository: %s/%s\n", repo.Owner, repo.Repository)
// The result gives a list of repositories with the registration status
// Let's parse the results and print the status
columns := []table.Column{
{Title: "Repository", Width: 35},
{Title: "Status", Width: 15},
{Title: "Message", Width: 60},
}

rows := make([]table.Row, len(registerResp.Results))
for i, result := range registerResp.Results {
rows[i] = table.Row{
fmt.Sprintf("%s/%s", result.Repository.Owner, result.Repository.Name),
}

if result.Status.Success {
rows[i] = append(rows[i], "Registered")
} else {
rows[i] = append(rows[i], "Failed")
}

if result.Status.Error != nil {
rows[i] = append(rows[i], *result.Status.Error)
} else {
rows[i] = append(rows[i], "")
}
}

t := table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(false),
table.WithHeight(len(rows)),
table.WithStyles(cli.TableHiddenSelectStyles),
)

cli.PrintCmd(cmd, cli.TableRender(t))
},
}

Expand All @@ -129,20 +194,20 @@ func init() {
}
}

func getSelectedRepositories(listResp *pb.ListRepositoriesResponse, flagRepos string) ([]*pb.Repositories, error) {
func getSelectedRepositories(repoList []*pb.UpstreamRepositoryRef, flagRepos string) ([]*pb.UpstreamRepositoryRef, error) {
// If no repos are found, exit
if len(listResp.Results) == 0 {
if len(repoList) == 0 {
return nil, fmt.Errorf("no repositories found")
}

// Create a slice of strings to hold the repo names
repoNames := make([]string, len(listResp.Results))
repoNames := make([]string, len(repoList))

// Map of repo names to IDs
repoIDs := make(map[string]int32)

// Populate the repoNames slice and repoIDs map
for i, repo := range listResp.Results {
for i, repo := range repoList {
repoNames[i] = fmt.Sprintf("%s/%s", repo.Owner, repo.Name)
repoIDs[repoNames[i]] = repo.RepoId
}
Expand Down Expand Up @@ -184,7 +249,7 @@ func getSelectedRepositories(listResp *pb.ListRepositoriesResponse, flagRepos st
}

// Create a slice of Repositories protobufs
protoRepos := make([]*pb.Repositories, len(allSelectedRepos))
protoRepos := make([]*pb.UpstreamRepositoryRef, len(allSelectedRepos))

// Convert the selected repos into a slice of Repositories protobufs
for i, repo := range allSelectedRepos {
Expand All @@ -193,7 +258,7 @@ func getSelectedRepositories(listResp *pb.ListRepositoriesResponse, flagRepos st
_, _ = fmt.Fprintf(os.Stderr, "Unexpected repository name format: %s, skipping registration\n", repo)
continue
}
protoRepos[i] = &pb.Repositories{
protoRepos[i] = &pb.UpstreamRepositoryRef{
Owner: splitRepo[0],
Name: splitRepo[1],
RepoId: repoIDs[repo],
Expand Down
2 changes: 1 addition & 1 deletion cmd/dev/app/rule_type/rule_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func readEntityFromFile(fpath string, entType mediatorv1.Entity) (protoreflect.P

switch entType {
case mediatorv1.Entity_ENTITY_REPOSITORIES:
out = &mediatorv1.RepositoryResult{}
out = &mediatorv1.Repository{}
case mediatorv1.Entity_ENTITY_ARTIFACTS:
out = &mediatorv1.Artifact{}
case mediatorv1.Entity_ENTITY_PULL_REQUESTS:
Expand Down
9 changes: 9 additions & 0 deletions cmd/server/app/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ var serveCmd = &cobra.Command{
return err
}

// webhook config validation
webhookURL := cfg.WebhookConfig.ExternalWebhookURL
webhookping := cfg.WebhookConfig.ExternalPingURL
webhooksecret := cfg.WebhookConfig.WebhookSecret
if webhookURL == "" || webhookping == "" || webhooksecret == "" {
return fmt.Errorf("webhook configuration is not set")
}

// Identity
parsedURL, err := url.Parse(cfg.Identity.IssuerUrl)
if err != nil {
return fmt.Errorf("failed to parse issuer URL: %w\n", err)
Expand Down
Loading