diff --git a/.mockery.yaml b/.mockery.yaml index 52b2cdbe7..84c5318bb 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -191,4 +191,22 @@ packages: interfaces: GroupRepository: config: - filename: "group_repository.go" \ No newline at end of file + filename: "group_repository.go" + github.com/goto/shield/internal/proxy/envoy/xds/ads: + config: + dir: "internal/proxy/envoy/xds/ads/mocks" + outpkg: "mocks" + mockname: "{{.InterfaceName}}" + interfaces: + Repository: + config: + filename: "repository.go" + github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3: + config: + dir: "internal/proxy/envoy/xds/ads/mocks" + outpkg: "mocks" + mockname: "{{.InterfaceName}}" + interfaces: + AggregatedDiscoveryService_StreamAggregatedResourcesServer: + config: + filename: "ads_stream.go" \ No newline at end of file diff --git a/cmd/proxy.go b/cmd/proxy.go new file mode 100644 index 000000000..c7f443fd6 --- /dev/null +++ b/cmd/proxy.go @@ -0,0 +1,94 @@ +package cmd + +import ( + "github.com/MakeNowJust/heredoc" + "github.com/goto/shield/config" + "github.com/goto/shield/internal/proxy/envoy/xds" + "github.com/goto/shield/internal/store/postgres" + shieldlogger "github.com/goto/shield/pkg/logger" + "github.com/spf13/cobra" + cli "github.com/spf13/cobra" +) + +func ProxyCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "proxy ", + Short: "Proxy management", + Long: "Server management commands.", + Example: heredoc.Doc(` + $ shield proxy envoy start -c ./config.yaml + `), + } + + cmd.AddCommand(proxyEnvoyXDSCommand()) + + return cmd +} + +func proxyEnvoyXDSCommand() *cobra.Command { + c := &cli.Command{ + Use: "envoy", + Short: "Envoy Agent xDS management", + Long: "Envoy Agent xDS management commands.", + Example: heredoc.Doc(` + $ shield proxy envoy start + `), + } + + c.AddCommand(envoyXDSStartCommand()) + + return c +} + +func envoyXDSStartCommand() *cobra.Command { + var configFile string + + c := &cli.Command{ + Use: "start", + Short: "Start Envoy Agent xDS server", + Long: "Start Envoy Agent xDS server commands.", + Example: "shield proxy envoy start", + RunE: func(cmd *cli.Command, args []string) error { + appConfig, err := config.Load(configFile) + if err != nil { + panic(err) + } + + logger := shieldlogger.InitLogger(shieldlogger.Config{Level: appConfig.Log.Level}) + + dbClient, err := setupDB(appConfig.DB, logger) + if err != nil { + return err + } + defer func() { + logger.Info("cleaning up db") + dbClient.Close() + }() + + ctx := cmd.Context() + + pgRuleRepository := postgres.NewRuleRepository(dbClient) + if err := pgRuleRepository.InitCache(ctx); err != nil { + return err + } + + cbs, repositories, err := buildXDSDependencies(ctx, logger, appConfig.Proxy, pgRuleRepository) + if err != nil { + return err + } + defer func() { + logger.Info("cleaning up rules proxy blob") + for _, f := range cbs { + if err := f(); err != nil { + logger.Warn("error occurred during shutdown rules proxy blob storages", "err", err) + } + } + }() + + return xds.Serve(ctx, logger, appConfig.Proxy, repositories) + }, + } + + c.Flags().StringVarP(&configFile, "config", "c", "", "Config file path") + return c +} diff --git a/cmd/root.go b/cmd/root.go index 73b910920..9f2f3237d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -40,6 +40,7 @@ func New(cliConfig *Config) *cli.Command { } cmd.AddCommand(ServerCommand()) + cmd.AddCommand(ProxyCommand()) cmd.AddCommand(NamespaceCommand(cliConfig)) cmd.AddCommand(UserCommand(cliConfig)) cmd.AddCommand(OrganizationCommand(cliConfig)) diff --git a/cmd/serve.go b/cmd/serve.go index 4da2ead2b..f2e263f28 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -34,6 +34,7 @@ import ( "github.com/goto/shield/core/user" "github.com/goto/shield/internal/adapter" "github.com/goto/shield/internal/api" + proxycfg "github.com/goto/shield/internal/proxy" "github.com/goto/shield/internal/schema" "github.com/goto/shield/internal/server" "github.com/goto/shield/internal/store/blob" @@ -179,7 +180,13 @@ func StartServer(logger *log.Zap, cfg *config.Shield) error { } // serving proxies - cbs, cps, err := serveProxies(ctx, logger, cfg.App.IdentityProxyHeader, cfg.App.UserIDHeader, cfg.Proxy, pgRuleRepository, deps.ResourceService, deps.RelationService, deps.UserService, deps.GroupService, deps.ProjectService, deps.RelationAdapter) + var cbs []func() error + var cps []func(context.Context) error + if cfg.Proxy.Type == proxycfg.ENVOY_PROXY { + cbs, err = serveXDS(ctx, logger, cfg.Proxy, pgRuleRepository) + } else { + cbs, cps, err = serveProxies(ctx, logger, cfg.App.IdentityProxyHeader, cfg.App.UserIDHeader, cfg.Proxy, pgRuleRepository, deps.ResourceService, deps.RelationService, deps.UserService, deps.GroupService, deps.ProjectService, deps.RelationAdapter) + } if err != nil { return err } diff --git a/cmd/serve_xds.go b/cmd/serve_xds.go new file mode 100644 index 000000000..f478f6724 --- /dev/null +++ b/cmd/serve_xds.go @@ -0,0 +1,70 @@ +package cmd + +import ( + "context" + "errors" + "net/url" + + "github.com/goto/salt/log" + "github.com/goto/shield/core/rule" + "github.com/goto/shield/internal/proxy" + "github.com/goto/shield/internal/proxy/envoy/xds" + "github.com/goto/shield/internal/proxy/envoy/xds/ads" + "github.com/goto/shield/internal/store/blob" + "github.com/goto/shield/internal/store/postgres" +) + +func serveXDS(ctx context.Context, logger *log.Zap, cfg proxy.ServicesConfig, pgRuleRepository *postgres.RuleRepository) ([]func() error, error) { + cleanUpBlobs, repositories, err := buildXDSDependencies(ctx, logger, cfg, pgRuleRepository) + if err != nil { + return nil, err + } + + errChan := make(chan error) + go func() { + err := xds.Serve(ctx, logger, cfg, repositories) + if err != nil { + errChan <- err + logger.Error("error while running envoy xds server", "error", err) + } + }() + + return cleanUpBlobs, nil +} + +func buildXDSDependencies(ctx context.Context, logger *log.Zap, cfg proxy.ServicesConfig, pgRuleRepository *postgres.RuleRepository) ([]func() error, map[string]ads.Repository, error) { + var cleanUpBlobs []func() error + repositories := make(map[string]ads.Repository) + + for _, svcConfig := range cfg.Services { + parsedRuleConfigURL, err := url.Parse(svcConfig.RulesPath) + if err != nil { + return nil, nil, err + } + + var repository ads.Repository + switch parsedRuleConfigURL.Scheme { + case rule.RULES_CONFIG_STORAGE_PG: + repository = pgRuleRepository + case rule.RULES_CONFIG_STORAGE_GS, + rule.RULES_CONFIG_STORAGE_FILE, + rule.RULES_CONFIG_STORAGE_MEM: + ruleBlobFS, err := blob.NewStore(ctx, svcConfig.RulesPath, svcConfig.RulesPathSecret) + if err != nil { + return nil, nil, err + } + + blobRuleRepository := blob.NewRuleRepository(logger, ruleBlobFS) + if err := blobRuleRepository.InitCache(ctx, ruleCacheRefreshDelay); err != nil { + return nil, nil, err + } + cleanUpBlobs = append(cleanUpBlobs, blobRuleRepository.Close) + repository = blobRuleRepository + default: + return nil, nil, errors.New("invalid rule config storage") + } + repositories[svcConfig.Name] = repository + } + + return cleanUpBlobs, repositories, nil +} diff --git a/config/config.yaml b/config/config.yaml index 4992752df..4c39aade6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -37,6 +37,16 @@ spicedb: # proxy configuration proxy: + # proxy type configuration + # valid values are "shield" and "envoy", with the default set to "shield" + type: shield + # envoy proxy configuration, will be ignored if proxy type set to "shield" + envoy: + xds: + host: 127.0.0.1 + port: 8082 + refresh_interval: 10s + # proxy services configuration services: - name: test host: 0.0.0.0 diff --git a/go.mod b/go.mod index 6c05ceb73..b667828e2 100644 --- a/go.mod +++ b/go.mod @@ -12,8 +12,7 @@ require ( github.com/authzed/spicedb v1.15.0 github.com/dgraph-io/ristretto v0.1.1 github.com/doug-martin/goqu/v9 v9.18.0 - github.com/envoyproxy/protoc-gen-validate v1.0.4 - github.com/ghodss/yaml v1.0.0 + github.com/envoyproxy/protoc-gen-validate v1.1.0 github.com/golang-migrate/migrate/v4 v4.16.0 github.com/golang/protobuf v1.5.4 github.com/google/go-cmp v0.6.0 @@ -54,18 +53,21 @@ require ( golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/net v0.27.0 golang.org/x/oauth2 v0.21.0 - google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8 - google.golang.org/grpc v1.64.0 + google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 + google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) require ( + cel.dev/expr v0.15.0 // indirect + github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect + github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/oklog/run v1.1.0 // indirect - github.com/planetscale/vtprotobuf v0.6.0 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/samber/lo v1.39.0 // indirect ) @@ -97,6 +99,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect github.com/emirpasic/gods v1.18.1 // indirect + github.com/envoyproxy/go-control-plane v0.13.1 github.com/fatih/color v1.17.0 // indirect github.com/felixge/fgprof v0.9.3 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -105,7 +108,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect - github.com/golang/glog v1.2.0 // indirect + github.com/golang/glog v1.2.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/cel-go v0.13.0 // indirect github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 // indirect diff --git a/go.sum b/go.sum index 3278c83b8..6518dcd0a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ bazil.org/fuse v0.0.0-20160811212531-371fbbdaa898/go.mod h1:Xbm+BRKSBEpa4q4hTSxohYNQpsxXPbPry4JJWOB3LB8= bazil.org/fuse v0.0.0-20200407214033-5883e5a4b512/go.mod h1:FbcW6z/2VytnFDhZfumh8Ss8zxHE6qpMP5sHTRe0EaM= +cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w= +cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -621,6 +623,7 @@ github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK3 github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= @@ -664,6 +667,8 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20220314180256-7f1daf1720fc/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw= +github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/cockroachdb/datadriven v0.0.0-20190809214429-80d97fb3cbaa/go.mod h1:zn76sxSg3SzpJ0PPJaLDCu+Bu0Lg3sKTORVIj19EIF8= @@ -880,11 +885,13 @@ github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go. github.com/envoyproxy/go-control-plane v0.10.1/go.mod h1:AY7fTTXNdv/aJ2O5jwpxAPOWUZ7hQAEvzN5Pf27BkQQ= github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= github.com/envoyproxy/go-control-plane v0.10.3/go.mod h1:fJJn/j26vwOu972OllsvAgJJM//w9BV6Fxbg2LuVd34= +github.com/envoyproxy/go-control-plane v0.13.1 h1:vPfJZCkob6yTMEgS+0TwfTUfbHjfy/6vOJ8hUWX/uXE= +github.com/envoyproxy/go-control-plane v0.13.1/go.mod h1:X45hY0mufo6Fd0KW3rqsGvQMw58jvjymeCzBU3mWyHw= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.6.7/go.mod h1:dyJXwwfPK2VSqiB9Klm1J6romD608Ba7Hij42vrOBCo= github.com/envoyproxy/protoc-gen-validate v0.6.13/go.mod h1:qEySVqXrEugbHKvmhI8ZqtQi75/RHSSRNpffvB4I6Bw= -github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= -github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= +github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= +github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.11.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= @@ -920,7 +927,6 @@ github.com/garyburd/redigo v0.0.0-20150301180006-535138d7bcd7/go.mod h1:NR3MbYis github.com/getkin/kin-openapi v0.76.0/go.mod h1:660oXbgy5JFMKreazJaQTw7o+X00qeSyhcnluiMv+Xg= github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= @@ -1058,8 +1064,8 @@ github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2V github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= -github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= -github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/golang/glog v1.2.1 h1:OptwRhECazUx5ix5TTWC3EZhsZEHWcYWY4FQHTIubm4= +github.com/golang/glog v1.2.1/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= github.com/golang/groupcache v0.0.0-20160516000752-02826c3e7903/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190129154638-5b532d6fd5ef/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -1719,8 +1725,8 @@ github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= github.com/pkg/sftp v1.10.1/go.mod h1:lYOWFsE0bwd1+KfKJaKeuokY15vzFx25BLbzYYoAxZI= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= -github.com/planetscale/vtprotobuf v0.6.0 h1:nBeETjudeJ5ZgBHUz1fVHvbqUKnYOXNhsIEabROxmNA= -github.com/planetscale/vtprotobuf v0.6.0/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -2870,8 +2876,8 @@ google.golang.org/genproto v0.0.0-20221201164419-0e50fba7f41c/go.mod h1:rZS5c/ZV google.golang.org/genproto v0.0.0-20221201204527-e3fa12d562f3/go.mod h1:rZS5c/ZVYMaOGBfO68GWtjOw/eLaZM1X6iVtgjZ+EWg= google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 h1:9+tzLLstTlPTRyJTh+ah5wIMsBW5c4tQwGTN3thOW9Y= google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9/go.mod h1:mqHbVIp48Muh7Ywss/AD6I5kNVKZMmAa/QEW58Gxp2s= -google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8 h1:W5Xj/70xIA4x60O/IFyXivR5MGqblAb8R3w26pnD6No= -google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8/go.mod h1:vPrPUTsDCYxXWjP7clS81mZ6/803D8K4iM9Ma27VKas= +google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157 h1:7whR9kGa5LUwFtpLm2ArCEejtnxlGeLbAyjFY8sGNFw= +google.golang.org/genproto/googleapis/api v0.0.0-20240528184218-531527333157/go.mod h1:99sLkeliLXfdj2J75X3Ho+rrVCaJze0uwN7zDDkjPVU= google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117 h1:1GBuWVLM/KMVUv1t1En5Gs+gFZCNd360GGb4sSxtrhU= google.golang.org/genproto/googleapis/rpc v0.0.0-20240604185151-ef581f913117/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= google.golang.org/grpc v0.0.0-20160317175043-d3ddb4469d5a/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= @@ -2919,8 +2925,8 @@ google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCD google.golang.org/grpc v1.50.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= google.golang.org/grpc v1.50.1/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= google.golang.org/grpc v1.51.0/go.mod h1:wgNDFcnuBGmxLKI/qn4T+m5BtEBYXJPvibbUPsAIPww= -google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= -google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= +google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= diff --git a/internal/proxy/config.go b/internal/proxy/config.go index 7b01afbf7..be6ce57a6 100644 --- a/internal/proxy/config.go +++ b/internal/proxy/config.go @@ -1,7 +1,26 @@ package proxy +import "time" + +const ( + SHIELD_PROXY = "shield" + ENVOY_PROXY = "envoy" +) + type ServicesConfig struct { - Services []Config `yaml:"services" mapstructure:"services"` + Type string `yaml:"type" mapstructure:"type"` + EnvoyAgent EnvoyAgent `yaml:"envoy" mapstructure:"envoy"` + Services []Config `yaml:"services" mapstructure:"services"` +} + +type EnvoyAgent struct { + XDS XDS `yaml:"xds" mapstructure:"xds"` +} + +type XDS struct { + Host string `yaml:"host" mapstructure:"host" default:"shield"` + Port int `yaml:"port" mapstructure:"port"` + RefreshInterval time.Duration `yaml:"refresh_interval" mapstructure:"refresh_interval" default:"60s"` } type Config struct { diff --git a/internal/proxy/envoy/xds/ads/ads.go b/internal/proxy/envoy/xds/ads/ads.go new file mode 100644 index 000000000..d2f5347b2 --- /dev/null +++ b/internal/proxy/envoy/xds/ads/ads.go @@ -0,0 +1,21 @@ +package ads + +import ( + "context" + "time" + + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "github.com/goto/shield/core/rule" +) + +const ( + HTTP_CONNECTION_MANAGER_TYPE_URL = resource.APITypePrefix + "envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager" + ROUTER_TYPE_URL = resource.APITypePrefix + "envoy.extensions.filters.http.router.v3.Router" + URI_TEMPLATE_TYPE_URL = resource.APITypePrefix + "envoy.extensions.path.match.uri_template.v3.UriTemplateMatchConfig" + STDOUT_LOGGER_TYPE_URL = resource.APITypePrefix + "envoy.extensions.access_loggers.stream.v3.StdoutAccessLog" +) + +type Repository interface { + Fetch(ctx context.Context) ([]rule.Ruleset, error) + IsUpdated(ctx context.Context, since time.Time) bool +} diff --git a/internal/proxy/envoy/xds/ads/mocks/ads_stream.go b/internal/proxy/envoy/xds/ads/mocks/ads_stream.go new file mode 100644 index 000000000..1b270bd02 --- /dev/null +++ b/internal/proxy/envoy/xds/ads/mocks/ads_stream.go @@ -0,0 +1,406 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + discoveryv3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + metadata "google.golang.org/grpc/metadata" + + mock "github.com/stretchr/testify/mock" +) + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer is an autogenerated mock type for the AggregatedDiscoveryService_StreamAggregatedResourcesServer type +type AggregatedDiscoveryService_StreamAggregatedResourcesServer struct { + mock.Mock +} + +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter struct { + mock *mock.Mock +} + +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) EXPECT() *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter{mock: &_m.Mock} +} + +// Context provides a mock function with given fields: +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) Context() context.Context { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Context") + } + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Context' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call struct { + *mock.Call +} + +// Context is a helper method to define mock.On call +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) Context() *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call{Call: _e.mock.On("Context")} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call) Run(run func()) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call) Return(_a0 context.Context) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call) RunAndReturn(run func() context.Context) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Context_Call { + _c.Call.Return(run) + return _c +} + +// Recv provides a mock function with given fields: +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) Recv() (*discoveryv3.DiscoveryRequest, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Recv") + } + + var r0 *discoveryv3.DiscoveryRequest + var r1 error + if rf, ok := ret.Get(0).(func() (*discoveryv3.DiscoveryRequest, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() *discoveryv3.DiscoveryRequest); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*discoveryv3.DiscoveryRequest) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Recv' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call struct { + *mock.Call +} + +// Recv is a helper method to define mock.On call +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) Recv() *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call{Call: _e.mock.On("Recv")} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call) Run(run func()) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call) Return(_a0 *discoveryv3.DiscoveryRequest, _a1 error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call) RunAndReturn(run func() (*discoveryv3.DiscoveryRequest, error)) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Recv_Call { + _c.Call.Return(run) + return _c +} + +// RecvMsg provides a mock function with given fields: m +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) RecvMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for RecvMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecvMsg' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call struct { + *mock.Call +} + +// RecvMsg is a helper method to define mock.On call +// - m interface{} +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) RecvMsg(m interface{}) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call{Call: _e.mock.On("RecvMsg", m)} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call) Run(run func(m interface{})) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call) Return(_a0 error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call) RunAndReturn(run func(interface{}) error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_RecvMsg_Call { + _c.Call.Return(run) + return _c +} + +// Send provides a mock function with given fields: _a0 +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) Send(_a0 *discoveryv3.DiscoveryResponse) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Send") + } + + var r0 error + if rf, ok := ret.Get(0).(func(*discoveryv3.DiscoveryResponse) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Send' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call struct { + *mock.Call +} + +// Send is a helper method to define mock.On call +// - _a0 *discoveryv3.DiscoveryResponse +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) Send(_a0 interface{}) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call{Call: _e.mock.On("Send", _a0)} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call) Run(run func(_a0 *discoveryv3.DiscoveryResponse)) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*discoveryv3.DiscoveryResponse)) + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call) Return(_a0 error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call) RunAndReturn(run func(*discoveryv3.DiscoveryResponse) error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Send_Call { + _c.Call.Return(run) + return _c +} + +// SendHeader provides a mock function with given fields: _a0 +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) SendHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SendHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendHeader' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call struct { + *mock.Call +} + +// SendHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) SendHeader(_a0 interface{}) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call{Call: _e.mock.On("SendHeader", _a0)} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call) Run(run func(_a0 metadata.MD)) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call) Return(_a0 error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call) RunAndReturn(run func(metadata.MD) error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendHeader_Call { + _c.Call.Return(run) + return _c +} + +// SendMsg provides a mock function with given fields: m +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) SendMsg(m interface{}) error { + ret := _m.Called(m) + + if len(ret) == 0 { + panic("no return value specified for SendMsg") + } + + var r0 error + if rf, ok := ret.Get(0).(func(interface{}) error); ok { + r0 = rf(m) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendMsg' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call struct { + *mock.Call +} + +// SendMsg is a helper method to define mock.On call +// - m interface{} +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) SendMsg(m interface{}) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call{Call: _e.mock.On("SendMsg", m)} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call) Run(run func(m interface{})) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(interface{})) + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call) Return(_a0 error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call) RunAndReturn(run func(interface{}) error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SendMsg_Call { + _c.Call.Return(run) + return _c +} + +// SetHeader provides a mock function with given fields: _a0 +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) SetHeader(_a0 metadata.MD) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for SetHeader") + } + + var r0 error + if rf, ok := ret.Get(0).(func(metadata.MD) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHeader' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call struct { + *mock.Call +} + +// SetHeader is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) SetHeader(_a0 interface{}) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call{Call: _e.mock.On("SetHeader", _a0)} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call) Run(run func(_a0 metadata.MD)) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call) Return(_a0 error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call) RunAndReturn(run func(metadata.MD) error) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetHeader_Call { + _c.Call.Return(run) + return _c +} + +// SetTrailer provides a mock function with given fields: _a0 +func (_m *AggregatedDiscoveryService_StreamAggregatedResourcesServer) SetTrailer(_a0 metadata.MD) { + _m.Called(_a0) +} + +// AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetTrailer' +type AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call struct { + *mock.Call +} + +// SetTrailer is a helper method to define mock.On call +// - _a0 metadata.MD +func (_e *AggregatedDiscoveryService_StreamAggregatedResourcesServer_Expecter) SetTrailer(_a0 interface{}) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call { + return &AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call{Call: _e.mock.On("SetTrailer", _a0)} +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call) Run(run func(_a0 metadata.MD)) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(metadata.MD)) + }) + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call) Return() *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call { + _c.Call.Return() + return _c +} + +func (_c *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call) RunAndReturn(run func(metadata.MD)) *AggregatedDiscoveryService_StreamAggregatedResourcesServer_SetTrailer_Call { + _c.Call.Return(run) + return _c +} + +// NewAggregatedDiscoveryService_StreamAggregatedResourcesServer creates a new instance of AggregatedDiscoveryService_StreamAggregatedResourcesServer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAggregatedDiscoveryService_StreamAggregatedResourcesServer(t interface { + mock.TestingT + Cleanup(func()) +}) *AggregatedDiscoveryService_StreamAggregatedResourcesServer { + mock := &AggregatedDiscoveryService_StreamAggregatedResourcesServer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proxy/envoy/xds/ads/mocks/repository.go b/internal/proxy/envoy/xds/ads/mocks/repository.go new file mode 100644 index 000000000..ab74b2daa --- /dev/null +++ b/internal/proxy/envoy/xds/ads/mocks/repository.go @@ -0,0 +1,144 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + rule "github.com/goto/shield/core/rule" + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// Fetch provides a mock function with given fields: ctx +func (_m *Repository) Fetch(ctx context.Context) ([]rule.Ruleset, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Fetch") + } + + var r0 []rule.Ruleset + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]rule.Ruleset, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []rule.Ruleset); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]rule.Ruleset) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Fetch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Fetch' +type Repository_Fetch_Call struct { + *mock.Call +} + +// Fetch is a helper method to define mock.On call +// - ctx context.Context +func (_e *Repository_Expecter) Fetch(ctx interface{}) *Repository_Fetch_Call { + return &Repository_Fetch_Call{Call: _e.mock.On("Fetch", ctx)} +} + +func (_c *Repository_Fetch_Call) Run(run func(ctx context.Context)) *Repository_Fetch_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Repository_Fetch_Call) Return(_a0 []rule.Ruleset, _a1 error) *Repository_Fetch_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Fetch_Call) RunAndReturn(run func(context.Context) ([]rule.Ruleset, error)) *Repository_Fetch_Call { + _c.Call.Return(run) + return _c +} + +// IsUpdated provides a mock function with given fields: ctx, since +func (_m *Repository) IsUpdated(ctx context.Context, since time.Time) bool { + ret := _m.Called(ctx, since) + + if len(ret) == 0 { + panic("no return value specified for IsUpdated") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, time.Time) bool); ok { + r0 = rf(ctx, since) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Repository_IsUpdated_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsUpdated' +type Repository_IsUpdated_Call struct { + *mock.Call +} + +// IsUpdated is a helper method to define mock.On call +// - ctx context.Context +// - since time.Time +func (_e *Repository_Expecter) IsUpdated(ctx interface{}, since interface{}) *Repository_IsUpdated_Call { + return &Repository_IsUpdated_Call{Call: _e.mock.On("IsUpdated", ctx, since)} +} + +func (_c *Repository_IsUpdated_Call) Run(run func(ctx context.Context, since time.Time)) *Repository_IsUpdated_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(time.Time)) + }) + return _c +} + +func (_c *Repository_IsUpdated_Call) Return(_a0 bool) *Repository_IsUpdated_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_IsUpdated_Call) RunAndReturn(run func(context.Context, time.Time) bool) *Repository_IsUpdated_Call { + _c.Call.Return(run) + return _c +} + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/proxy/envoy/xds/ads/pubsub.go b/internal/proxy/envoy/xds/ads/pubsub.go new file mode 100644 index 000000000..92140351e --- /dev/null +++ b/internal/proxy/envoy/xds/ads/pubsub.go @@ -0,0 +1,27 @@ +package ads + +import "errors" + +type Message struct { + NodeID string + VersionInfo string + Nonce string + TypeUrl string +} + +type MessageChan chan Message + +var ( + ErrChannelClosed = errors.New("can't send message on closed channel") +) + +func (m MessageChan) Push(message Message) (err error) { + defer func() { + if recover() != nil { + err = ErrChannelClosed + } + }() + + m <- message + return nil +} diff --git a/internal/proxy/envoy/xds/ads/pubsub_test.go b/internal/proxy/envoy/xds/ads/pubsub_test.go new file mode 100644 index 000000000..3ace9ee9e --- /dev/null +++ b/internal/proxy/envoy/xds/ads/pubsub_test.go @@ -0,0 +1,29 @@ +package ads_test + +import ( + "errors" + "testing" + + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "github.com/goto/shield/internal/proxy/envoy/xds/ads" + "github.com/stretchr/testify/assert" +) + +func TestPush(t *testing.T) { + message := ads.Message{ + NodeID: "node-1", + VersionInfo: "v1", + Nonce: "test", + TypeUrl: resource.ClusterType, + } + messageChan := make(ads.MessageChan, 1) + + err := messageChan.Push(message) + recv := <-messageChan + assert.NoError(t, err) + assert.Equal(t, message, recv) + + close(messageChan) + err = messageChan.Push(message) + assert.True(t, errors.Is(err, ads.ErrChannelClosed)) +} diff --git a/internal/proxy/envoy/xds/ads/response.go b/internal/proxy/envoy/xds/ads/response.go new file mode 100644 index 000000000..092deabae --- /dev/null +++ b/internal/proxy/envoy/xds/ads/response.go @@ -0,0 +1,109 @@ +package ads + +import ( + cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + xds "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" +) + +type ResponseStream struct { + stream xds.AggregatedDiscoveryService_StreamAggregatedResourcesServer + versionInfo string + nonce string +} + +func (s ResponseStream) StreamCDS(clusters []*cluster.Cluster) error { + if len(clusters) == 0 { + return nil + } + + var resources []*anypb.Any + for _, cls := range clusters { + res, err := proto.Marshal(cls) + if err != nil { + return err + } + + resources = append(resources, &anypb.Any{ + TypeUrl: resource.ClusterType, + Value: res, + }) + } + + resp := &xds.DiscoveryResponse{ + VersionInfo: s.versionInfo, + Nonce: s.nonce, + Resources: resources, + TypeUrl: resource.ClusterType, + } + + return s.stream.Send(resp) +} + +func (s ResponseStream) StreamLDS(listeners []*listener.Listener) error { + if len(listeners) == 0 { + return nil + } + + var resources []*anypb.Any + for _, ls := range listeners { + res, err := proto.Marshal(ls) + if err != nil { + return err + } + + resources = append(resources, &anypb.Any{ + TypeUrl: resource.ListenerType, + Value: res, + }) + } + + resp := &xds.DiscoveryResponse{ + VersionInfo: s.versionInfo, + Nonce: s.nonce, + Resources: resources, + TypeUrl: resource.ListenerType, + } + return s.stream.Send(resp) +} + +func (s ResponseStream) StreamRDS(routes []*route.RouteConfiguration) error { + if len(routes) == 0 { + return nil + } + + var resources []*anypb.Any + for _, r := range routes { + res, err := proto.Marshal(r) + if err != nil { + return err + } + + resources = append(resources, &anypb.Any{ + TypeUrl: resource.RouteType, + Value: res, + }) + } + + resp := &xds.DiscoveryResponse{ + VersionInfo: s.versionInfo, + Nonce: s.nonce, + Resources: resources, + TypeUrl: resource.RouteType, + } + + return s.stream.Send(resp) +} + +func NewResponseStream(stream xds.AggregatedDiscoveryService_StreamAggregatedResourcesServer, versionInfo, nonce string) ResponseStream { + return ResponseStream{ + stream: stream, + versionInfo: versionInfo, + nonce: nonce, + } +} diff --git a/internal/proxy/envoy/xds/ads/response_test.go b/internal/proxy/envoy/xds/ads/response_test.go new file mode 100644 index 000000000..b2cd73e73 --- /dev/null +++ b/internal/proxy/envoy/xds/ads/response_test.go @@ -0,0 +1,162 @@ +package ads_test + +import ( + "testing" + + clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + xds "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "github.com/goto/shield/internal/proxy/envoy/xds/ads" + "github.com/goto/shield/internal/proxy/envoy/xds/ads/mocks" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" +) + +var ( + testClusterStream = &clusterv3.Cluster{} + testClusterBytes, _ = proto.Marshal(testClusterStream) + testClusterResources = &anypb.Any{ + TypeUrl: resource.ClusterType, + Value: testClusterBytes, + } + + testListenerStream = &listenerv3.Listener{} + testListenerBytes, _ = proto.Marshal(testListenerStream) + testListenerResources = &anypb.Any{ + TypeUrl: resource.ListenerType, + Value: testListenerBytes, + } + + testRouteStream = &routev3.RouteConfiguration{} + testRouteBytes, _ = proto.Marshal(testRouteStream) + testRouteResources = &anypb.Any{ + TypeUrl: resource.RouteType, + Value: testRouteBytes, + } +) + +func TestStreamCDS(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cluster []*clusterv3.Cluster + setup func(t *testing.T) ads.ResponseStream + wantErr error + }{ + { + name: "should return error from stream send", + cluster: []*clusterv3.Cluster{testClusterStream}, + setup: func(t *testing.T) ads.ResponseStream { + t.Helper() + stream := mocks.AggregatedDiscoveryService_StreamAggregatedResourcesServer{} + stream.EXPECT().Send(&xds.DiscoveryResponse{ + VersionInfo: "v1", + Nonce: "test", + Resources: []*anypb.Any{testClusterResources}, + TypeUrl: resource.ClusterType, + }).Return(nil) + return ads.NewResponseStream(&stream, "v1", "test") + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + resp := tt.setup(t) + + assert.NotNil(t, resp) + got := resp.StreamCDS(tt.cluster) + + assert.Equal(t, tt.wantErr, got) + }) + } +} + +func TestStreamLDS(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + listener []*listenerv3.Listener + setup func(t *testing.T) ads.ResponseStream + wantErr error + }{ + { + name: "should return error from stream send", + listener: []*listenerv3.Listener{testListenerStream}, + setup: func(t *testing.T) ads.ResponseStream { + t.Helper() + stream := mocks.AggregatedDiscoveryService_StreamAggregatedResourcesServer{} + stream.EXPECT().Send(&xds.DiscoveryResponse{ + VersionInfo: "v1", + Nonce: "test", + Resources: []*anypb.Any{testListenerResources}, + TypeUrl: resource.ListenerType, + }).Return(nil) + return ads.NewResponseStream(&stream, "v1", "test") + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + resp := tt.setup(t) + + assert.NotNil(t, resp) + got := resp.StreamLDS(tt.listener) + + assert.Equal(t, tt.wantErr, got) + }) + } +} + +func TestStreamRDS(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + route []*routev3.RouteConfiguration + setup func(t *testing.T) ads.ResponseStream + wantErr error + }{ + { + name: "should return error from stream send", + route: []*routev3.RouteConfiguration{testRouteStream}, + setup: func(t *testing.T) ads.ResponseStream { + t.Helper() + stream := mocks.AggregatedDiscoveryService_StreamAggregatedResourcesServer{} + stream.EXPECT().Send(&xds.DiscoveryResponse{ + VersionInfo: "v1", + Nonce: "test", + Resources: []*anypb.Any{testRouteResources}, + TypeUrl: resource.RouteType, + }).Return(nil) + return ads.NewResponseStream(&stream, "v1", "test") + }, + wantErr: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + resp := tt.setup(t) + + assert.NotNil(t, resp) + got := resp.StreamRDS(tt.route) + + assert.Equal(t, tt.wantErr, got) + }) + } +} diff --git a/internal/proxy/envoy/xds/ads/server.go b/internal/proxy/envoy/xds/ads/server.go new file mode 100644 index 000000000..8fa70a15d --- /dev/null +++ b/internal/proxy/envoy/xds/ads/server.go @@ -0,0 +1,32 @@ +package ads + +import ( + "errors" + "time" + + xds "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/goto/salt/log" +) + +type Server struct { + Logger log.Logger + Services map[string]Service + RefreshInterval time.Duration +} + +func (a *Server) DeltaAggregatedResources(xds.AggregatedDiscoveryService_DeltaAggregatedResourcesServer) error { + return errors.New("not implemented") +} + +func (a *Server) StreamAggregatedResources(stream xds.AggregatedDiscoveryService_StreamAggregatedResourcesServer) error { + err := NewStream(a.Logger, a.RefreshInterval, stream, a.Services).Stream() + return err +} + +func New(logger log.Logger, services map[string]Service, refreshInterval time.Duration) *Server { + return &Server{ + Logger: logger, + Services: services, + RefreshInterval: refreshInterval, + } +} diff --git a/internal/proxy/envoy/xds/ads/service.go b/internal/proxy/envoy/xds/ads/service.go new file mode 100644 index 000000000..6c3ff4662 --- /dev/null +++ b/internal/proxy/envoy/xds/ads/service.go @@ -0,0 +1,316 @@ +package ads + +import ( + "context" + "fmt" + "net/url" + "strconv" + "time" + + accesslog "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3" + cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + uri_template "github.com/envoyproxy/go-control-plane/envoy/extensions/path/match/uri_template/v3" + matcherv3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + "github.com/envoyproxy/go-control-plane/pkg/wellknown" + "github.com/goto/shield/core/rule" + "github.com/goto/shield/internal/proxy" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" +) + +type Service struct { + config proxy.Config + repository Repository +} + +func NewService(config proxy.Config, repository Repository) Service { + return Service{ + config: config, + repository: repository, + } +} + +func (s Service) Get(ctx context.Context) (*DiscoveryResource, error) { + ruleset, err := s.repository.Fetch(ctx) + if err != nil { + return &DiscoveryResource{}, err + } + + var clusters []*cluster.Cluster + var listeners []*listener.Listener + var routes []*route.RouteConfiguration + backendmap := make(map[string]bool) + for _, rule := range ruleset { + for _, r := range rule.Rules { + if _, ok := backendmap[r.Backend.Namespace]; ok { + continue + } + backendmap[r.Backend.Namespace] = true + clusters = append(clusters, s.getCluster(r)) + } + } + + routes = append(routes, s.getRoute(ruleset)) + + ls, err := s.getListener() + if err != nil { + return &DiscoveryResource{}, err + } + + listeners = append(listeners, ls) + + return &DiscoveryResource{ + Clusters: clusters, + Listeners: listeners, + Routes: routes, + }, nil +} + +func (s Service) getCluster(rule rule.Rule) *cluster.Cluster { + return &cluster.Cluster{ + ClusterDiscoveryType: &cluster.Cluster_Type{ + Type: cluster.Cluster_LOGICAL_DNS, + }, + DnsLookupFamily: cluster.Cluster_V4_PREFERRED, + Name: rule.Backend.Namespace, + ConnectTimeout: durationpb.New(1 * time.Second), + LoadAssignment: s.getEndpoint(rule), + } +} + +func (s Service) getEndpoint(rule rule.Rule) *endpoint.ClusterLoadAssignment { + host, port, err := resolveHostPort(rule.Backend.URL) + if err != nil { + return nil + } + + lbEndpoint := &endpoint.LbEndpoint{ + HostIdentifier: &endpoint.LbEndpoint_Endpoint{ + Endpoint: &endpoint.Endpoint{ + Hostname: host, + Address: &core.Address{ + Address: &core.Address_SocketAddress{ + SocketAddress: &core.SocketAddress{ + Protocol: core.SocketAddress_TCP, + Address: host, + PortSpecifier: &core.SocketAddress_PortValue{ + PortValue: port, + }, + }, + }, + }, + }, + }, + } + + lbEndpoints := &endpoint.LocalityLbEndpoints{ + LbEndpoints: []*endpoint.LbEndpoint{lbEndpoint}, + } + + return &endpoint.ClusterLoadAssignment{ + ClusterName: rule.Backend.Namespace, + Endpoints: []*endpoint.LocalityLbEndpoints{lbEndpoints}, + } +} + +func (s Service) getRoute(ruleset []rule.Ruleset) *route.RouteConfiguration { + vh := &route.VirtualHost{ + Name: s.config.Name, + Domains: []string{"*"}, + Routes: []*route.Route{}, + } + + rc := &route.RouteConfiguration{ + Name: s.config.Name, + VirtualHosts: []*route.VirtualHost{vh}, + } + + for _, rule := range ruleset { + for _, r := range rule.Rules { + host, _, err := resolveHostPort(r.Backend.URL) + if err != nil { + continue + } + headerMatcher := &route.HeaderMatcher{ + Name: ":method", + HeaderMatchSpecifier: &route.HeaderMatcher_StringMatch{ + StringMatch: &matcherv3.StringMatcher{ + MatchPattern: &matcherv3.StringMatcher_Exact{ + Exact: r.Frontend.Method, + }, + }, + }, + } + + pathTemplate := uri_template.UriTemplateMatchConfig{ + PathTemplate: r.Frontend.URL, + } + + pathTemplateBytes, err := proto.Marshal(&pathTemplate) + if err != nil { + continue + } + + rt := &route.Route{ + Match: &route.RouteMatch{ + PathSpecifier: &route.RouteMatch_PathMatchPolicy{ + PathMatchPolicy: &core.TypedExtensionConfig{ + Name: "envoy.extensions.path.match.uri_template.v3.UriTemplateMatchConfig", + TypedConfig: &anypb.Any{ + TypeUrl: URI_TEMPLATE_TYPE_URL, + Value: pathTemplateBytes, + }, + }, + }, + Headers: []*route.HeaderMatcher{ + headerMatcher, + }, + }, + } + if r.Backend.Prefix != "" { + rt.Action = &route.Route_Route{ + Route: &route.RouteAction{ + ClusterSpecifier: &route.RouteAction_Cluster{ + Cluster: r.Backend.Namespace, + }, + HostRewriteSpecifier: &route.RouteAction_HostRewriteLiteral{ + HostRewriteLiteral: host, + }, + RegexRewrite: &matcherv3.RegexMatchAndSubstitute{ + Pattern: &matcherv3.RegexMatcher{ + Regex: fmt.Sprintf("^(%s)(/.+$)", r.Backend.Prefix), + }, + Substitution: "\\2", + }, + }, + } + } else { + rt.Action = &route.Route_Route{ + Route: &route.RouteAction{ + ClusterSpecifier: &route.RouteAction_Cluster{ + Cluster: r.Backend.Namespace, + }, + HostRewriteSpecifier: &route.RouteAction_HostRewriteLiteral{ + HostRewriteLiteral: host, + }, + }, + } + } + vh.Routes = append(vh.Routes, rt) + } + } + + return rc +} + +func (s Service) getListener() (*listener.Listener, error) { + ads := core.ConfigSource{ + ConfigSourceSpecifier: &core.ConfigSource_Ads{ + Ads: &core.AggregatedConfigSource{}, + }, + } + + routerFilter := &http_connection_manager.HttpFilter{ + Name: wellknown.Router, + ConfigType: &http_connection_manager.HttpFilter_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: ROUTER_TYPE_URL, + }, + }, + } + + al := accesslog.AccessLog{ + Name: "envoy.access_loggers.stdout", + ConfigType: &accesslog.AccessLog_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: STDOUT_LOGGER_TYPE_URL, + }, + }, + } + + httpConnManager := http_connection_manager.HttpConnectionManager{ + CodecType: http_connection_manager.HttpConnectionManager_AUTO, + StatPrefix: "http", + AccessLog: []*accesslog.AccessLog{&al}, + RouteSpecifier: &http_connection_manager.HttpConnectionManager_Rds{ + Rds: &http_connection_manager.Rds{ + ConfigSource: &ads, + RouteConfigName: s.config.Name, + }, + }, + HttpFilters: []*http_connection_manager.HttpFilter{ + routerFilter, + }, + } + + httpConnManagerBytes, err := proto.Marshal(&httpConnManager) + if err != nil { + return &listener.Listener{}, err + } + + filterChain := &listener.FilterChain{ + Filters: []*listener.Filter{ + { + Name: wellknown.HTTPConnectionManager, + ConfigType: &listener.Filter_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: HTTP_CONNECTION_MANAGER_TYPE_URL, + Value: httpConnManagerBytes, + }, + }, + }, + }, + } + + ls := &listener.Listener{ + Name: s.config.Name, + Address: &core.Address{ + Address: &core.Address_SocketAddress{ + SocketAddress: &core.SocketAddress{ + Protocol: core.SocketAddress_TCP, + Address: s.config.Host, + PortSpecifier: &core.SocketAddress_PortValue{ + PortValue: uint32(s.config.Port), + }, + }, + }, + }, + FilterChains: []*listener.FilterChain{filterChain}, + } + + return ls, nil +} + +func (s Service) IsUpdated(ctx context.Context, since time.Time) bool { + return s.repository.IsUpdated(ctx, since) +} + +func resolveHostPort(urlString string) (string, uint32, error) { + parsed, err := url.Parse(urlString) + if err != nil { + return "", 0, err + } + + port := parsed.Port() + if parsed.Port() == "" { + switch parsed.Scheme { + case "https": + return parsed.Host, 443, nil + default: + return parsed.Host, 80, nil + } + } + + uintPort, err := strconv.ParseUint(port, 10, 32) + if err != nil { + return "", 0, err + } + + return parsed.Hostname(), uint32(uintPort), nil +} diff --git a/internal/proxy/envoy/xds/ads/service_test.go b/internal/proxy/envoy/xds/ads/service_test.go new file mode 100644 index 000000000..277df34f0 --- /dev/null +++ b/internal/proxy/envoy/xds/ads/service_test.go @@ -0,0 +1,322 @@ +package ads_test + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + accesslog "github.com/envoyproxy/go-control-plane/envoy/config/accesslog/v3" + cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + core "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + http_connection_manager "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + uri_template "github.com/envoyproxy/go-control-plane/envoy/extensions/path/match/uri_template/v3" + matcherv3 "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + "github.com/envoyproxy/go-control-plane/pkg/wellknown" + "github.com/goto/shield/core/rule" + "github.com/goto/shield/internal/proxy" + "github.com/goto/shield/internal/proxy/envoy/xds/ads" + "github.com/goto/shield/internal/proxy/envoy/xds/ads/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" +) + +var ( + testConfig = proxy.Config{ + Name: "test-proxy", + Port: 5556, + Host: "0.0.0.0", + } + + testRule = rule.Rule{ + Frontend: rule.Frontend{ + URL: "/shield/test", + Method: "GET", + }, + Backend: rule.Backend{ + URL: "http://localhost:8080", + Namespace: "shield", + Prefix: "/shield", + }, + Middlewares: rule.MiddlewareSpecs{}, + Hooks: rule.HookSpecs{}, + } + + testDiscoveryResource = ads.DiscoveryResource{ + Clusters: []*cluster.Cluster{testCluster}, + Listeners: []*listener.Listener{testListener}, + Routes: []*route.RouteConfiguration{testRouteConfiguration}, + } + + testCluster = &cluster.Cluster{ + ClusterDiscoveryType: &cluster.Cluster_Type{ + Type: cluster.Cluster_LOGICAL_DNS, + }, + DnsLookupFamily: cluster.Cluster_V4_PREFERRED, + Name: "shield", + ConnectTimeout: durationpb.New(1 * time.Second), + LoadAssignment: &testCLA, + } + + testLbEndppoint = &endpoint.LbEndpoint{ + HostIdentifier: &endpoint.LbEndpoint_Endpoint{ + Endpoint: &endpoint.Endpoint{ + Hostname: "localhost", + Address: &core.Address{ + Address: &core.Address_SocketAddress{ + SocketAddress: &core.SocketAddress{ + Protocol: core.SocketAddress_TCP, + Address: "localhost", + PortSpecifier: &core.SocketAddress_PortValue{ + PortValue: 8080, + }, + }, + }, + }, + }, + }, + } + + testLbEndppoints = &endpoint.LocalityLbEndpoints{ + LbEndpoints: []*endpoint.LbEndpoint{testLbEndppoint}, + } + + testCLA = endpoint.ClusterLoadAssignment{ + ClusterName: "shield", + Endpoints: []*endpoint.LocalityLbEndpoints{testLbEndppoints}, + } + + testListener = &listener.Listener{ + Name: "test-proxy", + Address: &core.Address{ + Address: &core.Address_SocketAddress{ + SocketAddress: &core.SocketAddress{ + Protocol: core.SocketAddress_TCP, + Address: "0.0.0.0", + PortSpecifier: &core.SocketAddress_PortValue{ + PortValue: 5556, + }, + }, + }, + }, + FilterChains: []*listener.FilterChain{testFilterChain}, + } + + testAds = core.ConfigSource{ + ConfigSourceSpecifier: &core.ConfigSource_Ads{ + Ads: &core.AggregatedConfigSource{}, + }, + } + + testRouterFilter = &http_connection_manager.HttpFilter{ + Name: wellknown.Router, + ConfigType: &http_connection_manager.HttpFilter_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: ads.ROUTER_TYPE_URL, + }, + }, + } + + testAL = accesslog.AccessLog{ + Name: "envoy.access_loggers.stdout", + ConfigType: &accesslog.AccessLog_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: ads.STDOUT_LOGGER_TYPE_URL, + }, + }, + } + + testHttpConnManager = http_connection_manager.HttpConnectionManager{ + CodecType: http_connection_manager.HttpConnectionManager_AUTO, + StatPrefix: "http", + AccessLog: []*accesslog.AccessLog{&testAL}, + RouteSpecifier: &http_connection_manager.HttpConnectionManager_Rds{ + Rds: &http_connection_manager.Rds{ + ConfigSource: &testAds, + RouteConfigName: "test-proxy", + }, + }, + HttpFilters: []*http_connection_manager.HttpFilter{ + testRouterFilter, + }, + } + + testHttpConnManagerBytes, _ = proto.Marshal(&testHttpConnManager) + + testFilterChain = &listener.FilterChain{ + Filters: []*listener.Filter{ + { + Name: wellknown.HTTPConnectionManager, + ConfigType: &listener.Filter_TypedConfig{ + TypedConfig: &anypb.Any{ + TypeUrl: ads.HTTP_CONNECTION_MANAGER_TYPE_URL, + Value: testHttpConnManagerBytes, + }, + }, + }, + }, + } + + testHeaderMatcher = &route.HeaderMatcher{ + Name: ":method", + HeaderMatchSpecifier: &route.HeaderMatcher_StringMatch{ + StringMatch: &matcherv3.StringMatcher{ + MatchPattern: &matcherv3.StringMatcher_Exact{ + Exact: "GET", + }, + }, + }, + } + + testPathTemplate = uri_template.UriTemplateMatchConfig{ + PathTemplate: "/shield/test", + } + + testPathTemplateBytes, _ = proto.Marshal(&testPathTemplate) + + testRoute = &route.Route{ + Match: &route.RouteMatch{ + PathSpecifier: &route.RouteMatch_PathMatchPolicy{ + PathMatchPolicy: &core.TypedExtensionConfig{ + Name: "envoy.extensions.path.match.uri_template.v3.UriTemplateMatchConfig", + TypedConfig: &anypb.Any{ + TypeUrl: ads.URI_TEMPLATE_TYPE_URL, + Value: testPathTemplateBytes, + }, + }, + }, + Headers: []*route.HeaderMatcher{ + testHeaderMatcher, + }, + }, + Action: &route.Route_Route{ + Route: &route.RouteAction{ + ClusterSpecifier: &route.RouteAction_Cluster{ + Cluster: "shield", + }, + HostRewriteSpecifier: &route.RouteAction_HostRewriteLiteral{ + HostRewriteLiteral: "localhost", + }, + RegexRewrite: &matcherv3.RegexMatchAndSubstitute{ + Pattern: &matcherv3.RegexMatcher{ + Regex: fmt.Sprintf("^(%s)(/.+$)", "/shield"), + }, + Substitution: "\\2", + }, + }, + }, + } + + testVH = &route.VirtualHost{ + Name: "test-proxy", + Domains: []string{"*"}, + Routes: []*route.Route{testRoute}, + } + + testRouteConfiguration = &route.RouteConfiguration{ + Name: "test-proxy", + VirtualHosts: []*route.VirtualHost{testVH}, + } +) + +func TestGet(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + setup func(t *testing.T) ads.Service + want *ads.DiscoveryResource + wantErr error + }{ + { + name: "should return discovery resource", + setup: func(t *testing.T) ads.Service { + t.Helper() + repository := &mocks.Repository{} + repository.EXPECT().Fetch(mock.Anything).Return([]rule.Ruleset{ + { + Rules: []rule.Rule{testRule}, + }, + }, nil) + return ads.NewService(testConfig, repository) + }, + want: &testDiscoveryResource, + wantErr: nil, + }, + { + name: "should return discovery resource", + setup: func(t *testing.T) ads.Service { + t.Helper() + repository := &mocks.Repository{} + repository.EXPECT().Fetch(mock.Anything).Return([]rule.Ruleset{}, rule.ErrMarshal) + return ads.NewService(testConfig, repository) + }, + want: &ads.DiscoveryResource{}, + wantErr: rule.ErrMarshal, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + ctx := context.Background() + got, err := svc.Get(ctx) + if tt.wantErr != nil { + assert.Error(t, err) + assert.True(t, errors.Is(err, tt.wantErr)) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, tt.want, got) + }) + } +} + +func TestIsUpdated(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + since time.Time + setup func(t *testing.T) ads.Service + want bool + }{ + { + name: "should return discovery resource", + since: time.Time{}, + setup: func(t *testing.T) ads.Service { + t.Helper() + repository := &mocks.Repository{} + repository.EXPECT().IsUpdated(mock.Anything, time.Time{}).Return(true) + return ads.NewService(testConfig, repository) + }, + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := tt.setup(t) + + assert.NotNil(t, svc) + ctx := context.Background() + got := svc.IsUpdated(ctx, tt.since) + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/proxy/envoy/xds/ads/stream.go b/internal/proxy/envoy/xds/ads/stream.go new file mode 100644 index 000000000..1d96757cb --- /dev/null +++ b/internal/proxy/envoy/xds/ads/stream.go @@ -0,0 +1,179 @@ +package ads + +import ( + "context" + "io" + "strconv" + "time" + + cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + xds "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "github.com/goto/salt/log" +) + +type DiscoveryResource struct { + Clusters []*cluster.Cluster + Listeners []*listener.Listener + Routes []*route.RouteConfiguration +} + +type Client struct { + NodeID string + LastUpdated time.Time +} + +type Stream struct { + ctx context.Context + cancel func() + logger log.Logger + stream xds.AggregatedDiscoveryService_StreamAggregatedResourcesServer + client Client + services map[string]Service + messageChan MessageChan + refreshInterval time.Duration +} + +func NewStream(logger log.Logger, refreshInterval time.Duration, stream xds.AggregatedDiscoveryService_StreamAggregatedResourcesServer, services map[string]Service) Stream { + ctx, cancel := context.WithCancel(context.Background()) + return Stream{ + ctx: ctx, + cancel: cancel, + logger: logger, + stream: stream, + services: services, + messageChan: make(MessageChan), + refreshInterval: refreshInterval, + } +} + +func (s Stream) Stream() error { + terminate := make(chan bool) + + go func() { + for { + select { + case <-s.ctx.Done(): + return + default: + in, err := s.stream.Recv() + if err == io.EOF { + return + } + + if err != nil { + s.logger.Error(err.Error()) + return + } + + if in.ResponseNonce == "" { + s.logger.Info("received request on stream", "typeurl", in.TypeUrl) + message := Message{ + NodeID: in.Node.Id, + VersionInfo: strconv.FormatInt(time.Now().UnixNano(), 10), + Nonce: strconv.FormatInt(time.Now().UnixNano(), 10), + TypeUrl: in.TypeUrl, + } + s.messageChan.Push(message) + s.client.LastUpdated = time.Now() + + if s.client.NodeID == "" { + s.client.NodeID = in.Node.Id + go s.PushUpdatePeriodically() + } + } else if in.ErrorDetail == nil { + s.logger.Info("received ACK on stream", "typeurl", in.TypeUrl, "version_info", in.VersionInfo) + } else { + s.logger.Info("received NACK on stream", "typeurl", in.TypeUrl, "version_info", in.VersionInfo, "error", in.ErrorDetail) + } + } + } + }() + + go func() { + for e := range s.messageChan { + if err := s.streamResponses(e); err != nil { + s.logger.Debug("error while streaming response", "error", err) + } + } + }() + + go func() { + <-s.stream.Context().Done() + close(s.messageChan) + s.cancel() + terminate <- true + }() + <-terminate + return nil +} + +func (s Stream) streamResponses(message Message) error { + cfg := &DiscoveryResource{} + var err error + if repo, ok := s.services[message.NodeID]; ok { + cfg, err = repo.Get(s.ctx) + if err != nil { + return err + } + } + + responseStream := NewResponseStream(s.stream, message.VersionInfo, message.Nonce) + switch message.TypeUrl { + case resource.ClusterType: + if err := responseStream.StreamCDS(cfg.Clusters); err != nil { + return err + } + case resource.ListenerType: + if err := responseStream.StreamLDS(cfg.Listeners); err != nil { + return err + } + case ROUTER_TYPE_URL: + if err := responseStream.StreamRDS(cfg.Routes); err != nil { + return err + } + default: + if err := responseStream.StreamCDS(cfg.Clusters); err != nil { + return err + } + if err := responseStream.StreamLDS(cfg.Listeners); err != nil { + return err + } + if err := responseStream.StreamRDS(cfg.Routes); err != nil { + return err + } + } + + return nil +} + +func (s Stream) PushUpdatePeriodically() { + service, ok := s.services[s.client.NodeID] + if !ok { + s.logger.Debug("service not found", "node_id", s.client.NodeID) + return + } + + for { + select { + case <-s.ctx.Done(): + return + default: + time.Sleep(s.refreshInterval) + if service.IsUpdated(s.ctx, s.client.LastUpdated) { + s.logger.Debug("discovery resource update found", "node_id", s.client.NodeID) + message := Message{ + NodeID: s.client.NodeID, + VersionInfo: strconv.FormatInt(time.Now().UnixNano(), 10), + Nonce: strconv.FormatInt(time.Now().UnixNano(), 10), + } + s.messageChan.Push(message) + s.client.LastUpdated = time.Now() + } else { + s.logger.Debug("no discovery resource update", "node_id", s.client.NodeID) + } + } + } +} diff --git a/internal/proxy/envoy/xds/xds.go b/internal/proxy/envoy/xds/xds.go new file mode 100644 index 000000000..0db9e1de1 --- /dev/null +++ b/internal/proxy/envoy/xds/xds.go @@ -0,0 +1,36 @@ +package xds + +import ( + "context" + "fmt" + "net" + + xds "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/goto/salt/log" + "github.com/goto/shield/internal/proxy" + "github.com/goto/shield/internal/proxy/envoy/xds/ads" + "google.golang.org/grpc" +) + +func Serve(ctx context.Context, logger log.Logger, cfg proxy.ServicesConfig, repositories map[string]ads.Repository) error { + xdsURL := fmt.Sprintf("%s:%d", cfg.EnvoyAgent.XDS.Host, cfg.EnvoyAgent.XDS.Port) + logger.Info("[envoy agent] starting envoy xds", "url", xdsURL) + + server := grpc.NewServer() + + services := make(map[string]ads.Service) + for _, c := range cfg.Services { + if repo, ok := repositories[c.Name]; ok { + services[c.Name] = ads.NewService(c, repo) + } + } + xds.RegisterAggregatedDiscoveryServiceServer(server, ads.New(logger, services, cfg.EnvoyAgent.XDS.RefreshInterval)) + + lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", cfg.EnvoyAgent.XDS.Host, cfg.EnvoyAgent.XDS.Port)) + if err != nil { + logger.Error("[envoy agent] envoy xds failed to listen: %v\n", err) + return err + } + + return server.Serve(lis) +} diff --git a/internal/store/blob/rule_repository.go b/internal/store/blob/rule_repository.go index 680a294dd..874ffc920 100644 --- a/internal/store/blob/rule_repository.go +++ b/internal/store/blob/rule_repository.go @@ -22,9 +22,10 @@ type RuleRepository struct { log log.Logger mu *sync.Mutex - cron *cron.Cron - bucket Bucket - cached []rule.Ruleset + cron *cron.Cron + bucket Bucket + cached []rule.Ruleset + updatedAt time.Time } func (repo *RuleRepository) GetAll(ctx context.Context) ([]rule.Ruleset, error) { @@ -40,6 +41,10 @@ func (repo *RuleRepository) GetAll(ctx context.Context) ([]rule.Ruleset, error) return repo.cached, err } +func (repo *RuleRepository) Fetch(ctx context.Context) ([]rule.Ruleset, error) { + return repo.GetAll(ctx) +} + func (repo *RuleRepository) refresh(ctx context.Context) error { var rulesets []rule.Ruleset @@ -96,6 +101,7 @@ func (repo *RuleRepository) refresh(ctx context.Context) error { repo.mu.Lock() repo.cached = rulesets + repo.updatedAt = time.Now() repo.mu.Unlock() repo.log.Debug("rule cache refreshed", "ruleset_count", len(repo.cached)) return nil @@ -128,6 +134,10 @@ func (repo *RuleRepository) Upsert(ctx context.Context, name string, config rule return rule.Config{}, rule.ErrUpsertConfigNotSupported } +func (repo *RuleRepository) IsUpdated(ctx context.Context, lastUpdated time.Time) bool { + return repo.updatedAt.After(lastUpdated) +} + func NewRuleRepository(logger log.Logger, b Bucket) *RuleRepository { return &RuleRepository{ log: logger, diff --git a/internal/store/postgres/resource_repository.go b/internal/store/postgres/resource_repository.go index 788c743ca..63a26f0bd 100644 --- a/internal/store/postgres/resource_repository.go +++ b/internal/store/postgres/resource_repository.go @@ -443,7 +443,7 @@ func (r ResourceRepository) UpsertConfig(ctx context.Context, name string, confi query, params, err := goqu.Insert(TABLE_RESOURCE_CONFIGS).Rows( goqu.Record{"name": name, "config": configJson}, ).OnConflict( - goqu.DoUpdate("name", goqu.Record{"name": name, "config": configJson})).Returning(&RuleConfig{}).ToSQL() + goqu.DoUpdate("name", goqu.Record{"name": name, "config": configJson, "updated_at": goqu.L("now()")})).Returning(&RuleConfig{}).ToSQL() if err != nil { return schema.Config{}, fmt.Errorf("%w: %s", queryErr, err) } diff --git a/internal/store/postgres/rule_repository.go b/internal/store/postgres/rule_repository.go index d77af1022..33bee734d 100644 --- a/internal/store/postgres/rule_repository.go +++ b/internal/store/postgres/rule_repository.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "time" "github.com/doug-martin/goqu/v9" "github.com/goto/shield/core/rule" @@ -40,7 +41,7 @@ func (r *RuleRepository) Upsert(ctx context.Context, name string, config rule.Ru query, params, err := goqu.Insert(TABLE_RULE_CONFIGS).Rows( goqu.Record{"name": name, "config": configJson}, ).OnConflict( - goqu.DoUpdate("name", goqu.Record{"name": name, "config": configJson})).Returning(&RuleConfig{}).ToSQL() + goqu.DoUpdate("name", goqu.Record{"name": name, "config": configJson, "updated_at": goqu.L("now()")})).Returning(&RuleConfig{}).ToSQL() if err != nil { return rule.Config{}, fmt.Errorf("%w: %s", queryErr, err) } @@ -109,7 +110,7 @@ func (r *RuleRepository) InitCache(ctx context.Context) error { if nrCtx != nil { nr := newrelic.DatastoreSegment{ Product: newrelic.DatastorePostgres, - Collection: TABLE_RESOURCES, + Collection: TABLE_RULE_CONFIGS, Operation: "List", StartTime: nrCtx.StartSegmentNow(), } @@ -138,10 +139,96 @@ func (r *RuleRepository) InitCache(ctx context.Context) error { return nil } +func (r *RuleRepository) IsUpdated(ctx context.Context, since time.Time) bool { + query, params, err := dialect.From(TABLE_RULE_CONFIGS).Select(goqu.C("updated_at").Gt(since)).Order(goqu.C("updated_at").Desc()).Limit(1).ToSQL() + if err != nil { + return false + } + + ctx = otelsql.WithCustomAttributes( + ctx, + []attribute.KeyValue{ + attribute.String("db.repository.method", "List"), + attribute.String(string(semconv.DBSQLTableKey), TABLE_RULE_CONFIGS), + }..., + ) + + var isUpdated bool + if err = r.dbc.WithTimeout(ctx, func(ctx context.Context) error { + nrCtx := newrelic.FromContext(ctx) + if nrCtx != nil { + nr := newrelic.DatastoreSegment{ + Product: newrelic.DatastorePostgres, + Collection: TABLE_RULE_CONFIGS, + Operation: "List", + StartTime: nrCtx.StartSegmentNow(), + } + defer nr.End() + } + + return r.dbc.GetContext(ctx, &isUpdated, query, params...) + }); err != nil { + err = checkPostgresError(err) + if !errors.Is(err, sql.ErrNoRows) { + return false + } + } + + isUpdatedTest := isUpdated + return isUpdatedTest +} + func (r *RuleRepository) GetAll(ctx context.Context) ([]rule.Ruleset, error) { return r.cached, nil } +func (r *RuleRepository) Fetch(ctx context.Context) ([]rule.Ruleset, error) { + query, params, err := dialect.From(TABLE_RULE_CONFIGS).ToSQL() + if err != nil { + return []rule.Ruleset{}, err + } + ctx = otelsql.WithCustomAttributes( + ctx, + []attribute.KeyValue{ + attribute.String("db.repository.method", "List"), + attribute.String(string(semconv.DBSQLTableKey), TABLE_RULE_CONFIGS), + }..., + ) + + var ruleConfigModel []RuleConfig + if err = r.dbc.WithTimeout(ctx, func(ctx context.Context) error { + nrCtx := newrelic.FromContext(ctx) + if nrCtx != nil { + nr := newrelic.DatastoreSegment{ + Product: newrelic.DatastorePostgres, + Collection: TABLE_RULE_CONFIGS, + Operation: "List", + StartTime: nrCtx.StartSegmentNow(), + } + defer nr.End() + } + + return r.dbc.SelectContext(ctx, &ruleConfigModel, query, params...) + }); err != nil { + err = checkPostgresError(err) + if !errors.Is(err, sql.ErrNoRows) { + return []rule.Ruleset{}, err + } + } + + rules := []rule.Ruleset{} + for _, ruleConfig := range ruleConfigModel { + rc := ruleConfig.transformToRuleConfig() + var targetRuleset rule.Ruleset + if err := json.Unmarshal([]byte(rc.Config), &targetRuleset); err != nil { + return []rule.Ruleset{}, err + } + rules = append(rules, targetRuleset) + } + + return rules, nil +} + func (r *RuleRepository) WithTransaction(ctx context.Context) context.Context { return r.dbc.WithTransaction(ctx, sql.TxOptions{}) } diff --git a/internal/store/postgres/rule_repository_test.go b/internal/store/postgres/rule_repository_test.go index 5bebc9ea1..609e8c048 100644 --- a/internal/store/postgres/rule_repository_test.go +++ b/internal/store/postgres/rule_repository_test.go @@ -2,7 +2,9 @@ package postgres_test import ( "context" + "encoding/json" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -21,7 +23,7 @@ type RuleRepositoryTestSuite struct { pool *dockertest.Pool resource *dockertest.Resource repository *postgres.RuleRepository - Config []rule.Config + config []rule.Config } func (s *RuleRepositoryTestSuite) SetupSuite() { @@ -34,12 +36,128 @@ func (s *RuleRepositoryTestSuite) SetupSuite() { } s.ctx = context.TODO() + + s.config, err = bootstrapRuleConfig(s.client) + if err != nil { + s.T().Fatal(err) + } + s.repository = postgres.NewRuleRepository(s.client) + err = s.repository.InitCache(s.ctx) + if err != nil { + s.T().Fatal(err) + } +} + +func (s *RuleRepositoryTestSuite) mergeRules() ([]rule.Ruleset, error) { + rules := []rule.Ruleset{} + for _, ruleConfig := range s.config { + var targetRuleset rule.Ruleset + if err := json.Unmarshal([]byte(ruleConfig.Config), &targetRuleset); err != nil { + return []rule.Ruleset{}, err + } + rules = append(rules, targetRuleset) + } - s.Config, err = bootstrapRuleConfig(s.client) + return rules, nil +} + +func (s *RuleRepositoryTestSuite) TestGetAll() { + expected, err := s.mergeRules() if err != nil { s.T().Fatal(err) } + + type testCase struct { + Description string + Expected []rule.Ruleset + ErrString string + } + + testCases := []testCase{ + { + Description: "should get all rules from repository cache", + Expected: expected, + }, + } + + for _, tc := range testCases { + s.Run(tc.Description, func() { + got, err := s.repository.GetAll(s.ctx) + if tc.ErrString != "" { + if err.Error() != tc.ErrString { + s.T().Fatalf("got error %s, expected was %s", err.Error(), tc.ErrString) + } + } + if !cmp.Equal(got, tc.Expected) { + s.T().Fatalf("got result %+v, expected was %+v", got, tc.Expected) + } + }) + } +} + +func (s *RuleRepositoryTestSuite) TestFetch() { + expected, err := s.mergeRules() + if err != nil { + s.T().Fatal(err) + } + + type testCase struct { + Description string + Expected []rule.Ruleset + ErrString string + } + + testCases := []testCase{ + { + Description: "should get all rules from repository cache", + Expected: expected, + }, + } + + for _, tc := range testCases { + s.Run(tc.Description, func() { + got, err := s.repository.Fetch(s.ctx) + if tc.ErrString != "" { + if err.Error() != tc.ErrString { + s.T().Fatalf("got error %s, expected was %s", err.Error(), tc.ErrString) + } + } + if !cmp.Equal(got, tc.Expected) { + s.T().Fatalf("got result %+v, expected was %+v", got, tc.Expected) + } + }) + } +} + +func (s *RuleRepositoryTestSuite) TestIsUpdated() { + type testCase struct { + Description string + Since time.Time + Expected bool + } + + testCases := []testCase{ + { + Description: "should get true if since before last updated", + Since: time.Time{}, + Expected: true, + }, + { + Description: "should get false if since after last updated", + Since: s.config[0].UpdatedAt.Add(10 * time.Hour), + Expected: false, + }, + } + + for _, tc := range testCases { + s.Run(tc.Description, func() { + got := s.repository.IsUpdated(s.ctx, tc.Since) + if !cmp.Equal(got, tc.Expected) { + s.T().Fatalf("got result %+v, expected was %+v", got, tc.Expected) + } + }) + } } func (s *RuleRepositoryTestSuite) TestUpsert() { @@ -66,13 +184,13 @@ func (s *RuleRepositoryTestSuite) TestUpsert() { }, { Description: "should update a resource config", - Name: s.Config[0].Name, + Name: s.config[0].Name, Config: rule.Ruleset{ Rules: []rule.Rule{{}}, }, Expected: rule.Config{ - ID: s.Config[0].ID, - Name: s.Config[0].Name, + ID: s.config[0].ID, + Name: s.config[0].Name, Config: "{\"Rules\": [{\"Hooks\": null, \"Backend\": {\"URL\": \"\", \"Prefix\": \"\", \"Namespace\": \"\"}, \"Frontend\": {\"URL\": \"\", \"URLRx\": null, \"Method\": \"\"}, \"Middlewares\": null}]}", }, },