diff --git a/go.mod b/go.mod index 6d4ffad..a3a1aba 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/go-git/go-billy/v5 v5.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/google/go-github/v68 v68.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect diff --git a/go.sum b/go.sum index b1f7a96..5903159 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-github/v68 v68.0.0 h1:ZW57zeNZiXTdQ16qrDiZ0k6XucrxZ2CGmoTvcCyQG6s= +github.com/google/go-github/v68 v68.0.0/go.mod h1:K9HAUBovM2sLwM408A18h+wd9vqdLOEqTUCbnRIcx68= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/internal/cli/patrol.go b/internal/cli/patrol.go index f5a2c71..94e13b8 100644 --- a/internal/cli/patrol.go +++ b/internal/cli/patrol.go @@ -5,9 +5,8 @@ import ( "fmt" "os/exec" "sheriff/internal/config" - "sheriff/internal/git" "sheriff/internal/patrol" - "sheriff/internal/repo" + "sheriff/internal/repository/provider" "sheriff/internal/scanner" "sheriff/internal/slack" "strings" @@ -34,6 +33,7 @@ const reportToSlackChannel = "report-to-slack-channel" const reportEnableProjectReportToFlag = "report-enable-project-report-to" const silentReportFlag = "silent" const gitlabTokenFlag = "gitlab-token" +const githubTokenFlag = "github-token" const slackTokenFlag = "slack-token" var necessaryScanners = []string{scanner.OsvCommandName} @@ -91,6 +91,13 @@ var PatrolFlags = []cli.Flag{ EnvVars: []string{"GITLAB_TOKEN"}, Category: string(Tokens), }, + &cli.StringFlag{ + Name: githubTokenFlag, + Usage: "Token to access the Github API.", + Required: true, + EnvVars: []string{"GITHUB_TOKEN"}, + Category: string(Tokens), + }, &cli.StringFlag{ Name: slackTokenFlag, Usage: "Token to access the Slack API.", @@ -122,12 +129,13 @@ func PatrolAction(cCtx *cli.Context) error { // Get tokens gitlabToken := cCtx.String(gitlabTokenFlag) + githubToken := cCtx.String(githubTokenFlag) slackToken := cCtx.String(slackTokenFlag) // Create services - gitlabService, err := repo.NewGitlabService(gitlabToken) + repositoryService, err := provider.NewProvider(gitlabToken, githubToken) if err != nil { - return errors.Join(errors.New("failed to create GitLab service"), err) + return errors.Join(errors.New("failed to create repository service"), err) } slackService, err := slack.New(slackToken, config.Verbose) @@ -135,10 +143,9 @@ func PatrolAction(cCtx *cli.Context) error { return errors.Join(errors.New("failed to create Slack service"), err) } - gitService := git.New(gitlabToken) osvService := scanner.NewOsvScanner() - patrolService := patrol.New(gitlabService, slackService, gitService, osvService) + patrolService := patrol.New(repositoryService, slackService, osvService) // Check whether the necessary scanners are available missingScanners := getMissingScanners(necessaryScanners) diff --git a/internal/config/patrol.go b/internal/config/patrol.go index ba1acb8..a0703d5 100644 --- a/internal/config/patrol.go +++ b/internal/config/patrol.go @@ -4,13 +4,13 @@ import ( "errors" "fmt" "net/url" - "sheriff/internal/repo" + "sheriff/internal/repository" zerolog "github.com/rs/zerolog/log" ) type ProjectLocation struct { - Type repo.PlatformType + Type repository.RepositoryType Path string } @@ -124,9 +124,7 @@ func parseTargets(targets []string) ([]ProjectLocation, error) { return nil, fmt.Errorf("target missing platform scheme %v", t) } - if parsed.Scheme == string(repo.Github) { - return nil, fmt.Errorf("github is currently unsupported, but is on our roadmap šŸ˜ƒ") // TODO #9 - } else if parsed.Scheme != string(repo.Gitlab) { + if parsed.Scheme != string(repository.Gitlab) && parsed.Scheme != string(repository.Github) { return nil, fmt.Errorf("unsupported platform %v", parsed.Scheme) } @@ -136,7 +134,7 @@ func parseTargets(targets []string) ([]ProjectLocation, error) { } locations[i] = ProjectLocation{ - Type: repo.PlatformType(parsed.Scheme), + Type: repository.RepositoryType(parsed.Scheme), Path: path, } } diff --git a/internal/config/patrol_test.go b/internal/config/patrol_test.go index 81b478d..46575fd 100644 --- a/internal/config/patrol_test.go +++ b/internal/config/patrol_test.go @@ -1,7 +1,7 @@ package config import ( - "sheriff/internal/repo" + "sheriff/internal/repository" "testing" "github.com/stretchr/testify/assert" @@ -9,7 +9,7 @@ import ( func TestGetPatrolConfiguration(t *testing.T) { want := PatrolConfig{ - Locations: []ProjectLocation{{Type: repo.Gitlab, Path: "group1"}, {Type: repo.Gitlab, Path: "group2/project1"}}, + Locations: []ProjectLocation{{Type: repository.Gitlab, Path: "group1"}, {Type: repository.Gitlab, Path: "group2/project1"}}, ReportToEmails: []string{"some-email@gmail.com"}, ReportToSlackChannels: []string{"report-slack-channel"}, ReportToIssue: true, @@ -29,7 +29,7 @@ func TestGetPatrolConfiguration(t *testing.T) { func TestGetPatrolConfigurationCLIOverridesFile(t *testing.T) { want := PatrolConfig{ - Locations: []ProjectLocation{{Type: repo.Gitlab, Path: "group1"}, {Type: repo.Gitlab, Path: "group2/project1"}}, + Locations: []ProjectLocation{{Type: repository.Gitlab, Path: "group1"}, {Type: repository.Gitlab, Path: "group2/project1"}}, ReportToEmails: []string{"email@gmail.com", "other@gmail.com"}, ReportToSlackChannels: []string{"other-slack-channel"}, ReportToIssue: false, @@ -87,8 +87,8 @@ func TestParseUrls(t *testing.T) { {[]string{"gitlab://namespace/project"}, &ProjectLocation{Type: "gitlab", Path: "namespace/project"}, false}, {[]string{"gitlab://namespace/subgroup/project"}, &ProjectLocation{Type: "gitlab", Path: "namespace/subgroup/project"}, false}, {[]string{"gitlab://namespace"}, &ProjectLocation{Type: "gitlab", Path: "namespace"}, false}, - {[]string{"github://organization"}, &ProjectLocation{Type: "github", Path: "organization"}, true}, - {[]string{"github://organization/project"}, &ProjectLocation{Type: "github", Path: "organization/project"}, true}, + {[]string{"github://organization"}, &ProjectLocation{Type: "github", Path: "organization"}, false}, + {[]string{"github://organization/project"}, &ProjectLocation{Type: "github", Path: "organization/project"}, false}, {[]string{"unknown://namespace/project"}, nil, true}, {[]string{"unknown://not a path"}, nil, true}, {[]string{"not a target"}, nil, true}, diff --git a/internal/git/client.go b/internal/git/client.go deleted file mode 100644 index 66b850b..0000000 --- a/internal/git/client.go +++ /dev/null @@ -1,19 +0,0 @@ -// This client is a thin wrapper around the go-git library. It provides an interface to the Git client -// The main purpose of this client is to provide an interface to the GitLab client which can be mocked in tests. -// As such this MUST be as thin as possible and MUST not contain any business logic, since it is not testable. -package git - -import ( - "github.com/go-git/go-git/v5" -) - -type iclient interface { - PlainClone(path string, isBare bool, o *git.CloneOptions) (*git.Repository, error) -} - -type client struct { -} - -func (c *client) PlainClone(path string, isBare bool, o *git.CloneOptions) (*git.Repository, error) { - return git.PlainClone(path, isBare, o) -} diff --git a/internal/git/git.go b/internal/git/git.go deleted file mode 100644 index 8101072..0000000 --- a/internal/git/git.go +++ /dev/null @@ -1,35 +0,0 @@ -package git - -import ( - "github.com/go-git/go-git/v5" - "github.com/go-git/go-git/v5/plumbing/transport/http" -) - -type IService interface { - Clone(dir string, url string) error -} - -type service struct { - client iclient - token string -} - -func New(token string) IService { - return &service{ - client: &client{}, - token: token, - } -} - -func (s *service) Clone(dir string, url string) (err error) { - _, err = s.client.PlainClone(dir, false, &git.CloneOptions{ - URL: url, - Auth: &http.BasicAuth{ - Username: "N/A", - Password: s.token, - }, - Depth: 1, - }) - - return -} diff --git a/internal/git/git_test.go b/internal/git/git_test.go deleted file mode 100644 index 3ed1ba8..0000000 --- a/internal/git/git_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package git - -import ( - "testing" - - "github.com/go-git/go-git/v5" - "github.com/go-git/go-git/v5/plumbing/transport/http" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -func TestNewService(t *testing.T) { - s := New("token") - - assert.NotNil(t, s) -} - -func TestClone(t *testing.T) { - path := "path/to/directory" - url := "https://gitlab.com/username/repo.git" - token := "token" - - mockGit := &mockGit{} - mockGit.On("PlainClone", path, false, &git.CloneOptions{ - URL: url, - Auth: &http.BasicAuth{ - Username: "N/A", - Password: token, - }, - Depth: 1, - }).Return(&git.Repository{}, nil) - - s := &service{client: mockGit, token: token} - - err := s.Clone(path, url) - - assert.Nil(t, err) - mockGit.AssertExpectations(t) -} - -type mockGit struct { - mock.Mock -} - -func (g *mockGit) PlainClone(path string, isBare bool, o *git.CloneOptions) (*git.Repository, error) { - args := g.Called(path, isBare, o) - return args.Get(0).(*git.Repository), args.Error(1) -} diff --git a/internal/patrol/patrol.go b/internal/patrol/patrol.go index 67b7138..d691278 100644 --- a/internal/patrol/patrol.go +++ b/internal/patrol/patrol.go @@ -7,9 +7,9 @@ import ( "fmt" "os" "sheriff/internal/config" - "sheriff/internal/git" "sheriff/internal/publish" - "sheriff/internal/repo" + "sheriff/internal/repository" + "sheriff/internal/repository/provider" "sheriff/internal/scanner" "sheriff/internal/slack" "sync" @@ -29,21 +29,19 @@ type securityPatroller interface { // sheriffService is the implementation of the SecurityPatroller interface. type sheriffService struct { - gitlabService repo.IService - slackService slack.IService - gitService git.IService - osvService scanner.VulnScanner[scanner.OsvReport] + repoService provider.IProvider + slackService slack.IService + osvService scanner.VulnScanner[scanner.OsvReport] } // New creates a new securityPatroller service. // It contains the main "loop" logic of this tool. // A "patrol" is defined as scanning GitLab groups for vulnerabilities and publishing reports where needed. -func New(gitlabService repo.IService, slackService slack.IService, gitService git.IService, osvService scanner.VulnScanner[scanner.OsvReport]) securityPatroller { +func New(repoService provider.IProvider, slackService slack.IService, osvService scanner.VulnScanner[scanner.OsvReport]) securityPatroller { return &sheriffService{ - gitlabService: gitlabService, - slackService: slackService, - gitService: gitService, - osvService: osvService, + repoService: repoService, + slackService: slackService, + osvService: osvService, } } @@ -65,7 +63,7 @@ func (s *sheriffService) Patrol(args config.PatrolConfig) (warn error, err error if args.ReportToIssue { log.Info().Msg("Creating issue in affected projects") - if gwarn := publish.PublishAsGitlabIssues(scanReports, s.gitlabService); gwarn != nil { + if gwarn := publish.PublishAsIssues(scanReports, s.repoService); gwarn != nil { gwarn = errors.Join(errors.New("errors occured when creating issues"), gwarn) warn = errors.Join(gwarn, warn) } @@ -107,13 +105,7 @@ func (s *sheriffService) scanAndGetReports(locations []config.ProjectLocation) ( defer os.RemoveAll(tempScanDir) log.Info().Str("path", tempScanDir).Msg("Created temporary directory") - gitlabLocs := pie.Map( - pie.Filter(locations, func(v config.ProjectLocation) bool { return v.Type == repo.Gitlab }), - func(v config.ProjectLocation) string { return v.Path }, - ) - log.Info().Strs("locations", gitlabLocs).Msg("Getting the list of projects to scan") - - projects, pwarn := s.gitlabService.GetProjectList(gitlabLocs) + projects, pwarn := s.getProjectList(locations) if pwarn != nil { pwarn = errors.Join(errors.New("errors occured when getting project list"), pwarn) warn = errors.Join(pwarn, warn) @@ -152,8 +144,41 @@ func (s *sheriffService) scanAndGetReports(locations []config.ProjectLocation) ( return } +func (s *sheriffService) getProjectList(locs []config.ProjectLocation) (projects []repository.Project, warn error) { + gitlabLocs := pie.Map( + pie.Filter(locs, func(loc config.ProjectLocation) bool { return loc.Type == repository.Gitlab }), + func(loc config.ProjectLocation) string { return loc.Path }, + ) + githubLocs := pie.Map( + pie.Filter(locs, func(loc config.ProjectLocation) bool { return loc.Type == repository.Github }), + func(loc config.ProjectLocation) string { return loc.Path }, + ) + + if len(gitlabLocs) > 0 { + log.Info().Strs("locations", gitlabLocs).Msg("Getting the list of projects from gitlab to scan") + gitlabProjects, err := s.repoService.Provide(repository.Gitlab).GetProjectList(gitlabLocs) + if err != nil { + warn = errors.Join(errors.New("non-critical errors encountered when scanning for gitlab projects"), err) + } + + projects = append(projects, gitlabProjects...) + } + + if len(githubLocs) > 0 { + log.Info().Strs("locations", githubLocs).Msg("Getting the list of projects from github to scan") + githubProjects, err := s.repoService.Provide(repository.Github).GetProjectList(githubLocs) + if err != nil { + warn = errors.Join(errors.New("non-critical errors encountered when scanning for github projects"), err) + } + + projects = append(projects, githubProjects...) + } + + return +} + // scanProject scans a project for vulnerabilities using the osv scanner. -func (s *sheriffService) scanProject(project repo.Project) (report *scanner.Report, err error) { +func (s *sheriffService) scanProject(project repository.Project) (report *scanner.Report, err error) { dir, err := os.MkdirTemp(tempScanDir, fmt.Sprintf("%v-", project.Name)) if err != nil { return nil, errors.Join(errors.New("failed to create project temporary directory"), err) @@ -161,9 +186,9 @@ func (s *sheriffService) scanProject(project repo.Project) (report *scanner.Repo defer os.RemoveAll(dir) // Clone the project - log.Info().Str("project", project.Path).Str("dir", dir).Msg("Cloning project") - if err = s.gitService.Clone(dir, project.RepoUrl); err != nil { - return nil, errors.Join(errors.New("failed to clone project"), err) + log.Info().Str("project", project.Path).Str("dir", dir).Str("url", project.RepoUrl).Msg("Cloning project") + if err := s.repoService.Provide(project.Repository).Clone(project.RepoUrl, dir); err != nil { + return nil, errors.Join(fmt.Errorf("failed to clone project %v", project.Path), err) } config := config.GetProjectConfiguration(project.Path, dir) diff --git a/internal/patrol/patrol_test.go b/internal/patrol/patrol_test.go index 5bdbc4b..5d8501a 100644 --- a/internal/patrol/patrol_test.go +++ b/internal/patrol/patrol_test.go @@ -2,7 +2,7 @@ package patrol import ( "sheriff/internal/config" - "sheriff/internal/repo" + "sheriff/internal/repository" "sheriff/internal/scanner" "testing" @@ -12,27 +12,27 @@ import ( ) func TestNewService(t *testing.T) { - s := New(&mockGitlabService{}, &mockSlackService{}, &mockGitService{}, &mockOSVService{}) + s := New(&mockRepoService{}, &mockSlackService{}, &mockOSVService{}) assert.NotNil(t, s) } func TestScanNoProjects(t *testing.T) { - mockGitlabService := &mockGitlabService{} - mockGitlabService.On("GetProjectList", []string{"group/to/scan"}).Return([]repo.Project{}, nil) + mockClient := &mockClient{} + mockClient.On("GetProjectList", []string{"group/to/scan"}).Return([]repository.Project{}, nil) - mockSlackService := &mockSlackService{} + mockRepoService := &mockRepoService{} + mockRepoService.On("Provide", repository.Gitlab).Return(mockClient) - mockGitService := &mockGitService{} - mockGitService.On("Clone", mock.Anything, "https://gitlab.com/group/to/scan.git").Return(nil) + mockSlackService := &mockSlackService{} mockOSVService := &mockOSVService{} mockOSVService.On("Scan", mock.Anything).Return(&scanner.OsvReport{}, nil) - svc := New(mockGitlabService, mockSlackService, mockGitService, mockOSVService) + svc := New(mockRepoService, mockSlackService, mockOSVService) warn, err := svc.Patrol(config.PatrolConfig{ - Locations: []config.ProjectLocation{{Type: repo.Gitlab, Path: "group/to/scan"}}, + Locations: []config.ProjectLocation{{Type: repository.Gitlab, Path: "group/to/scan"}}, ReportToEmails: []string{}, ReportToSlackChannels: []string{"channel"}, ReportToIssue: true, @@ -43,29 +43,31 @@ func TestScanNoProjects(t *testing.T) { assert.Nil(t, err) assert.Nil(t, warn) - mockGitlabService.AssertExpectations(t) + mockClient.AssertExpectations(t) + mockRepoService.AssertExpectations(t) mockSlackService.AssertExpectations(t) } func TestScanNonVulnerableProject(t *testing.T) { - mockGitlabService := &mockGitlabService{} - mockGitlabService.On("GetProjectList", []string{"group/to/scan"}).Return([]repo.Project{{Name: "Hello World", RepoUrl: "https://gitlab.com/group/to/scan.git"}}, nil) - mockGitlabService.On("CloseVulnerabilityIssue", mock.Anything).Return(nil) + mockClient := &mockClient{} + mockClient.On("GetProjectList", []string{"group/to/scan"}).Return([]repository.Project{{Name: "Hello World", RepoUrl: "https://gitlab.com/group/to/scan.git", Repository: repository.Gitlab}}, nil) + mockClient.On("CloseVulnerabilityIssue", mock.Anything).Return(nil) + mockClient.On("Clone", "https://gitlab.com/group/to/scan.git", mock.Anything).Return(nil) + + mockRepoService := &mockRepoService{} + mockRepoService.On("Provide", repository.Gitlab).Return(mockClient) mockSlackService := &mockSlackService{} mockSlackService.On("PostMessage", "channel", mock.Anything).Return("", nil) - mockGitService := &mockGitService{} - mockGitService.On("Clone", mock.Anything, "https://gitlab.com/group/to/scan.git").Return(nil) - mockOSVService := &mockOSVService{} mockOSVService.On("Scan", mock.Anything).Return(&scanner.OsvReport{}, nil) - mockOSVService.On("GenerateReport", mock.Anything, mock.Anything).Return(scanner.Report{}) + mockOSVService.On("GenerateReport", mock.Anything, mock.Anything).Return(scanner.Report{Project: repository.Project{Repository: repository.Gitlab}}) - svc := New(mockGitlabService, mockSlackService, mockGitService, mockOSVService) + svc := New(mockRepoService, mockSlackService, mockOSVService) warn, err := svc.Patrol(config.PatrolConfig{ - Locations: []config.ProjectLocation{{Type: repo.Gitlab, Path: "group/to/scan"}}, + Locations: []config.ProjectLocation{{Type: repository.Gitlab, Path: "group/to/scan"}}, ReportToEmails: []string{}, ReportToSlackChannels: []string{"channel"}, ReportToIssue: true, @@ -76,25 +78,28 @@ func TestScanNonVulnerableProject(t *testing.T) { assert.Nil(t, err) assert.Nil(t, warn) - mockGitlabService.AssertExpectations(t) + mockClient.AssertExpectations(t) + mockRepoService.AssertExpectations(t) mockSlackService.AssertExpectations(t) } func TestScanVulnerableProject(t *testing.T) { - mockGitlabService := &mockGitlabService{} - mockGitlabService.On("GetProjectList", []string{"group/to/scan"}).Return([]repo.Project{{Name: "Hello World", RepoUrl: "https://gitlab.com/group/to/scan.git"}}, nil) - mockGitlabService.On("OpenVulnerabilityIssue", mock.Anything, mock.Anything).Return(&repo.Issue{}, nil) + mockClient := &mockClient{} + mockClient.On("GetProjectList", []string{"group/to/scan"}).Return([]repository.Project{{Name: "Hello World", RepoUrl: "https://gitlab.com/group/to/scan.git", Repository: repository.Gitlab}}, nil) + mockClient.On("OpenVulnerabilityIssue", mock.Anything, mock.Anything).Return(&repository.Issue{}, nil) + mockClient.On("Clone", "https://gitlab.com/group/to/scan.git", mock.Anything).Return(nil) + + mockRepoService := &mockRepoService{} + mockRepoService.On("Provide", repository.Gitlab).Return(mockClient) mockSlackService := &mockSlackService{} mockSlackService.On("PostMessage", "channel", mock.Anything).Return("", nil) - mockGitService := &mockGitService{} - mockGitService.On("Clone", mock.Anything, "https://gitlab.com/group/to/scan.git").Return(nil) - mockOSVService := &mockOSVService{} report := &scanner.OsvReport{} mockOSVService.On("Scan", mock.Anything).Return(report, nil) mockOSVService.On("GenerateReport", mock.Anything, mock.Anything).Return(scanner.Report{ + Project: repository.Project{Repository: repository.Gitlab}, IsVulnerable: true, Vulnerabilities: []scanner.Vulnerability{ { @@ -103,10 +108,10 @@ func TestScanVulnerableProject(t *testing.T) { }, }) - svc := New(mockGitlabService, mockSlackService, mockGitService, mockOSVService) + svc := New(mockRepoService, mockSlackService, mockOSVService) warn, err := svc.Patrol(config.PatrolConfig{ - Locations: []config.ProjectLocation{{Type: repo.Gitlab, Path: "group/to/scan"}}, + Locations: []config.ProjectLocation{{Type: repository.Gitlab, Path: "group/to/scan"}}, ReportToEmails: []string{}, ReportToSlackChannels: []string{"channel"}, ReportToIssue: true, @@ -117,7 +122,8 @@ func TestScanVulnerableProject(t *testing.T) { assert.Nil(t, err) assert.Nil(t, warn) - mockGitlabService.AssertExpectations(t) + mockClient.AssertExpectations(t) + mockRepoService.AssertExpectations(t) mockSlackService.AssertExpectations(t) } @@ -186,23 +192,37 @@ func TestMarkOutdatedAcknowledgements(t *testing.T) { assert.Equal(t, []string{"CVE-3"}, report.OutdatedAcks) } -type mockGitlabService struct { +type mockRepoService struct { + mock.Mock +} + +func (c *mockRepoService) Provide(platform repository.RepositoryType) repository.IRepositoryService { + args := c.Called(platform) + return args.Get(0).(repository.IRepositoryService) +} + +type mockClient struct { mock.Mock } -func (c *mockGitlabService) GetProjectList(paths []string) ([]repo.Project, error) { +func (c *mockClient) GetProjectList(paths []string) ([]repository.Project, error) { args := c.Called(paths) - return args.Get(0).([]repo.Project), args.Error(1) + return args.Get(0).([]repository.Project), args.Error(1) } -func (c *mockGitlabService) CloseVulnerabilityIssue(project repo.Project) error { +func (c *mockClient) CloseVulnerabilityIssue(project repository.Project) error { args := c.Called(project) return args.Error(0) } -func (c *mockGitlabService) OpenVulnerabilityIssue(project repo.Project, report string) (*repo.Issue, error) { +func (c *mockClient) OpenVulnerabilityIssue(project repository.Project, report string) (*repository.Issue, error) { args := c.Called(project, report) - return args.Get(0).(*repo.Issue), args.Error(1) + return args.Get(0).(*repository.Issue), args.Error(1) +} + +func (c *mockClient) Clone(url string, dir string) error { + args := c.Called(url, dir) + return args.Error(0) } type mockSlackService struct { @@ -214,15 +234,6 @@ func (c *mockSlackService) PostMessage(channelName string, options ...slack.MsgO return args.String(0), args.Error(1) } -type mockGitService struct { - mock.Mock -} - -func (c *mockGitService) Clone(dir string, url string) (err error) { - args := c.Called(dir, url) - return args.Error(0) -} - type mockOSVService struct { mock.Mock } @@ -232,7 +243,7 @@ func (c *mockOSVService) Scan(dir string) (*scanner.OsvReport, error) { return args.Get(0).(*scanner.OsvReport), args.Error(1) } -func (c *mockOSVService) GenerateReport(p repo.Project, r *scanner.OsvReport) scanner.Report { +func (c *mockOSVService) GenerateReport(p repository.Project, r *scanner.OsvReport) scanner.Report { args := c.Called(p, r) return args.Get(0).(scanner.Report) } diff --git a/internal/publish/to_console_test.go b/internal/publish/to_console_test.go index e973d3e..73a8743 100644 --- a/internal/publish/to_console_test.go +++ b/internal/publish/to_console_test.go @@ -1,7 +1,7 @@ package publish import ( - "sheriff/internal/repo" + "sheriff/internal/repository" "sheriff/internal/scanner" "testing" @@ -11,7 +11,7 @@ import ( func TestFormatReportMessageForConsole(t *testing.T) { reports := []scanner.Report{ { - Project: repo.Project{ + Project: repository.Project{ Name: "project1", WebURL: "http://example.com", }, @@ -27,7 +27,7 @@ func TestFormatReportMessageForConsole(t *testing.T) { }, }, { - Project: repo.Project{ + Project: repository.Project{ Name: "project2", WebURL: "http://example2.com", }, diff --git a/internal/publish/to_gitlab.go b/internal/publish/to_issue.go similarity index 84% rename from internal/publish/to_gitlab.go rename to internal/publish/to_issue.go index 30b819b..23ec353 100644 --- a/internal/publish/to_gitlab.go +++ b/internal/publish/to_issue.go @@ -3,7 +3,7 @@ package publish import ( "errors" "fmt" - "sheriff/internal/repo" + "sheriff/internal/repository/provider" "sheriff/internal/scanner" "strconv" "sync" @@ -20,16 +20,17 @@ var severityScoreOrder = getSeverityScoreOrder(scanner.SeverityScoreThresholds) // now is a function that returns the current time var now = time.Now -// PublishAsGitlabIssues creates or updates GitLab Issue reports for the given reports +// PublishAsIssues creates or updates Issue reports for the given reports // It will add the Issue URL to the Report if it was created or updated successfully -func PublishAsGitlabIssues(reports []scanner.Report, s repo.IService) (warn error) { +func PublishAsIssues(reports []scanner.Report, s provider.IProvider) (warn error) { var wg sync.WaitGroup for i := 0; i < len(reports); i++ { wg.Add(1) go func() { defer wg.Done() - if reports[i].IsVulnerable { - if issue, err := s.OpenVulnerabilityIssue(reports[i].Project, formatGitlabIssue(reports[i])); err != nil { + report := reports[i] + if report.IsVulnerable { + if issue, err := s.Provide(report.Project.Repository).OpenVulnerabilityIssue(report.Project, formatIssue(report)); err != nil { log.Error().Err(err).Str("project", reports[i].Project.Path).Msg("Failed to open or update issue") err = fmt.Errorf("failed to open or update issue for project %v", reports[i].Project.Path) warn = errors.Join(err, warn) @@ -37,9 +38,9 @@ func PublishAsGitlabIssues(reports []scanner.Report, s repo.IService) (warn erro reports[i].IssueUrl = issue.WebURL } } else { - if err := s.CloseVulnerabilityIssue(reports[i].Project); err != nil { - log.Error().Err(err).Str("project", reports[i].Project.Path).Msg("Failed to close issue") - err = fmt.Errorf("failed to close issue for project %v", reports[i].Project.Path) + if err := s.Provide(report.Project.Repository).CloseVulnerabilityIssue(report.Project); err != nil { + log.Error().Err(err).Str("project", report.Project.Path).Msg("Failed to close issue") + err = fmt.Errorf("failed to close issue for project %v", report.Project.Path) warn = errors.Join(err, warn) } } @@ -74,8 +75,8 @@ func groupVulnReportsByMaxSeverityKind(reports []scanner.Report) map[scanner.Sev return groupedVulnerabilities } -// formatGitlabIssue formats the report as a GitLab issue -func formatGitlabIssue(r scanner.Report) (mdReport string) { +// formatIssue formats the report as an issue +func formatIssue(r scanner.Report) (mdReport string) { groupedVulnerabilities := pie.GroupBy(r.Vulnerabilities, func(v scanner.Vulnerability) scanner.SeverityScoreKind { return v.SeverityScoreKind }) mdReport = getVulnReportHeader() @@ -84,7 +85,7 @@ func formatGitlabIssue(r scanner.Report) (mdReport string) { sortedVulnsInGroup := pie.SortUsing(group, func(a, b scanner.Vulnerability) bool { return severityBiggerThan(a.Severity, b.Severity) }) - mdReport += formatGitlabIssueTable(groupName, sortedVulnsInGroup) + mdReport += formatIssueTable(groupName, sortedVulnsInGroup) } } @@ -108,9 +109,9 @@ func formatOutdatedAcks(outdatedAcks []string) (md string) { return } -// formatGitlabIssueTable formats a group of vulnerabilities as a markdown table -// for the GitLab issue report -func formatGitlabIssueTable(groupName scanner.SeverityScoreKind, vs []scanner.Vulnerability) (md string) { +// formatIssueTable formats a group of vulnerabilities as a markdown table +// for the issue report +func formatIssueTable(groupName scanner.SeverityScoreKind, vs []scanner.Vulnerability) (md string) { md = fmt.Sprintf("\n## Severity: %v\n", groupName) if groupName == scanner.Acknowledged { md += "\nšŸ’” These vulnerabilities have been acknowledged by the team and are not considered a risk.\n\n" diff --git a/internal/publish/to_gitlab_test.go b/internal/publish/to_issue_test.go similarity index 81% rename from internal/publish/to_gitlab_test.go rename to internal/publish/to_issue_test.go index eb64b29..1ece87b 100644 --- a/internal/publish/to_gitlab_test.go +++ b/internal/publish/to_issue_test.go @@ -1,7 +1,8 @@ package publish import ( - "sheriff/internal/repo" + "sheriff/internal/config" + "sheriff/internal/repository" "sheriff/internal/scanner" "testing" "time" @@ -61,7 +62,7 @@ func TestFormatGitlabIssue(t *testing.T) { }, } - got := formatGitlabIssue(scanner.Report{ + got := formatIssue(scanner.Report{ Vulnerabilities: mockVulnerabilities, }) @@ -132,7 +133,7 @@ func TestFormatGitlabIssueSortWithinGroup(t *testing.T) { }, } - got := formatGitlabIssue(scanner.Report{ + got := formatIssue(scanner.Report{ Vulnerabilities: mockVulnerabilities, }) @@ -164,9 +165,14 @@ func TestMarkdownBoolean(t *testing.T) { func TestPublishAsGitlabIssues(t *testing.T) { mockGitlabService := &mockGitlabService{} - mockGitlabService.On("OpenVulnerabilityIssue", mock.Anything, mock.Anything).Return(&repo.Issue{WebURL: "https://my-issue.com"}, nil) + mockGitlabService.On("OpenVulnerabilityIssue", mock.Anything, mock.Anything).Return(&repository.Issue{WebURL: "https://my-issue.com"}, nil) + + mockRepoService := &mockRepoService{} + mockRepoService.On("Provide", repository.Gitlab).Return(mockGitlabService) + reports := []scanner.Report{ { + Project: repository.Project{Repository: repository.Gitlab}, IsVulnerable: true, Vulnerabilities: []scanner.Vulnerability{ { @@ -176,8 +182,9 @@ func TestPublishAsGitlabIssues(t *testing.T) { }, } - _ = PublishAsGitlabIssues(reports, mockGitlabService) + _ = PublishAsIssues(reports, mockRepoService) mockGitlabService.AssertExpectations(t) + mockRepoService.AssertExpectations(t) t.Run("FillsTheIssueUrl", func(t *testing.T) { assert.Equal(t, "https://my-issue.com", reports[0].IssueUrl) @@ -202,21 +209,40 @@ func TestGitlabIssueReportHeader(t *testing.T) { } +type mockRepoService struct { + mock.Mock +} + +func (c *mockRepoService) GetProjectList(paths []config.ProjectLocation) ([]repository.Project, error) { + args := c.Called(paths) + return args.Get(0).([]repository.Project), args.Error(1) +} + +func (c *mockRepoService) Provide(platform repository.RepositoryType) repository.IRepositoryService { + args := c.Called(platform) + return args.Get(0).(repository.IRepositoryService) +} + type mockGitlabService struct { mock.Mock } -func (c *mockGitlabService) GetProjectList(paths []string) ([]repo.Project, error) { +func (c *mockGitlabService) GetProjectList(paths []string) ([]repository.Project, error) { args := c.Called(paths) - return args.Get(0).([]repo.Project), args.Error(1) + return args.Get(0).([]repository.Project), args.Error(1) } -func (c *mockGitlabService) CloseVulnerabilityIssue(project repo.Project) error { +func (c *mockGitlabService) CloseVulnerabilityIssue(project repository.Project) error { args := c.Called(project) return args.Error(0) } -func (c *mockGitlabService) OpenVulnerabilityIssue(project repo.Project, report string) (*repo.Issue, error) { +func (c *mockGitlabService) OpenVulnerabilityIssue(project repository.Project, report string) (*repository.Issue, error) { args := c.Called(project, report) - return args.Get(0).(*repo.Issue), args.Error(1) + return args.Get(0).(*repository.Issue), args.Error(1) +} + +func (c *mockGitlabService) Clone(url string, dir string) error { + args := c.Called(url, dir) + return args.Error(0) } diff --git a/internal/publish/to_slack_test.go b/internal/publish/to_slack_test.go index a3065e2..2d48e41 100644 --- a/internal/publish/to_slack_test.go +++ b/internal/publish/to_slack_test.go @@ -2,7 +2,7 @@ package publish import ( "sheriff/internal/config" - "sheriff/internal/repo" + "sheriff/internal/repository" "sheriff/internal/scanner" "testing" @@ -102,7 +102,7 @@ func TestFormatReportMessage(t *testing.T) { reportBySeverityKind := map[scanner.SeverityScoreKind][]scanner.Report{ scanner.Critical: { { - Project: repo.Project{ + Project: repository.Project{ Name: "project1", WebURL: "http://example.com", }, @@ -118,7 +118,7 @@ func TestFormatReportMessage(t *testing.T) { }, scanner.High: { { - Project: repo.Project{ + Project: repository.Project{ Name: "project2", WebURL: "http://example2.com", }, diff --git a/internal/repo/repo.go b/internal/repo/repo.go deleted file mode 100644 index 82fdeb7..0000000 --- a/internal/repo/repo.go +++ /dev/null @@ -1,34 +0,0 @@ -package repo - -const VulnerabilityIssueTitle = "Sheriff - šŸšØ Vulnerability report" - -type PlatformType string - -const ( - Gitlab PlatformType = "gitlab" - Github PlatformType = "github" -) - -type Project struct { - ID int - Name string - Path string - WebURL string - RepoUrl string - Platform string -} - -type Issue struct { - ID int - Title string - WebURL string - Open bool - Platform string -} - -// IService is the interface of the GitLab service as needed by sheriff -type IService interface { - GetProjectList(paths []string) (projects []Project, warn error) - CloseVulnerabilityIssue(project Project) error - OpenVulnerabilityIssue(project Project, report string) (*Issue, error) -} diff --git a/internal/repository/github/github.go b/internal/repository/github/github.go new file mode 100644 index 0000000..2c62451 --- /dev/null +++ b/internal/repository/github/github.go @@ -0,0 +1,206 @@ +package github + +import ( + "errors" + "fmt" + "sheriff/internal/repository" + "strings" + + "github.com/elliotchance/pie/v2" + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/transport/http" + "github.com/google/go-github/v68/github" + "github.com/rs/zerolog/log" + "golang.org/x/sync/errgroup" +) + +type githubService struct { + client iGithubClient + token string +} + +// newGithubRepo creates a new GitHub repository service +func New(token string) githubService { + client := github.NewClient(nil) + + s := githubService{client: &githubClient{client: client}, token: token} + + return s +} + +func (s githubService) GetProjectList(paths []string) (projects []repository.Project, warn error) { + g := new(errgroup.Group) + reposChan := make(chan []github.Repository, len(paths)) + for _, path := range paths { + g.Go(func() error { + repos, err := s.getPathRepos(path) + reposChan <- repos + if err != nil { + return err + } + + return nil + }) + } + warn = g.Wait() + + close(reposChan) + + // Collect repos + var allRepos []github.Repository + for repos := range reposChan { + allRepos = append(allRepos, repos...) + } + + projects = pie.Map(allRepos, mapGithubProject) + + return +} + +// CloseVulnerabilityIssue closes the vulnerability issue for the given project +func (s githubService) CloseVulnerabilityIssue(project repository.Project) (err error) { + return errors.New("CloseVulnerabilityIssue not yet implemented") // TODO #9 Add github support +} + +// OpenVulnerabilityIssue opens or updates the vulnerability issue for the given project +func (s githubService) OpenVulnerabilityIssue(project repository.Project, report string) (issue *repository.Issue, err error) { + return nil, errors.New("OpenVulnerabilityIssue not yet implemented") // TODO #9 Add github support +} + +func (s githubService) Clone(url string, dir string) (err error) { + _, err = git.PlainClone(dir, false, &git.CloneOptions{ + URL: url, + Auth: &http.BasicAuth{ + Username: "N/A", + Password: s.token, + }, + Depth: 1, + }) + + return err +} + +func (s githubService) getPathRepos(path string) (repositories []github.Repository, err error) { + parts := strings.Split(path, "/") + + if len(parts) == 1 { + return s.getOwnerRepos(parts[0]) + } else if len(parts) == 2 { + repo, err := s.getOwnerRepository(parts[0], parts[1]) + if err != nil { + return nil, errors.Join(fmt.Errorf("failed to get repository %s", path), err) + } else if repo == nil { + return nil, errors.New("repository unexpectedly nil") + } + + return []github.Repository{*repo}, err + } else { + return nil, fmt.Errorf("project %v path of unexpected length %v", path, len(parts)) + } +} + +func (s githubService) getOwnerRepos(owner string) (repos []github.Repository, err error) { + // Try first as `organization` + repoPtrs, err := s.getOrganizationRepos(owner) + if err != nil { + // Try again as `user` + repoPtrs, err = s.getUserRepos(owner) + if err != nil { + return nil, errors.Join(fmt.Errorf("could not fetch repos for owner %v", owner), err) + } + } + + repos = derefRepoPtrs(owner, repoPtrs) + + return +} + +func (s githubService) getOrganizationRepos(org string) (repos []*github.Repository, err error) { + repos, err = getGithubPaginatedResults(func(listOpts github.ListOptions) ([]*github.Repository, *github.Response, error) { + opts := &github.RepositoryListByOrgOptions{ + ListOptions: listOpts, + } + return s.client.GetOrganizationRepositories(org, opts) + }) + + return +} + +func (s githubService) getUserRepos(user string) (repos []*github.Repository, err error) { + repos, err = getGithubPaginatedResults(func(listOpts github.ListOptions) ([]*github.Repository, *github.Response, error) { + opts := &github.RepositoryListByUserOptions{ + Type: "owner", + ListOptions: listOpts, + } + return s.client.GetUserRepositories(user, opts) + }) + + return +} + +func getGithubPaginatedResults[T interface{}](paginatedFunc func(github.ListOptions) ([]T, *github.Response, error)) (results []T, err error) { + opts := github.ListOptions{ + PerPage: 100, + Page: 1, + } + for { + pageResults, resp, err := paginatedFunc(opts) + if err != nil { + return nil, err + } + + results = append(results, pageResults...) + if resp.NextPage == 0 { + break + } + + opts.Page = resp.NextPage + } + + return +} + +func (s *githubService) getOwnerRepository(owner string, name string) (repo *github.Repository, err error) { + repo, _, err = s.client.GetRepository(owner, name) + if err != nil { + return nil, err + } + + return +} + +func derefRepoPtrs(owner string, repoPtrs []*github.Repository) (repos []github.Repository) { + var errCount = 0 + for _, repo := range repoPtrs { + if repo == nil { + errCount++ + continue + } + repos = append(repos, *repo) + } + + if errCount > 0 { + log.Warn().Str("owner", owner).Int("count", errCount).Msg("Found nil repositories, skipping them.") + } + + return +} + +func mapGithubProject(r github.Repository) repository.Project { + return repository.Project{ + ID: int(valueOrEmpty(r.ID)), + Name: valueOrEmpty(r.Name), + Path: valueOrEmpty(r.FullName), + WebURL: valueOrEmpty(r.HTMLURL), + RepoUrl: valueOrEmpty(r.HTMLURL), + Repository: repository.Github, + } +} + +func valueOrEmpty[T interface{}](val *T) (r T) { + if val != nil { + return *val + } + + return r +} diff --git a/internal/repository/github/github_client.go b/internal/repository/github/github_client.go new file mode 100644 index 0000000..9912b3e --- /dev/null +++ b/internal/repository/github/github_client.go @@ -0,0 +1,34 @@ +// Package gitlab provides a GitLab service to interact with the GitLab API. +package github + +import ( + "context" + + "github.com/google/go-github/v68/github" +) + +// This client is a thin wrapper around the go-github library. It provides an interface to the GitHub client +// The main purpose of this client is to provide an interface to the GitHub client which can be mocked in tests. +// As such this MUST be as thin as possible and MUST not contain any business logic, since it is not testable. + +type iGithubClient interface { + GetRepository(owner string, repo string) (*github.Repository, *github.Response, error) + GetOrganizationRepositories(org string, opts *github.RepositoryListByOrgOptions) ([]*github.Repository, *github.Response, error) + GetUserRepositories(user string, opts *github.RepositoryListByUserOptions) ([]*github.Repository, *github.Response, error) +} + +type githubClient struct { + client *github.Client +} + +func (c *githubClient) GetRepository(owner string, repo string) (*github.Repository, *github.Response, error) { + return c.client.Repositories.Get(context.Background(), owner, repo) +} + +func (c *githubClient) GetOrganizationRepositories(org string, opts *github.RepositoryListByOrgOptions) ([]*github.Repository, *github.Response, error) { + return c.client.Repositories.ListByOrg(context.Background(), org, opts) +} + +func (c *githubClient) GetUserRepositories(user string, opts *github.RepositoryListByUserOptions) ([]*github.Repository, *github.Response, error) { + return c.client.Repositories.ListByUser(context.Background(), user, opts) +} diff --git a/internal/repository/github/github_test.go b/internal/repository/github/github_test.go new file mode 100644 index 0000000..20c84c2 --- /dev/null +++ b/internal/repository/github/github_test.go @@ -0,0 +1,113 @@ +package github + +import ( + "errors" + "testing" + + "github.com/google/go-github/v68/github" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGetProjectListOrganizationRepos(t *testing.T) { + mockService := mockService{} + mockService.On("GetOrganizationRepositories", "org", mock.Anything).Return([]*github.Repository{{Name: github.Ptr("Hello World")}}, &github.Response{}, nil) + + svc := githubService{client: &mockService} + + projects, err := svc.GetProjectList([]string{"org"}) + + assert.Nil(t, err) + assert.NotEmpty(t, projects) + assert.Equal(t, "Hello World", projects[0].Name) + mockService.AssertExpectations(t) +} + +func TestGetProjectListUserRepos(t *testing.T) { + mockService := mockService{} + mockService.On("GetOrganizationRepositories", "user", mock.Anything).Return([]*github.Repository{}, &github.Response{}, errors.New("error")) + mockService.On("GetUserRepositories", "user", mock.Anything).Return([]*github.Repository{{Name: github.Ptr("Hello World")}}, &github.Response{}, nil) + + svc := githubService{client: &mockService} + + projects, err := svc.GetProjectList([]string{"user"}) + + assert.Nil(t, err) + assert.NotEmpty(t, projects) + assert.Equal(t, "Hello World", projects[0].Name) + mockService.AssertExpectations(t) +} + +func TestGetProjectSpecificRepo(t *testing.T) { + mockService := mockService{} + mockService.On("GetRepository", "owner", "repo").Return(&github.Repository{Name: github.Ptr("Hello World")}, &github.Response{}, nil) + + svc := githubService{client: &mockService} + + projects, err := svc.GetProjectList([]string{"owner/repo"}) + + assert.Nil(t, err) + assert.NotEmpty(t, projects) + assert.Equal(t, "Hello World", projects[0].Name) + mockService.AssertExpectations(t) +} + +func TestGetProjectListWithNextPage(t *testing.T) { + project1 := &github.Repository{ID: github.Ptr(int64(1))} + project2 := &github.Repository{ID: github.Ptr(int64(2))} + + mockService := mockService{} + mockService.On("GetOrganizationRepositories", "org", &github.RepositoryListByOrgOptions{ + ListOptions: github.ListOptions{ + Page: 1, + PerPage: 100, + }, + }, mock.Anything).Return([]*github.Repository{project1}, &github.Response{NextPage: 2}, nil) + mockService.On("GetOrganizationRepositories", "org", &github.RepositoryListByOrgOptions{ + ListOptions: github.ListOptions{ + Page: 2, + PerPage: 100, + }, + }, mock.Anything).Return([]*github.Repository{project2}, &github.Response{NextPage: 0}, nil) + + svc := githubService{client: &mockService} + + projects, err := svc.GetProjectList([]string{"org"}) + + assert.Nil(t, err) + assert.Len(t, projects, 2) + assert.Equal(t, int(*project1.ID), projects[0].ID) + assert.Equal(t, int(*project2.ID), projects[1].ID) + mockService.AssertExpectations(t) +} + +type mockService struct { + mock.Mock +} + +func (c *mockService) GetRepository(owner string, repo string) (*github.Repository, *github.Response, error) { + args := c.Called(owner, repo) + var r *github.Response + if resp := args.Get(1); resp != nil { + r = args.Get(1).(*github.Response) + } + return args.Get(0).(*github.Repository), r, args.Error(2) +} + +func (c *mockService) GetOrganizationRepositories(org string, opts *github.RepositoryListByOrgOptions) ([]*github.Repository, *github.Response, error) { + args := c.Called(org, opts) + var r *github.Response + if resp := args.Get(1); resp != nil { + r = args.Get(1).(*github.Response) + } + return args.Get(0).([]*github.Repository), r, args.Error(2) +} + +func (c *mockService) GetUserRepositories(user string, opts *github.RepositoryListByUserOptions) ([]*github.Repository, *github.Response, error) { + args := c.Called(user, opts) + var r *github.Response + if resp := args.Get(1); resp != nil { + r = args.Get(1).(*github.Response) + } + return args.Get(0).([]*github.Repository), r, args.Error(2) +} diff --git a/internal/repo/gitlab.go b/internal/repository/gitlab/gitlab.go similarity index 77% rename from internal/repo/gitlab.go rename to internal/repository/gitlab/gitlab.go index d2ec0ff..310ac7b 100644 --- a/internal/repo/gitlab.go +++ b/internal/repository/gitlab/gitlab.go @@ -1,46 +1,50 @@ -package repo +package gitlab import ( "errors" "fmt" + "sheriff/internal/repository" "sync" "github.com/elliotchance/pie/v2" + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/plumbing/transport/http" "github.com/rs/zerolog/log" "github.com/xanzy/go-gitlab" ) type gitlabService struct { client iclient + token string } -// NewGitlabService creates a new GitLab service -func NewGitlabService(gitlabToken string) (IService, error) { - gitlabClient, err := gitlab.NewClient(gitlabToken) +// newGitlabRepo creates a new GitLab repository service +func New(token string) (*gitlabService, error) { + c, err := gitlab.NewClient(token) if err != nil { return nil, err } - s := gitlabService{&client{client: gitlabClient}} + s := gitlabService{client: &client{client: c}, token: token} return &s, nil } -func (s *gitlabService) GetProjectList(paths []string) (projects []Project, warn error) { +func (s gitlabService) GetProjectList(paths []string) (projects []repository.Project, warn error) { projects, pwarn := s.gatherProjectsFromGroupsOrProjects(paths) if pwarn != nil { pwarn = errors.Join(errors.New("errors occured when gathering projects"), pwarn) warn = errors.Join(pwarn, warn) } - projectsNamespaces := pie.Map(projects, func(p Project) string { return p.Path }) + projectsNamespaces := pie.Map(projects, func(p repository.Project) string { return p.Path }) log.Info().Strs("projects", projectsNamespaces).Msg("Projects to scan") return projects, warn } // CloseVulnerabilityIssue closes the vulnerability issue for the given project -func (s *gitlabService) CloseVulnerabilityIssue(project Project) (err error) { +func (s gitlabService) CloseVulnerabilityIssue(project repository.Project) (err error) { issue, err := s.getVulnerabilityIssue(project) if err != nil { return errors.Join(errors.New("failed to fetch current list of issues"), err) @@ -73,7 +77,7 @@ func (s *gitlabService) CloseVulnerabilityIssue(project Project) (err error) { } // OpenVulnerabilityIssue opens or updates the vulnerability issue for the given project -func (s *gitlabService) OpenVulnerabilityIssue(project Project, report string) (issue *Issue, err error) { +func (s gitlabService) OpenVulnerabilityIssue(project repository.Project, report string) (issue *repository.Issue, err error) { gitlabIssue, err := s.getVulnerabilityIssue(project) if err != nil { return nil, errors.Join(fmt.Errorf("[%v] Failed to fetch current list of issues", project.Path), err) @@ -83,7 +87,7 @@ func (s *gitlabService) OpenVulnerabilityIssue(project Project, report string) ( log.Info().Str("project", project.Path).Msg("Creating new issue") gitlabIssue, _, err := s.client.CreateIssue(project.ID, &gitlab.CreateIssueOptions{ - Title: gitlab.Ptr(VulnerabilityIssueTitle), + Title: gitlab.Ptr(repository.VulnerabilityIssueTitle), Description: &report, }) if err != nil { @@ -111,9 +115,22 @@ func (s *gitlabService) OpenVulnerabilityIssue(project Project, report string) ( return } +func (s gitlabService) Clone(url string, dir string) (err error) { + _, err = git.PlainClone(dir, false, &git.CloneOptions{ + URL: url, + Auth: &http.BasicAuth{ + Username: "N/A", + Password: s.token, + }, + Depth: 1, + }) + + return err +} + // This function receives a list of paths which can be gitlab projects or groups // and returns the list of projects within those paths and the list of projects contained within those groups and their subgroups. -func (s *gitlabService) gatherProjectsFromGroupsOrProjects(paths []string) (projects []Project, warn error) { +func (s gitlabService) gatherProjectsFromGroupsOrProjects(paths []string) (projects []repository.Project, warn error) { for _, path := range paths { gp, gpwarn, gerr := s.getProjectsFromGroupOrProject(path) if gerr != nil { @@ -140,7 +157,7 @@ func (s *gitlabService) gatherProjectsFromGroupsOrProjects(paths []string) (proj // // If it succeeds then it returns all projects of that group & its subgroups. // If it fails then it tries to get the path as a project. -func (s *gitlabService) getProjectsFromGroupOrProject(path string) (projects []Project, warn error, err error) { +func (s gitlabService) getProjectsFromGroupOrProject(path string) (projects []repository.Project, warn error, err error) { gp, gpwarn, gperr := s.listGroupProjects(path) if gperr != nil { log.Debug().Str("path", path).Msg("failed to fetch as group. trying as project") @@ -151,7 +168,7 @@ func (s *gitlabService) getProjectsFromGroupOrProject(path string) (projects []P return nil, fmt.Errorf("unexpected nil project %v", path), nil } - return []Project{mapProject(*p)}, nil, nil + return []repository.Project{mapProject(*p)}, nil, nil } ps := pie.Map(gp, mapProject) @@ -160,9 +177,9 @@ func (s *gitlabService) getProjectsFromGroupOrProject(path string) (projects []P } // getVulnerabilityIssue returns the vulnerability issue for the given project -func (s *gitlabService) getVulnerabilityIssue(project Project) (issue *gitlab.Issue, err error) { +func (s gitlabService) getVulnerabilityIssue(project repository.Project) (issue *gitlab.Issue, err error) { issues, _, err := s.client.ListProjectIssues(project.ID, &gitlab.ListProjectIssuesOptions{ - Search: gitlab.Ptr(VulnerabilityIssueTitle), + Search: gitlab.Ptr(repository.VulnerabilityIssueTitle), In: gitlab.Ptr("title"), }) if err != nil { @@ -181,7 +198,7 @@ func (s *gitlabService) getVulnerabilityIssue(project Project) (issue *gitlab.Is } // listGroupProjects returns the list of projects for the given group ID -func (s *gitlabService) listGroupProjects(path string) (projects []gitlab.Project, warn error, err error) { +func (s gitlabService) listGroupProjects(path string) (projects []gitlab.Project, warn error, err error) { projectPtrs, response, err := s.client.ListGroupProjects(path, &gitlab.ListGroupProjectsOptions{ Archived: gitlab.Ptr(false), @@ -224,7 +241,7 @@ func ToChan[T any](s []T) <-chan T { } // listGroupNextProjects returns the list of projects for the given group ID from the next pages -func (s *gitlabService) listGroupNextProjects(path string, totalPages int) (projects []gitlab.Project, warn error) { +func (s gitlabService) listGroupNextProjects(path string, totalPages int) (projects []gitlab.Project, warn error) { var wg sync.WaitGroup nextProjectsChan := make(chan []gitlab.Project, totalPages) warnChan := make(chan error, totalPages) @@ -274,7 +291,7 @@ func (s *gitlabService) listGroupNextProjects(path string, totalPages int) (proj return } -func filterUniqueProjects(projects []Project) (filteredProjects []Project) { +func filterUniqueProjects(projects []repository.Project) (filteredProjects []repository.Project) { projectsNamespaces := make(map[int]bool) for _, project := range projects { @@ -299,26 +316,25 @@ func dereferenceProjectsPointers(projects []*gitlab.Project) (filteredProjects [ return } -func mapProject(p gitlab.Project) Project { - return Project{ - ID: p.ID, - Name: p.Name, - Path: p.PathWithNamespace, - WebURL: p.WebURL, - RepoUrl: p.HTTPURLToRepo, - Platform: string(Gitlab), +func mapProject(p gitlab.Project) repository.Project { + return repository.Project{ + ID: p.ID, + Name: p.Name, + Path: p.PathWithNamespace, + WebURL: p.WebURL, + RepoUrl: p.HTTPURLToRepo, + Repository: repository.Gitlab, } } -func mapIssue(i gitlab.Issue) Issue { - return Issue{ - Title: i.Title, - WebURL: i.WebURL, - Platform: string(Gitlab), +func mapIssue(i gitlab.Issue) repository.Issue { + return repository.Issue{ + Title: i.Title, + WebURL: i.WebURL, } } -func mapIssuePtr(i *gitlab.Issue) *Issue { +func mapIssuePtr(i *gitlab.Issue) *repository.Issue { if i == nil { return nil } diff --git a/internal/repo/gitlab_client.go b/internal/repository/gitlab/gitlab_client.go similarity index 99% rename from internal/repo/gitlab_client.go rename to internal/repository/gitlab/gitlab_client.go index ece98b0..eeb30b1 100644 --- a/internal/repo/gitlab_client.go +++ b/internal/repository/gitlab/gitlab_client.go @@ -1,5 +1,5 @@ // Package gitlab provides a GitLab service to interact with the GitLab API. -package repo +package gitlab // This client is a thin wrapper around the go-gitlab library. It provides an interface to the GitLab client // The main purpose of this client is to provide an interface to the GitLab client which can be mocked in tests. diff --git a/internal/repo/gitlab_test.go b/internal/repository/gitlab/gitlab_test.go similarity index 91% rename from internal/repo/gitlab_test.go rename to internal/repository/gitlab/gitlab_test.go index dc0978c..054126c 100644 --- a/internal/repo/gitlab_test.go +++ b/internal/repository/gitlab/gitlab_test.go @@ -1,7 +1,8 @@ -package repo +package gitlab import ( "errors" + "sheriff/internal/repository" "testing" "github.com/stretchr/testify/assert" @@ -10,7 +11,7 @@ import ( ) func TestNewService(t *testing.T) { - s, err := NewGitlabService("token") + s, err := New("token") assert.Nil(t, err) assert.NotNil(t, s) @@ -20,7 +21,7 @@ func TestGetProjectListWithTopLevelGroup(t *testing.T) { mockClient := mockClient{} mockClient.On("ListGroupProjects", "group", mock.Anything, mock.Anything).Return([]*gitlab.Project{{Name: "Hello World"}}, &gitlab.Response{}, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} projects, err := svc.GetProjectList([]string{"group"}) @@ -34,7 +35,7 @@ func TestGetProjectListWithSubGroup(t *testing.T) { mockClient := mockClient{} mockClient.On("ListGroupProjects", "group/subgroup", mock.Anything, mock.Anything).Return([]*gitlab.Project{{Name: "Hello World"}}, &gitlab.Response{}, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} projects, err := svc.GetProjectList([]string{"group/subgroup"}) @@ -49,7 +50,7 @@ func TestGetProjectListWithProjects(t *testing.T) { mockClient.On("ListGroupProjects", "group/subgroup/project", mock.Anything, mock.Anything).Return([]*gitlab.Project{}, &gitlab.Response{}, errors.New("no group")) mockClient.On("GetProject", "group/subgroup/project", mock.Anything, mock.Anything).Return(&gitlab.Project{Name: "Hello World", PathWithNamespace: "group/subgroup/project"}, &gitlab.Response{}, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} projects, err := svc.GetProjectList([]string{"group/subgroup/project"}) @@ -69,7 +70,7 @@ func TestGetProjectListWithGroupAndProjects(t *testing.T) { mockClient.On("ListGroupProjects", project1.PathWithNamespace, mock.Anything, mock.Anything).Return([]*gitlab.Project{}, &gitlab.Response{}, errors.New("no group")) mockClient.On("GetProject", project1.PathWithNamespace, mock.Anything, mock.Anything).Return(project1, &gitlab.Response{}, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} projects, err := svc.GetProjectList([]string{"group", "group/subgroup", project1.PathWithNamespace}) @@ -105,7 +106,7 @@ func TestGetProjectListWithNextPage(t *testing.T) { }, }, mock.Anything).Return([]*gitlab.Project{project2}, &gitlab.Response{NextPage: 0, TotalPages: 2}, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} projects, err := svc.GetProjectList([]string{"group/subgroup"}) @@ -121,9 +122,9 @@ func TestCloseVulnerabilityIssue(t *testing.T) { mockClient.On("ListProjectIssues", mock.Anything, mock.Anything, mock.Anything).Return([]*gitlab.Issue{{State: "opened"}}, nil, nil) mockClient.On("UpdateIssue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&gitlab.Issue{State: "closed"}, nil, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} - err := svc.CloseVulnerabilityIssue(Project{}) + err := svc.CloseVulnerabilityIssue(repository.Project{}) assert.Nil(t, err) mockClient.AssertExpectations(t) @@ -133,9 +134,9 @@ func TestCloseVulnerabilityIssueAlreadyClosed(t *testing.T) { mockClient := mockClient{} mockClient.On("ListProjectIssues", mock.Anything, mock.Anything, mock.Anything).Return([]*gitlab.Issue{{State: "closed"}}, nil, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} - err := svc.CloseVulnerabilityIssue(Project{}) + err := svc.CloseVulnerabilityIssue(repository.Project{}) assert.Nil(t, err) mockClient.AssertExpectations(t) @@ -145,9 +146,9 @@ func TestCloseVulnerabilityIssueNoIssue(t *testing.T) { mockClient := mockClient{} mockClient.On("ListProjectIssues", mock.Anything, mock.Anything, mock.Anything).Return([]*gitlab.Issue{}, nil, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} - err := svc.CloseVulnerabilityIssue(Project{}) + err := svc.CloseVulnerabilityIssue(repository.Project{}) assert.Nil(t, err) mockClient.AssertExpectations(t) @@ -158,16 +159,16 @@ func TestOpenVulnerabilityIssue(t *testing.T) { mockClient.On("ListProjectIssues", mock.Anything, mock.Anything, mock.Anything).Return([]*gitlab.Issue{}, nil, nil) mockClient.On("CreateIssue", mock.Anything, mock.Anything, mock.Anything).Return(&gitlab.Issue{Title: "666"}, nil, nil) - svc := gitlabService{&mockClient} + svc := gitlabService{client: &mockClient} - i, err := svc.OpenVulnerabilityIssue(Project{}, "report") + i, err := svc.OpenVulnerabilityIssue(repository.Project{}, "report") assert.Nil(t, err) assert.NotNil(t, i) assert.Equal(t, "666", i.Title) } func TestFilterUniqueProjects(t *testing.T) { - projects := []Project{ + projects := []repository.Project{ {ID: 1}, {ID: 1}, {ID: 2}, diff --git a/internal/repository/provider/provider.go b/internal/repository/provider/provider.go new file mode 100644 index 0000000..f38b261 --- /dev/null +++ b/internal/repository/provider/provider.go @@ -0,0 +1,41 @@ +package provider + +import ( + "errors" + "fmt" + "sheriff/internal/repository" + "sheriff/internal/repository/github" + "sheriff/internal/repository/gitlab" +) + +// IProvider is the interface of the repository service as needed by sheriff +type IProvider interface { + Provide(repository.RepositoryType) repository.IRepositoryService +} + +type provider struct { + gitlabService repository.IRepositoryService + githubService repository.IRepositoryService +} + +func NewProvider(gitlabToken string, githubToken string) (IProvider, error) { + gitlabService, err := gitlab.New(gitlabToken) + if err != nil { + return nil, errors.Join(fmt.Errorf("failed to create gitlab provider"), err) + } + + githubService := github.New(githubToken) + + return provider{ + gitlabService: gitlabService, + githubService: githubService, + }, nil +} + +func (s provider) Provide(p repository.RepositoryType) repository.IRepositoryService { + if p == repository.Gitlab { + return s.gitlabService + } else { + return s.githubService + } +} diff --git a/internal/repository/repository.go b/internal/repository/repository.go new file mode 100644 index 0000000..6db11b0 --- /dev/null +++ b/internal/repository/repository.go @@ -0,0 +1,33 @@ +package repository + +const VulnerabilityIssueTitle = "Sheriff - šŸšØ Vulnerability report" + +type RepositoryType string + +const ( + Gitlab RepositoryType = "gitlab" + Github RepositoryType = "github" +) + +type Project struct { + ID int + Name string + Path string + WebURL string + RepoUrl string + Repository RepositoryType +} + +type Issue struct { + ID int + Title string + WebURL string + Open bool +} + +type IRepositoryService interface { + GetProjectList(paths []string) (projects []Project, warn error) + CloseVulnerabilityIssue(project Project) error + OpenVulnerabilityIssue(project Project, report string) (*Issue, error) + Clone(url string, dir string) error +} diff --git a/internal/scanner/osv.go b/internal/scanner/osv.go index 4d41a90..1e95edf 100644 --- a/internal/scanner/osv.go +++ b/internal/scanner/osv.go @@ -3,7 +3,7 @@ package scanner import ( "encoding/json" "path/filepath" - "sheriff/internal/repo" + "sheriff/internal/repository" "sheriff/internal/shell" "strconv" "time" @@ -136,7 +136,7 @@ func (s *osvScanner) Scan(dir string) (*OsvReport, error) { } // GenerateReport generates a Report struct from the OsvReport. -func (s *osvScanner) GenerateReport(p repo.Project, r *OsvReport) Report { +func (s *osvScanner) GenerateReport(p repository.Project, r *OsvReport) Report { if r == nil { return Report{ Project: p, diff --git a/internal/scanner/osv_test.go b/internal/scanner/osv_test.go index 3dc79a8..39e4420 100644 --- a/internal/scanner/osv_test.go +++ b/internal/scanner/osv_test.go @@ -1,7 +1,7 @@ package scanner import ( - "sheriff/internal/repo" + "sheriff/internal/repository" "sheriff/internal/shell" "testing" @@ -96,7 +96,7 @@ func (m *mockCommandRunner) Run(shell.CommandInput) (shell.CommandOutput, error) func TestGenerateReportOSV(t *testing.T) { mockReport := createMockReport("10.0") s := osvScanner{} - got := s.GenerateReport(repo.Project{}, mockReport) + got := s.GenerateReport(repository.Project{}, mockReport) assert.NotNil(t, got) assert.Len(t, got.Vulnerabilities, 1) @@ -132,7 +132,7 @@ func TestGenerateReportOSVHasCorrectSeverityKind(t *testing.T) { for input, want := range testCases { t.Run(input, func(t *testing.T) { mockReport := createMockReport(input) - got := s.GenerateReport(repo.Project{}, mockReport) + got := s.GenerateReport(repository.Project{}, mockReport) assert.NotNil(t, got) assert.Equal(t, want, got.Vulnerabilities[0].SeverityScoreKind) @@ -156,7 +156,7 @@ func TestReportContainsHasAvailableFix(t *testing.T) { }, }, }) - got := s.GenerateReport(repo.Project{}, mockReport) + got := s.GenerateReport(repository.Project{}, mockReport) assert.NotNil(t, got) assert.Len(t, got.Vulnerabilities, 1) diff --git a/internal/scanner/vulnscanner.go b/internal/scanner/vulnscanner.go index f11761e..29334ab 100644 --- a/internal/scanner/vulnscanner.go +++ b/internal/scanner/vulnscanner.go @@ -3,7 +3,7 @@ package scanner import ( "sheriff/internal/config" - "sheriff/internal/repo" + "sheriff/internal/repository" ) type SeverityScoreKind string @@ -47,7 +47,7 @@ type Vulnerability struct { // Report is the main report representation of a project vulnerability scan. type Report struct { - Project repo.Project + Project repository.Project ProjectConfig config.ProjectConfig // Contains the project-level configuration that users of sheriff may have in their repository IsVulnerable bool Vulnerabilities []Vulnerability @@ -61,5 +61,5 @@ type VulnScanner[T any] interface { // Scan runs a vulnerability scan on the given directory Scan(dir string) (*T, error) // GenerateReport maps the report from the scanner to our internal representation of vulnerability reports. - GenerateReport(p repo.Project, r *T) Report + GenerateReport(p repository.Project, r *T) Report } diff --git a/justfile b/justfile index e5bda49..c2361fe 100644 --- a/justfile +++ b/justfile @@ -2,7 +2,7 @@ set dotenv-load := true [no-exit-message] run *ARGS: - GITLAB_TOKEN=$GITLAB_TOKEN go run . {{ARGS}} + go run . {{ARGS}} alias t := test test *ARGS="./...":