Skip to content

Commit

Permalink
harden: Only store registered repositories
Browse files Browse the repository at this point in the history
This changes the logic of the repo register to only store registered repositories
as opposed to registering everything and then verifying if they're registered or not.

This also uncovered that we weren't really checking for issues in registration... So,
this changed that logic.

This also adds a new provider called `repo-lister` which enables implementors to list
repositories. This is handy for the auto-enrollment flow, which requires us to list remote
repositories, compare them to the registered ones, and then enroll.

We might not have auto-enrollment (webhook creation) for other providers... but it would be
handy to be able to list repos so we could then build requests that manually register the repos
to mediator.
  • Loading branch information
JAORMX committed Oct 9, 2023
1 parent 657d372 commit c98c3c2
Show file tree
Hide file tree
Showing 33 changed files with 2,617 additions and 2,577 deletions.
79 changes: 36 additions & 43 deletions cmd/cli/app/provider/provider_enroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (

ghclient "github.com/stacklok/mediator/internal/providers/github"
"github.com/stacklok/mediator/internal/util"
"github.com/stacklok/mediator/internal/util/cli"
"github.com/stacklok/mediator/internal/util/rand"
pb "github.com/stacklok/mediator/pkg/api/protobuf/go/mediator/v1"
)
Expand All @@ -48,16 +49,10 @@ type Response struct {
// MAX_CALLS is the maximum number of calls to the gRPC server before stopping.
const MAX_CALLS = 300

// just syncs repos for the specific provider and project
func syncRepos(ctx context.Context, client pb.RepositoryServiceClient, provider string, project string) error {
_, err := client.SyncRepositories(ctx, &pb.SyncRepositoriesRequest{Provider: provider, ProjectId: project})
return err
}

// callBackServer starts a server and handler to listen for the OAuth callback.
// It will wait for either a success or failure response from the server.
func callBackServer(ctx context.Context, provider string, project string, port string,
wg *sync.WaitGroup, client pb.OAuthServiceClient, since int64, repos_client pb.RepositoryServiceClient) {
wg *sync.WaitGroup, client pb.OAuthServiceClient, since int64) {
server := &http.Server{
Addr: fmt.Sprintf(":%s", port),
ReadHeaderTimeout: time.Second * 10, // Set an appropriate timeout value
Expand Down Expand Up @@ -103,9 +98,7 @@ func callBackServer(ctx context.Context, provider string, project string, port s
res, err := client.VerifyProviderTokenFrom(clientCtx,
&pb.VerifyProviderTokenFromRequest{Provider: provider, ProjectId: project, Timestamp: timestamppb.New(t)})
if err == nil && res.Status == "OK" {
// we can sync repos
err := syncRepos(clientCtx, repos_client, provider, project)
util.ExitNicelyOnError(err, "Error syncing repos")
return
}
if err != nil || res.Status == "OK" || calls >= MAX_CALLS {
stopServer = true
Expand Down Expand Up @@ -141,7 +134,6 @@ actions such as adding repositories.`,
defer conn.Close()

client := pb.NewOAuthServiceClient(conn)
repos_client := pb.NewRepositoryServiceClient(conn)
ctx, cancel := util.GetAppContext()
defer cancel()
oAuthCallbackCtx, oAuthCancel := context.WithTimeout(context.Background(), MAX_CALLS*time.Second)
Expand All @@ -153,40 +145,41 @@ actions such as adding repositories.`,
&pb.StoreProviderTokenRequest{Provider: provider, ProjectId: project, AccessToken: pat, Owner: &owner})
util.ExitNicelyOnError(err, "Error storing token")

err = syncRepos(ctx, repos_client, provider, project)
util.ExitNicelyOnError(err, "Error syncing repos")
fmt.Println("Provider enrolled successfully")
} else {
// Get random port
port, err := rand.GetRandomPort()
util.ExitNicelyOnError(err, "Error getting random port")

resp, err := client.GetAuthorizationURL(ctx, &pb.GetAuthorizationURLRequest{
Provider: provider,
ProjectId: project,
Cli: true,
Port: int32(port),
Owner: &owner,
})
util.ExitNicelyOnError(err, "Error getting authorization URL")

fmt.Printf("Your browser will now be opened to: %s\n", resp.GetUrl())
fmt.Println("Please follow the instructions on the page to complete the OAuth flow.")
fmt.Println("Once the flow is complete, the CLI will close")
fmt.Println("If this is a headless environment, please copy and paste the URL into a browser on a different machine.")

if err := browser.OpenURL(resp.GetUrl()); err != nil {
fmt.Fprintf(os.Stderr, "Error opening browser: %s\n", err)
os.Exit(1)
}
openTime := time.Now().Unix()

var wg sync.WaitGroup
wg.Add(1)
cli.PrintCmd(cmd, "Provider enrolled successfully")
return
}

go callBackServer(oAuthCallbackCtx, provider, project, fmt.Sprintf("%d", port), &wg, client, openTime, repos_client)
wg.Wait()
// Get random port
port, err := rand.GetRandomPort()
util.ExitNicelyOnError(err, "Error getting random port")

resp, err := client.GetAuthorizationURL(ctx, &pb.GetAuthorizationURLRequest{
Provider: provider,
ProjectId: project,
Cli: true,
Port: int32(port),
Owner: &owner,
})
util.ExitNicelyOnError(err, "Error getting authorization URL")

fmt.Printf("Your browser will now be opened to: %s\n", resp.GetUrl())
fmt.Println("Please follow the instructions on the page to complete the OAuth flow.")
fmt.Println("Once the flow is complete, the CLI will close")
fmt.Println("If this is a headless environment, please copy and paste the URL into a browser on a different machine.")

if err := browser.OpenURL(resp.GetUrl()); err != nil {
fmt.Fprintf(os.Stderr, "Error opening browser: %s\n", err)
os.Exit(1)
}
openTime := time.Now().Unix()

var wg sync.WaitGroup
wg.Add(1)

go callBackServer(oAuthCallbackCtx, provider, project, fmt.Sprintf("%d", port), &wg, client, openTime)
wg.Wait()

cli.PrintCmd(cmd, "Provider enrolled successfully")
},
}

Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/app/repo/repo_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ var repo_getCmd = &cobra.Command{
defer cancel()

// check repo by id
var repository *pb.RepositoryRecord
var repository *pb.Repository
if repoid != "" {
resp, err := client.GetRepositoryById(ctx, &pb.GetRepositoryByIdRequest{
RepositoryId: repoid,
Expand Down
40 changes: 28 additions & 12 deletions cmd/cli/app/repo/repo_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ import (
"fmt"
"os"

"github.com/olekukonko/tablewriter"
"github.com/charmbracelet/bubbles/table"
"github.com/spf13/cobra"
"github.com/spf13/viper"

github "github.com/stacklok/mediator/internal/providers/github"
"github.com/stacklok/mediator/internal/util"
"github.com/stacklok/mediator/internal/util/cli"
pb "github.com/stacklok/mediator/pkg/api/protobuf/go/mediator/v1"
)

Expand Down Expand Up @@ -68,7 +69,6 @@ var repo_listCmd = &cobra.Command{
resp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{
Provider: provider,
ProjectId: projectID,
Filter: pb.RepoFilter_REPO_FILTER_SHOW_REGISTERED_ONLY,
})
if err != nil {
fmt.Fprintf(os.Stderr, "Error getting repo of repos: %s\n", err)
Expand All @@ -77,21 +77,37 @@ var repo_listCmd = &cobra.Command{

switch format {
case "", "table":
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"Id", "Project ID", "Provider Id", "Name", "Is fork", "Is private"})
columns := []table.Column{
{Title: "ID", Width: 40},
{Title: "Project", Width: 40},
{Title: "Provider", Width: 15},
{Title: "Upstream ID", Width: 15},
{Title: "Owner", Width: 15},
{Title: "Name", Width: 15},
}

var rows []table.Row
for _, v := range resp.Results {
row := []string{
v.Id,
v.ProjectId,
row := table.Row{
*v.Id,
*v.Context.Project,
v.Context.Provider,
fmt.Sprintf("%d", v.GetRepoId()),
fmt.Sprintf("%s/%s", v.GetOwner(), v.GetName()),
fmt.Sprintf("%t", v.GetIsFork()),
fmt.Sprintf("%t", v.GetIsPrivate()),
v.GetOwner(),
v.GetName(),
}
table.Append(row)
rows = append(rows, row)
}
table.Render()

t := table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(false),
table.WithHeight(len(rows)),
table.WithStyles(cli.TableHiddenSelectStyles),
)

cli.PrintCmd(cmd, cli.TableRender(t))
case "json":
out, err := util.GetJsonFromProto(resp)
util.ExitNicelyOnError(err, "Error getting json from proto")
Expand Down
97 changes: 81 additions & 16 deletions cmd/cli/app/repo/repo_register.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ import (
"strings"

"github.com/AlecAivazis/survey/v2"
"github.com/charmbracelet/bubbles/table"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"k8s.io/utils/strings/slices"

github "github.com/stacklok/mediator/internal/providers/github"
"github.com/stacklok/mediator/internal/util"
"github.com/stacklok/mediator/internal/util/cli"
pb "github.com/stacklok/mediator/pkg/api/protobuf/go/mediator/v1"
)

Expand All @@ -53,7 +55,6 @@ var repo_registerCmd = &cobra.Command{
}
},
Run: func(cmd *cobra.Command, args []string) {

provider := util.GetConfigValue("provider", "provider", cmd, "").(string)
if provider != github.Github {
fmt.Fprintf(os.Stderr, "Only %s is supported at this time\n", github.Github)
Expand All @@ -69,21 +70,54 @@ var repo_registerCmd = &cobra.Command{
ctx, cancel := util.GetAppContext()
defer cancel()

req := &pb.ListRepositoriesRequest{
// Get the list of repos
listResp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{
Provider: provider,
ProjectId: projectID,
Filter: pb.RepoFilter_REPO_FILTER_SHOW_NOT_REGISTERED_ONLY,
})
if err != nil {
cli.PrintCmd(cmd, "Error getting list of repos: %s\n", err)
os.Exit(1)
}

// Get the list of repos
listResp, err := client.ListRepositories(ctx, req)
cli.PrintCmd(cmd, "Found %d registered repositories\n", len(listResp.Results))

// Get list of remot repos
remoteListResp, err := client.ListRemoteRepositoriesFromProvider(ctx, &pb.ListRemoteRepositoriesFromProviderRequest{
Provider: provider,
ProjectId: projectID,
})
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error getting list of repos: %s\n", err)
cli.PrintCmd(cmd, "Error getting list of remote repos: %s\n", err)
os.Exit(1)
}

cli.PrintCmd(cmd, "Found %d remote repositories\n", len(remoteListResp.Results))

// Unregistered repos are in remoteListResp but not in listResp
// build a list of unregistered repos
var unregisteredRepos []*pb.UpstreamRepositoryRef
for _, remoteRepo := range remoteListResp.Results {
found := false
for _, repo := range listResp.Results {
if remoteRepo.Owner == repo.Owner && remoteRepo.Name == repo.Name {
found = true
break
}
}
if !found {
unregisteredRepos = append(unregisteredRepos, &pb.UpstreamRepositoryRef{
Owner: remoteRepo.Owner,
Name: remoteRepo.Name,
RepoId: remoteRepo.RepoId,
})
}
}

cli.PrintCmd(cmd, "Found %d unregistered repositories\n", len(unregisteredRepos))

// Get the selected repos
selectedRepos, err := getSelectedRepositories(listResp, cfgFlagRepos)
selectedRepos, err := getSelectedRepositories(unregisteredRepos, cfgFlagRepos)
if err != nil {
if errors.Is(err, errNoRepositoriesSelected) {
_, _ = fmt.Fprintf(os.Stderr, "%v\n", err)
Expand Down Expand Up @@ -111,11 +145,42 @@ var repo_registerCmd = &cobra.Command{
os.Exit(1)
}

// Print the registered repos
for _, repo := range registerResp.Results {
fmt.Printf("Registered repository: %s/%s\n", repo.Owner, repo.Repository)
// The result gives a list of repositories with the registration status
// Let's parse the results and print the status
columns := []table.Column{
{Title: "Repository", Width: 35},
{Title: "Status", Width: 15},
{Title: "Message", Width: 60},
}

rows := make([]table.Row, len(registerResp.Results))
for i, result := range registerResp.Results {
rows[i] = table.Row{
fmt.Sprintf("%s/%s", result.Repository.Owner, result.Repository.Name),
}

if result.Status.Success {
rows[i] = append(rows[i], "Registered")
} else {
rows[i] = append(rows[i], "Failed")
}

if result.Status.Error != nil {
rows[i] = append(rows[i], *result.Status.Error)
} else {
rows[i] = append(rows[i], "")
}
}

t := table.New(
table.WithColumns(columns),
table.WithRows(rows),
table.WithFocused(false),
table.WithHeight(len(rows)),
table.WithStyles(cli.TableHiddenSelectStyles),
)

cli.PrintCmd(cmd, cli.TableRender(t))
},
}

Expand All @@ -129,20 +194,20 @@ func init() {
}
}

func getSelectedRepositories(listResp *pb.ListRepositoriesResponse, flagRepos string) ([]*pb.Repositories, error) {
func getSelectedRepositories(repoList []*pb.UpstreamRepositoryRef, flagRepos string) ([]*pb.UpstreamRepositoryRef, error) {
// If no repos are found, exit
if len(listResp.Results) == 0 {
if len(repoList) == 0 {
return nil, fmt.Errorf("no repositories found")
}

// Create a slice of strings to hold the repo names
repoNames := make([]string, len(listResp.Results))
repoNames := make([]string, len(repoList))

// Map of repo names to IDs
repoIDs := make(map[string]int32)

// Populate the repoNames slice and repoIDs map
for i, repo := range listResp.Results {
for i, repo := range repoList {
repoNames[i] = fmt.Sprintf("%s/%s", repo.Owner, repo.Name)
repoIDs[repoNames[i]] = repo.RepoId
}
Expand Down Expand Up @@ -184,7 +249,7 @@ func getSelectedRepositories(listResp *pb.ListRepositoriesResponse, flagRepos st
}

// Create a slice of Repositories protobufs
protoRepos := make([]*pb.Repositories, len(allSelectedRepos))
protoRepos := make([]*pb.UpstreamRepositoryRef, len(allSelectedRepos))

// Convert the selected repos into a slice of Repositories protobufs
for i, repo := range allSelectedRepos {
Expand All @@ -193,7 +258,7 @@ func getSelectedRepositories(listResp *pb.ListRepositoriesResponse, flagRepos st
_, _ = fmt.Fprintf(os.Stderr, "Unexpected repository name format: %s, skipping registration\n", repo)
continue
}
protoRepos[i] = &pb.Repositories{
protoRepos[i] = &pb.UpstreamRepositoryRef{
Owner: splitRepo[0],
Name: splitRepo[1],
RepoId: repoIDs[repo],
Expand Down
2 changes: 1 addition & 1 deletion cmd/dev/app/rule_type/rule_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func readEntityFromFile(fpath string, entType mediatorv1.Entity) (protoreflect.P

switch entType {
case mediatorv1.Entity_ENTITY_REPOSITORIES:
out = &mediatorv1.RepositoryResult{}
out = &mediatorv1.Repository{}
case mediatorv1.Entity_ENTITY_ARTIFACTS:
out = &mediatorv1.Artifact{}
case mediatorv1.Entity_ENTITY_PULL_REQUESTS:
Expand Down
9 changes: 9 additions & 0 deletions cmd/server/app/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ var serveCmd = &cobra.Command{
return err
}

// webhook config validation
webhookURL := cfg.WebhookConfig.ExternalWebhookURL
webhookping := cfg.WebhookConfig.ExternalPingURL
webhooksecret := cfg.WebhookConfig.WebhookSecret
if webhookURL == "" || webhookping == "" || webhooksecret == "" {
return fmt.Errorf("webhook configuration is not set")
}

// Identity
parsedURL, err := url.Parse(cfg.Identity.IssuerUrl)
if err != nil {
return fmt.Errorf("failed to parse issuer URL: %w\n", err)
Expand Down
Loading

0 comments on commit c98c3c2

Please sign in to comment.