From 43c18b4e8198590281fcebb3547b5d56d6f62f49 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 25 Jun 2024 15:06:02 +0900 Subject: [PATCH 01/47] wip --- awsiface/iface.go | 7 +++++ cage.go | 3 ++ canary_task.go | 76 ++++++++++++++++++++++++++++++++++------------- go.mod | 7 +++-- go.sum | 8 +++++ rollout.go | 4 +++ 6 files changed, 81 insertions(+), 24 deletions(-) diff --git a/awsiface/iface.go b/awsiface/iface.go index d377f3e..24cc2bd 100644 --- a/awsiface/iface.go +++ b/awsiface/iface.go @@ -6,6 +6,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ecs" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" ) type ( @@ -36,8 +37,14 @@ type ( DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) } + SrvClient interface { + DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) + RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) + DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) + } ) var _ EcsClient = (*ecs.Client)(nil) var _ AlbClient = (*elbv2.Client)(nil) var _ Ec2Client = (*ec2.Client)(nil) +var _ SrvClient = (*servicediscovery.Client)(nil) diff --git a/cage.go b/cage.go index ee01f01..b472725 100644 --- a/cage.go +++ b/cage.go @@ -23,6 +23,7 @@ type cage struct { Ecs awsiface.EcsClient Alb awsiface.AlbClient Ec2 awsiface.Ec2Client + Srv awsiface.SrvClient Time Time MaxWait time.Duration } @@ -32,6 +33,7 @@ type Input struct { ECS awsiface.EcsClient ALB awsiface.AlbClient EC2 awsiface.Ec2Client + SRV awsiface.SrvClient Time Time MaxWait time.Duration } @@ -45,6 +47,7 @@ func NewCage(input *Input) Cage { Ecs: input.ECS, Alb: input.ALB, Ec2: input.EC2, + Srv: input.SRV, Time: input.Time, MaxWait: 5 * time.Minute, } diff --git a/canary_task.go b/canary_task.go index aacbd6b..5e6ebf5 100644 --- a/canary_task.go +++ b/canary_task.go @@ -13,11 +13,13 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" "golang.org/x/xerrors" ) type CanaryTarget struct { - targetGroupArn *string + // targetGroupArn *string + // serviceRegistry *string targetId *string targetPort *int32 availabilityZone *string @@ -27,6 +29,7 @@ type CanaryTask struct { *cage td *ecstypes.TaskDefinition lb *ecstypes.LoadBalancer + srv *ecstypes.ServiceRegistry networkConfiguration *ecstypes.NetworkConfiguration platformVersion *string taskArn *string @@ -127,6 +130,25 @@ func (c *CanaryTask) waitForIdleDuration(ctx context.Context) error { return nil } +func (c *CanaryTask) waitUntilSrvInstanceHelthCheckPassed(ctx context.Context) error { + task := c.taskArn + attrs := map[string]string{ + "AWS_INSTANCE_IPV4": "", + "AVAILABILITY_ZONE": c.Env.Region, + "AWS_INIT_HEALTH_STATUS": "UNHEALTHY", + "ECS_CLUSTER_NAME": c.Env.Cluster, + "ECS_SERVICE_NAME": c.Env.Service, + "ECS_TASK_DEFINITION_FAMILY": *c.td.Family, + "REGION": c.Env.Region, + } + if c.srv.ContainerName + c.Srv.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ + Attributes: attrs, + ServiceId: c.serviceRegistry, + InstanceId: c.taskArn, + }) +} + func (c *CanaryTask) waitUntilHealthCeheckPassed(ctx context.Context) error { log.Infof("😷 ensuring canary task container(s) to become healthy...") containerHasHealthChecks := map[string]struct{}{} @@ -167,27 +189,32 @@ func (c *CanaryTask) waitUntilHealthCeheckPassed(ctx context.Context) error { return xerrors.Errorf("😨 canary task hasn't become to be healthy") } -func (c *CanaryTask) registerToTargetGroup(ctx context.Context) error { +func (c *CanaryTask) describeTaskTarget(ctx context.Context) (*CanaryTarget, error) { // Phase 3: Get task details after network interface is attached var task ecstypes.Task + var result CanaryTarget if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &c.Env.Cluster, Tasks: []string{*c.taskArn}, }); err != nil { - return err + return nil, err } else { task = o.Tasks[0] } var targetId *string var targetPort *int32 var subnet ec2types.Subnet - for _, container := range c.td.ContainerDefinitions { - if *container.Name == *c.lb.ContainerName { - targetPort = container.PortMappings[0].HostPort + if c.lb != nil { + for _, container := range c.td.ContainerDefinitions { + if *container.Name == *c.lb.ContainerName { + targetPort = container.PortMappings[0].HostPort + } } + } else if c.serviceRegistry != nil { + targetPort = aws.Int32(80) } if targetPort == nil { - return xerrors.Errorf("couldn't find host port in container definition") + return nil, xerrors.Errorf("couldn't find host port in container definition") } if c.Env.CanaryInstanceArn == "" { // Fargate details := task.Attachments[0].Details @@ -201,12 +228,12 @@ func (c *CanaryTask) registerToTargetGroup(ctx context.Context) error { } } if subnetId == nil || privateIp == nil { - return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") + return nil, xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") } if o, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ SubnetIds: []string{*subnetId}, }); err != nil { - return err + return nil, err } else { subnet = o.Subnets[0] } @@ -218,40 +245,47 @@ func (c *CanaryTask) registerToTargetGroup(ctx context.Context) error { Cluster: &c.Env.Cluster, ContainerInstances: []string{c.Env.CanaryInstanceArn}, }); err != nil { - return err + return nil, err } else { containerInstance = outputs.ContainerInstances[0] } if o, err := c.Ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ InstanceIds: []string{*containerInstance.Ec2InstanceId}, }); err != nil { - return err + return nil, err } else if sn, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ SubnetIds: []string{*o.Reservations[0].Instances[0].SubnetId}, }); err != nil { - return err + return nil, err } else { targetId = containerInstance.Ec2InstanceId subnet = sn.Subnets[0] } log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", *targetId, *targetPort, *subnet.AvailabilityZone) } + return &CanaryTarget{ + targetId: targetId, + targetPort: targetPort, + availabilityZone: subnet.AvailabilityZone, + }, nil +} + +func (c *CanaryTask) registerToTargetGroup(ctx context.Context) error { + info, err := c.describeTaskTarget(ctx) + if err != nil { + return err + } if _, err := c.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ TargetGroupArn: c.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: subnet.AvailabilityZone, - Id: targetId, - Port: targetPort, + AvailabilityZone: info.availabilityZone, + Id: info.targetId, + Port: info.targetPort, }}, }); err != nil { return err } - c.target = &CanaryTarget{ - targetGroupArn: c.lb.TargetGroupArn, - targetId: targetId, - targetPort: targetPort, - availabilityZone: subnet.AvailabilityZone, - } + c.target = info return nil } diff --git a/go.mod b/go.mod index 0f41e7c..bd7dc09 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/apex/log v1.9.0 - github.com/aws/aws-sdk-go-v2 v1.27.0 + github.com/aws/aws-sdk-go-v2 v1.30.0 github.com/aws/aws-sdk-go-v2/config v1.27.16 github.com/aws/aws-sdk-go-v2/service/ec2 v1.161.4 github.com/aws/aws-sdk-go-v2/service/ecs v1.41.11 @@ -23,11 +23,12 @@ require ( github.com/Masterminds/semver/v3 v3.2.1 github.com/aws/aws-sdk-go-v2/credentials v1.17.16 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 // indirect + github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.30.1 github.com/aws/aws-sdk-go-v2/service/sso v1.20.9 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.28.10 // indirect diff --git a/go.sum b/go.sum index 1649eec..760eb4b 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3st github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2 v1.30.0 h1:6qAwtzlfcTtcL8NHtbDQAqgM5s6NDipQTkPxyH/6kAA= +github.com/aws/aws-sdk-go-v2 v1.30.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2/config v1.27.16 h1:knpCuH7laFVGYTNd99Ns5t+8PuRjDn4HnnZK48csipM= github.com/aws/aws-sdk-go-v2/config v1.27.16/go.mod h1:vutqgRhDUktwSge3hrC3nkuirzkJ4E/mLj5GvI0BQas= github.com/aws/aws-sdk-go-v2/credentials v1.17.16 h1:7d2QxY83uYl0l58ceyiSpxg9bSbStqBC6BeEeHEchwo= @@ -16,8 +18,12 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 h1:dQLK4TjtnlRGb0czOht2Cev github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3/go.mod h1:TL79f2P6+8Q7dTsILpiVST+AL9lkF6PPGI167Ny0Cjw= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12 h1:SJ04WXGTwnHlWIODtC5kJzKbeuHt+OUNOgKg7nfnUGw= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12/go.mod h1:FkpvXhA92gb3GE9LD6Og0pHHycTxW7xGpnEh5E7Opwo= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12 h1:hb5KgeYfObi5MHkSSZMEudnIvX30iB+E21evI4r6BnQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12/go.mod h1:CroKe/eWJdyfy9Vx4rljP5wTUjNJfb+fPz1uMYUhEGM= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= github.com/aws/aws-sdk-go-v2/service/ec2 v1.161.4 h1:JBcPadBAnSwqUZQ1o2XOkTXy7GBcidpupkXZf02parw= @@ -30,6 +36,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1x github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 h1:Wx0rlZoEJR7JwlSZcHnEa7CNjrSIyVxMFWGAaXy4fJY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9/go.mod h1:aVMHdE0aHO3v+f/iw01fmXV/5DbfQ3Bi9nN7nd9bE9Y= +github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.30.1 h1:N19/J0IqsoNlkbXLe+JYWLjOyGmRijt6dw0+MaL/9wE= +github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.30.1/go.mod h1:uuMsqZ2ATDqrzaAldWWuEUd9KGqi1NmnjroG6Eoe7W4= github.com/aws/aws-sdk-go-v2/service/sso v1.20.9 h1:aD7AGQhvPuAxlSUfo0CWU7s6FpkbyykMhGYMvlqTjVs= github.com/aws/aws-sdk-go-v2/service/sso v1.20.9/go.mod h1:c1qtZUWtygI6ZdvKppzCSXsDOq5I4luJPZ0Ud3juFCA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.3 h1:Pav5q3cA260Zqez42T9UhIlsd9QeypszRPwC9LdSSsQ= diff --git a/rollout.go b/rollout.go index 47e9601..3421687 100644 --- a/rollout.go +++ b/rollout.go @@ -129,10 +129,12 @@ func (c *cage) StartCanaryTasks( var networkConfiguration *ecstypes.NetworkConfiguration var platformVersion *string var loadBalancers []ecstypes.LoadBalancer + var serviceRegistries []ecstypes.ServiceRegistry if input.UpdateService { networkConfiguration = c.Env.ServiceDefinitionInput.NetworkConfiguration platformVersion = c.Env.ServiceDefinitionInput.PlatformVersion loadBalancers = c.Env.ServiceDefinitionInput.LoadBalancers + serviceRegistries = c.Env.ServiceDefinitionInput.ServiceRegistries } else { if o, err := c.Ecs.DescribeServices(ctx, &ecs.DescribeServicesInput{ Cluster: &c.Env.Cluster, @@ -144,9 +146,11 @@ func (c *cage) StartCanaryTasks( networkConfiguration = service.NetworkConfiguration platformVersion = service.PlatformVersion loadBalancers = service.LoadBalancers + serviceRegistries = service.ServiceRegistries } } var results []*CanaryTask + if len(loadBalancers) == 0 { task := &CanaryTask{ c, nextTaskDefinition, nil, networkConfiguration, platformVersion, nil, nil, From cc3e62744232f7d7f9cd3898692ec87c970fe91c Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 27 Jun 2024 17:55:02 +0900 Subject: [PATCH 02/47] w --- awsiface/iface.go | 1 + cage.go | 7 +- canary/alb_task.go | 188 +++++++++++++++++++ canary/common.go | 248 ++++++++++++++++++++++++ canary/simple_task.go | 55 ++++++ canary/srv_task.go | 147 +++++++++++++++ canary_task.go | 424 ------------------------------------------ rollout.go | 66 ++++--- rollout_test.go | 27 +-- run.go | 18 +- run_test.go | 13 +- types/iface.go | 69 +++++++ up.go | 10 +- 13 files changed, 780 insertions(+), 493 deletions(-) create mode 100644 canary/alb_task.go create mode 100644 canary/common.go create mode 100644 canary/simple_task.go create mode 100644 canary/srv_task.go delete mode 100644 canary_task.go create mode 100644 types/iface.go diff --git a/awsiface/iface.go b/awsiface/iface.go index 24cc2bd..2dc2fa9 100644 --- a/awsiface/iface.go +++ b/awsiface/iface.go @@ -41,6 +41,7 @@ type ( DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) + GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) } ) diff --git a/cage.go b/cage.go index 58eca31..4e8469f 100644 --- a/cage.go +++ b/cage.go @@ -6,12 +6,13 @@ import ( "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" ) type Cage interface { - Up(ctx context.Context) (*UpResult, error) - Run(ctx context.Context, input *RunInput) (*RunResult, error) - RollOut(ctx context.Context, input *RollOutInput) (*RollOutResult, error) + Up(ctx context.Context) (*types.UpResult, error) + Run(ctx context.Context, input *types.RunInput) (*types.RunResult, error) + RollOut(ctx context.Context, input *types.RollOutInput) (*types.RollOutResult, error) } type Time interface { diff --git a/canary/alb_task.go b/canary/alb_task.go new file mode 100644 index 0000000..8f8829b --- /dev/null +++ b/canary/alb_task.go @@ -0,0 +1,188 @@ +package canary + +import ( + "context" + "strconv" + "time" + + "github.com/apex/log" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "golang.org/x/xerrors" +) + +type albTask struct { + *common + lb *ecstypes.LoadBalancer + target *CanaryTarget +} + +func NewAlbTask(input *Input, + lb *ecstypes.LoadBalancer, +) Task { + return &albTask{ + common: &common{Input: input}, + lb: lb, + } +} + +func (c *albTask) Wait(ctx context.Context) error { + if err := c.wait(ctx); err != nil { + return err + } + if err := c.registerToTargetGroup(ctx); err != nil { + return err + } + log.Infof("😷 ensuring canary task to become healthy...") + if err := c.waitUntilTargetHealthy(ctx); err != nil { + return err + } + log.Info("🤩 canary task is healthy!") + return nil +} + +func (c *albTask) Stop(ctx context.Context) error { + if c.target == nil { + log.Info("no target is registered. skip deregisteration.") + } else { + deregistrationDelay, err := c.targetDeregistrationDelay(ctx) + if err != nil { + log.Errorf("failed to get deregistration delay: %v", err) + log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) + } + log.Infof("deregistering the canary task from target group '%s'...", c.target.targetId) + if _, err := c.Alb.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + AvailabilityZone: &c.target.availabilityZone, + Id: &c.target.targetId, + Port: &c.target.targetPort, + }}, + }); err != nil { + log.Errorf("failed to deregister the canary task from target group: %v", err) + log.Errorf("continuing to stop the canary task...") + } else { + log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") + deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety + if err := elbv2.NewTargetDeregisteredWaiter(c.Alb).Wait(ctx, &elbv2.DescribeTargetHealthInput{ + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + AvailabilityZone: &c.target.availabilityZone, + Id: &c.target.targetId, + Port: &c.target.targetPort, + }}, + }, deregisterWait); err != nil { + log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) + log.Errorf("continuing to stop the canary task...") + } else { + log.Infof( + "canary task '%s' has successfully been deregistered from target group '%s'", + *c.taskArn, c.target.targetId, + ) + } + } + } + return c.stopTask(ctx) +} + +func (c *albTask) getTargetPort(ctx context.Context) (int32, error) { + for _, container := range c.TaskDefinition.ContainerDefinitions { + if *container.Name == *c.lb.ContainerName { + return *container.PortMappings[0].ContainerPort, nil + } + } + return 0, xerrors.Errorf("couldn't find host port in container definition") +} + +func (c *albTask) registerToTargetGroup(ctx context.Context) error { + info, err := c.describeTaskTarget(ctx, c.getTargetPort) + if err != nil { + return err + } + if _, err := c.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + AvailabilityZone: &info.availabilityZone, + Id: &info.targetId, + Port: &info.targetPort, + }}, + }); err != nil { + return err + } + c.target = info + return nil +} + +func (c *albTask) waitUntilTargetHealthy( + ctx context.Context, +) error { + log.Infof("checking the health state of canary task...") + var unusedCount = 0 + var initialized = false + var recentState *elbv2types.TargetHealthStateEnum + for { + <-c.Time.NewTimer(time.Duration(15) * time.Second).C + if o, err := c.Alb.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + Id: &c.target.targetId, + Port: &c.target.targetPort, + AvailabilityZone: &c.target.availabilityZone, + }}, + }); err != nil { + return err + } else { + for _, desc := range o.TargetHealthDescriptions { + if *desc.Target.Id == c.target.targetId && *desc.Target.Port == c.target.targetPort { + recentState = &desc.TargetHealth.State + } + } + if recentState == nil { + return xerrors.Errorf("'%s' is not registered to the target group '%s'", c.target.targetId, *c.lb.TargetGroupArn) + } + log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.target.targetId, c.target.targetPort, *recentState) + switch *recentState { + case "healthy": + return nil + case "initial": + initialized = true + log.Infof("still checking the state...") + continue + case "unused": + unusedCount++ + if !initialized && unusedCount < 5 { + continue + } + default: + } + } + // unhealthy, draining, unused + log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) + return xerrors.Errorf( + "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", + *c.taskArn, c.target.targetId, c.target.targetPort, *recentState, + ) + } +} + +func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, error) { + deregistrationDelay := 300 * time.Second + if o, err := c.Alb.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ + TargetGroupArn: c.lb.TargetGroupArn, + }); err != nil { + return deregistrationDelay, err + } else { + // find deregistration_delay.timeout_seconds + for _, attr := range o.Attributes { + if *attr.Key == "deregistration_delay.timeout_seconds" { + if value, err := strconv.ParseInt(*attr.Value, 10, 64); err != nil { + return deregistrationDelay, err + } else { + deregistrationDelay = time.Duration(value) * time.Second + } + } + } + } + return deregistrationDelay, nil +} diff --git a/canary/common.go b/canary/common.go new file mode 100644 index 0000000..f904a3c --- /dev/null +++ b/canary/common.go @@ -0,0 +1,248 @@ +package canary + +import ( + "context" + "fmt" + "time" + + "github.com/apex/log" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "golang.org/x/xerrors" +) + +type CanaryTarget struct { + targetId string + targetIpv4 string + targetPort int32 + availabilityZone string +} + +type Task interface { + Start(ctx context.Context) error + Wait(ctx context.Context) error + Stop(ctx context.Context) error +} + +type Input struct { + *types.Input + TaskDefinition *ecstypes.TaskDefinition + Network *ecstypes.NetworkConfiguration + PlatformVersion *string + Timeout timeout.Manager +} + +type common struct { + *Input + taskArn *string +} + +func NewSimpleTask(input *Input) Task { + return &simpleTask{common: &common{Input: input}} +} + +func (c *common) Start(ctx context.Context) error { + if c.Env.CanaryInstanceArn != "" { + // ec2 + startTask := &ecs.StartTaskInput{ + Cluster: &c.Env.Cluster, + Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), + NetworkConfiguration: c.Network, + TaskDefinition: c.TaskDefinition.TaskDefinitionArn, + ContainerInstances: []string{c.Env.CanaryInstanceArn}, + } + if o, err := c.Ecs.StartTask(ctx, startTask); err != nil { + return err + } else { + c.taskArn = o.Tasks[0].TaskArn + } + } else { + // fargate + if o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ + Cluster: &c.Env.Cluster, + Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), + NetworkConfiguration: c.Network, + TaskDefinition: c.TaskDefinition.TaskDefinitionArn, + LaunchType: ecstypes.LaunchTypeFargate, + PlatformVersion: c.PlatformVersion, + }); err != nil { + return err + } else { + c.taskArn = o.Tasks[0].TaskArn + } + } + return nil +} + +func (c *common) wait(ctx context.Context) error { + log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) + if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }, c.Timeout.TaskRunning()); err != nil { + return err + } + log.Infof("🐣 canary task '%s' is running!", *c.taskArn) + if err := c.waitUntilHealthCheckPassed(ctx); err != nil { + return err + } + log.Info("🤩 canary task container(s) is healthy!") + log.Infof("canary task '%s' ensured.", *c.taskArn) + return nil +} + +func (c *common) waitUntilHealthCheckPassed(ctx context.Context) error { + log.Infof("😷 ensuring canary task container(s) to become healthy...") + containerHasHealthChecks := map[string]struct{}{} + for _, definition := range c.TaskDefinition.ContainerDefinitions { + if definition.HealthCheck != nil { + containerHasHealthChecks[*definition.Name] = struct{}{} + } + } + healthCheckWait := c.Timeout.TaskHealthCheck() + healthCheckPeriod := 15 * time.Second + countPerPeriod := int(healthCheckWait.Seconds() / 15) + for count := 0; count < countPerPeriod; count++ { + <-c.Time.NewTimer(healthCheckPeriod).C + log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) + if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }); err != nil { + return err + } else { + task := o.Tasks[0] + if *task.LastStatus != "RUNNING" { + return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) + } + + for _, container := range task.Containers { + if _, ok := containerHasHealthChecks[*container.Name]; !ok { + continue + } + if container.HealthStatus != ecstypes.HealthStatusHealthy { + log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) + continue + } + delete(containerHasHealthChecks, *container.Name) + } + if len(containerHasHealthChecks) == 0 { + return nil + } + } + } + return xerrors.Errorf("😨 canary task hasn't become to be healthy") +} + +func (c *common) describeTaskTarget( + ctx context.Context, + getTargetPort func(ctx context.Context) (int32, error), +) (*CanaryTarget, error) { + var targetPort int32 + if v, err := getTargetPort(ctx); err != nil { + return nil, err + } else { + targetPort = v + } + target := CanaryTarget{targetPort: targetPort} + if c.Env.CanaryInstanceArn == "" { // Fargate + if err := c.getFargateTarget(ctx, &target); err != nil { + return nil, err + } + log.Infof("canary task was placed: privateIp = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + } else { + if err := c.getEc2Target(ctx, &target); err != nil { + return nil, err + } + log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + } + return &target, nil +} + +func (c *common) getFargateTarget(ctx context.Context, dest *CanaryTarget) error { + var task ecstypes.Task + if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }); err != nil { + return err + } else { + task = o.Tasks[0] + } + details := task.Attachments[0].Details + var subnetId *string + var privateIp *string + for _, v := range details { + if *v.Name == "subnetId" { + subnetId = v.Value + } else if *v.Name == "privateIPv4Address" { + privateIp = v.Value + } + } + if subnetId == nil || privateIp == nil { + return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") + } + if o, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: []string{*subnetId}, + }); err != nil { + return err + } else { + dest.targetId = *privateIp + dest.targetIpv4 = *privateIp + dest.availabilityZone = *o.Subnets[0].AvailabilityZone + } + return nil +} + +func (c *common) getEc2Target(ctx context.Context, dest *CanaryTarget) error { + var containerInstance ecstypes.ContainerInstance + if outputs, err := c.Ecs.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ + Cluster: &c.Env.Cluster, + ContainerInstances: []string{c.Env.CanaryInstanceArn}, + }); err != nil { + return err + } else { + containerInstance = outputs.ContainerInstances[0] + } + var ec2Instance ec2types.Instance + if o, err := c.Ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + InstanceIds: []string{*containerInstance.Ec2InstanceId}, + }); err != nil { + return err + } else { + ec2Instance = o.Reservations[0].Instances[0] + } + if sn, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: []string{*ec2Instance.SubnetId}, + }); err != nil { + return err + } else { + dest.targetId = *containerInstance.Ec2InstanceId + dest.targetIpv4 = *ec2Instance.PrivateIpAddress + dest.availabilityZone = *sn.Subnets[0].AvailabilityZone + } + return nil +} + +func (c *common) stopTask(ctx context.Context) error { + log.Infof("stopping the canary task '%s'...", *c.taskArn) + if _, err := c.Ecs.StopTask(ctx, &ecs.StopTaskInput{ + Cluster: &c.Env.Cluster, + Task: c.taskArn, + }); err != nil { + return err + } + if err := ecs.NewTasksStoppedWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }, c.Timeout.TaskStopped()); err != nil { + return err + } + log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) + return nil +} diff --git a/canary/simple_task.go b/canary/simple_task.go new file mode 100644 index 0000000..1189833 --- /dev/null +++ b/canary/simple_task.go @@ -0,0 +1,55 @@ +package canary + +import ( + "context" + "time" + + "github.com/apex/log" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "golang.org/x/xerrors" +) + +type simpleTask struct { + *common +} + +func (c *simpleTask) Wait(ctx context.Context) error { + if err := c.wait(ctx); err != nil { + return err + } + return c.waitForIdleDuration(ctx) +} + +func (c *simpleTask) Stop(ctx context.Context) error { + return c.stopTask(ctx) +} + +func (c *simpleTask) waitForIdleDuration(ctx context.Context) error { + log.Infof("wait %d seconds for canary task to be stable...", c.Env.CanaryTaskIdleDuration) + duration := c.Env.CanaryTaskIdleDuration + for duration > 0 { + wt := 10 + if duration < 10 { + wt = duration + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.Time.NewTimer(time.Duration(wt) * time.Second).C: + duration -= 10 + } + log.Infof("still waiting...; %d seconds left", duration) + } + o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }) + if err != nil { + return err + } + task := o.Tasks[0] + if *task.LastStatus != "RUNNING" { + return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) + } + return nil +} diff --git a/canary/srv_task.go b/canary/srv_task.go new file mode 100644 index 0000000..835c295 --- /dev/null +++ b/canary/srv_task.go @@ -0,0 +1,147 @@ +package canary + +import ( + "context" + "regexp" + "time" + + "github.com/apex/log" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" + srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" + "golang.org/x/xerrors" +) + +type srvTask struct { + *common + registry *ecstypes.ServiceRegistry + target *CanaryTarget + srv *srvtypes.Service + inst *srvtypes.HttpInstanceSummary +} + +func NewSrvTask(input *Input, registry *ecstypes.ServiceRegistry) Task { + return &srvTask{ + common: &common{Input: input}, + registry: registry, + } +} + +func (c *srvTask) Wait(ctx context.Context) error { + if err := c.wait(ctx); err != nil { + return err + } + if err := c.registerToSrvDiscovery(ctx); err != nil { + return err + } + log.Infof("😷 ensuring canary task to become healthy...") + if err := c.waitUntilSrvInstHelthy(ctx); err != nil { + return err + } + log.Info("🤩 canary task is healthy!") + return nil +} + +func (c *srvTask) Stop(ctx context.Context) error { + if err := c.deregisterSrvInst(ctx); err != nil { + return err + } + return c.stopTask(ctx) +} + +func (c *srvTask) getTargetPort(ctx context.Context) (int32, error) { + if c.registry.Port != nil { + return *c.registry.Port, nil + } + return 80, nil +} + +func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { + target, err := c.describeTaskTarget(ctx, c.getTargetPort) + if err != nil { + return err + } + c.target = target + // get the service id from service registry arn + pat := regexp.MustCompile("arn://.+/(srv-.+)$") + matches := pat.FindStringSubmatch(*c.registry.RegistryArn) + if len(matches) != 2 { + return xerrors.Errorf("service name '%s' doesn't match the pattern", c.Env.Service) + } + srvId := matches[1] + var svc *srvtypes.Service + if o, err := c.Srv.GetService(ctx, &servicediscovery.GetServiceInput{ + Id: &srvId, + }); err != nil { + return xerrors.Errorf("failed to get the service: %w", err) + } else { + svc = o.Service + } + attrs := map[string]string{ + "AWS_INSTANCE_IPV4": target.targetIpv4, + "AVAILABILITY_ZONE": target.availabilityZone, + "AWS_INIT_HEALTH_STATUS": "UNHEALTHY", + "ECS_CLUSTER_NAME": c.Env.Cluster, + "ECS_SERVICE_NAME": c.Env.Service, + "ECS_TASK_DEFINITION_FAMILY": *c.TaskDefinition.Family, + "REGION": c.Env.Region, + "CAGE_CANARY_TASK": "1", + } + if _, err := c.Srv.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ + ServiceId: &srvId, + InstanceId: c.taskArn, + Attributes: attrs, + }); err != nil { + return xerrors.Errorf("failed to register the canary task to service discovery: %w", err) + } + c.srv = svc + return nil +} + +func (c *srvTask) waitUntilSrvInstHelthy( + ctx context.Context, +) error { + var maxWait = 900 + var waitPeriod = 15 + for maxWait > 0 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.Time.NewTimer(time.Duration(waitPeriod) * time.Second).C: + if list, err := c.Srv.DiscoverInstances(ctx, &servicediscovery.DiscoverInstancesInput{ + NamespaceName: c.inst.NamespaceName, + ServiceName: c.inst.ServiceName, + HealthStatus: srvtypes.HealthStatusFilterHealthy, + QueryParameters: map[string]string{ + "CAGE_CANARY_TASK": "1", + }, + }); err != nil { + return xerrors.Errorf("failed to discover instances: %w", err) + } else { + if len(list.Instances) == 0 { + return xerrors.Errorf("no healthy instances found") + } + for _, inst := range list.Instances { + if ipv4 := inst.Attributes["AWS_INSTANCE_IPV4"]; ipv4 == c.target.targetIpv4 { + c.inst = &inst + return nil + } + } + maxWait -= waitPeriod + } + } + } + return xerrors.Errorf("timed out waiting for healthy instances") +} + +func (c *srvTask) deregisterSrvInst( + ctx context.Context, +) error { + if _, err := c.Srv.DeregisterInstance(ctx, &servicediscovery.DeregisterInstanceInput{ + ServiceId: c.srv.Id, + InstanceId: c.inst.InstanceId, + }); err != nil { + return xerrors.Errorf("failed to deregister the canary task from service discovery: %w", err) + } + return nil +} diff --git a/canary_task.go b/canary_task.go deleted file mode 100644 index 4e04035..0000000 --- a/canary_task.go +++ /dev/null @@ -1,424 +0,0 @@ -package cage - -import ( - "context" - "fmt" - "strconv" - "time" - - "github.com/apex/log" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go-v2/service/ecs" - ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" - "github.com/aws/aws-sdk-go-v2/service/servicediscovery" - "golang.org/x/xerrors" -) - -type CanaryTarget struct { - // targetGroupArn *string - // serviceRegistry *string - targetId *string - targetPort *int32 - availabilityZone *string -} - -type CanaryTask struct { - *cage - td *ecstypes.TaskDefinition - lb *ecstypes.LoadBalancer - srv *ecstypes.ServiceRegistry - networkConfiguration *ecstypes.NetworkConfiguration - platformVersion *string - taskArn *string - target *CanaryTarget -} - -func (c *CanaryTask) Start(ctx context.Context) error { - if c.Env.CanaryInstanceArn != "" { - // ec2 - startTask := &ecs.StartTaskInput{ - Cluster: &c.Env.Cluster, - Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), - NetworkConfiguration: c.networkConfiguration, - TaskDefinition: c.td.TaskDefinitionArn, - ContainerInstances: []string{c.Env.CanaryInstanceArn}, - } - if o, err := c.Ecs.StartTask(ctx, startTask); err != nil { - return err - } else { - c.taskArn = o.Tasks[0].TaskArn - } - } else { - // fargate - if o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ - Cluster: &c.Env.Cluster, - Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), - NetworkConfiguration: c.networkConfiguration, - TaskDefinition: c.td.TaskDefinitionArn, - LaunchType: ecstypes.LaunchTypeFargate, - PlatformVersion: c.platformVersion, - }); err != nil { - return err - } else { - c.taskArn = o.Tasks[0].TaskArn - } - } - return nil -} - -func (c *CanaryTask) Wait(ctx context.Context) error { - log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) - if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }, c.Timeout.TaskRunning()); err != nil { - return err - } - log.Infof("🐣 canary task '%s' is running!", *c.taskArn) - if err := c.waitUntilHealthCheckPassed(ctx); err != nil { - return err - } - log.Info("🤩 canary task container(s) is healthy!") - log.Infof("canary task '%s' ensured.", *c.taskArn) - if c.lb == nil { - log.Infof("no load balancer is attached to service '%s'. skip registration to target group", c.Env.Service) - return c.waitForIdleDuration(ctx) - } else { - if err := c.registerToTargetGroup(ctx); err != nil { - return err - } - log.Infof("😷 ensuring canary task to become healthy...") - if err := c.waitUntilTargetHealthy(ctx); err != nil { - return err - } - log.Info("🤩 canary task is healthy!") - return nil - } -} - -func (c *CanaryTask) waitForIdleDuration(ctx context.Context) error { - log.Infof("wait %d seconds for canary task to be stable...", c.Env.CanaryTaskIdleDuration) - duration := c.Env.CanaryTaskIdleDuration - for duration > 0 { - wt := 10 - if duration < 10 { - wt = duration - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-c.Time.NewTimer(time.Duration(wt) * time.Second).C: - duration -= 10 - } - log.Infof("still waiting...; %d seconds left", duration) - } - o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }) - if err != nil { - return err - } - task := o.Tasks[0] - if *task.LastStatus != "RUNNING" { - return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) - } - return nil -} - -func (c *CanaryTask) waitUntilSrvInstanceHelthCheckPassed(ctx context.Context) error { - task := c.taskArn - attrs := map[string]string{ - "AWS_INSTANCE_IPV4": "", - "AVAILABILITY_ZONE": c.Env.Region, - "AWS_INIT_HEALTH_STATUS": "UNHEALTHY", - "ECS_CLUSTER_NAME": c.Env.Cluster, - "ECS_SERVICE_NAME": c.Env.Service, - "ECS_TASK_DEFINITION_FAMILY": *c.td.Family, - "REGION": c.Env.Region, - } - if c.srv.ContainerName - c.Srv.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ - Attributes: attrs, - ServiceId: c.serviceRegistry, - InstanceId: c.taskArn, - }) -} - -func (c *CanaryTask) waitUntilHealthCheckPassed(ctx context.Context) error { - log.Infof("😷 ensuring canary task container(s) to become healthy...") - containerHasHealthChecks := map[string]struct{}{} - for _, definition := range c.td.ContainerDefinitions { - if definition.HealthCheck != nil { - containerHasHealthChecks[*definition.Name] = struct{}{} - } - } - healthCheckWait := c.Timeout.TaskHealthCheck() - healthCheckPeriod := 15 * time.Second - countPerPeriod := int(healthCheckWait.Seconds() / 15) - for count := 0; count < countPerPeriod; count++ { - <-c.Time.NewTimer(healthCheckPeriod).C - log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) - if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }); err != nil { - return err - } else { - task := o.Tasks[0] - if *task.LastStatus != "RUNNING" { - return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) - } - - for _, container := range task.Containers { - if _, ok := containerHasHealthChecks[*container.Name]; !ok { - continue - } - if container.HealthStatus != ecstypes.HealthStatusHealthy { - log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) - continue - } - delete(containerHasHealthChecks, *container.Name) - } - if len(containerHasHealthChecks) == 0 { - return nil - } - } - } - return xerrors.Errorf("😨 canary task hasn't become to be healthy") -} - -func (c *CanaryTask) describeTaskTarget(ctx context.Context) (*CanaryTarget, error) { - // Phase 3: Get task details after network interface is attached - var task ecstypes.Task - var result CanaryTarget - if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }); err != nil { - return nil, err - } else { - task = o.Tasks[0] - } - var targetId *string - var targetPort *int32 - var subnet ec2types.Subnet - if c.lb != nil { - for _, container := range c.td.ContainerDefinitions { - if *container.Name == *c.lb.ContainerName { - targetPort = container.PortMappings[0].HostPort - } - } - } else if c.serviceRegistry != nil { - targetPort = aws.Int32(80) - } - if targetPort == nil { - return nil, xerrors.Errorf("couldn't find host port in container definition") - } - if c.Env.CanaryInstanceArn == "" { // Fargate - details := task.Attachments[0].Details - var subnetId *string - var privateIp *string - for _, v := range details { - if *v.Name == "subnetId" { - subnetId = v.Value - } else if *v.Name == "privateIPv4Address" { - privateIp = v.Value - } - } - if subnetId == nil || privateIp == nil { - return nil, xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") - } - if o, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*subnetId}, - }); err != nil { - return nil, err - } else { - subnet = o.Subnets[0] - } - targetId = privateIp - log.Infof("canary task was placed: privateIp = '%s', hostPort = '%d', az = '%s'", *targetId, *targetPort, *subnet.AvailabilityZone) - } else { - var containerInstance ecstypes.ContainerInstance - if outputs, err := c.Ecs.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ - Cluster: &c.Env.Cluster, - ContainerInstances: []string{c.Env.CanaryInstanceArn}, - }); err != nil { - return nil, err - } else { - containerInstance = outputs.ContainerInstances[0] - } - if o, err := c.Ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ - InstanceIds: []string{*containerInstance.Ec2InstanceId}, - }); err != nil { - return nil, err - } else if sn, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*o.Reservations[0].Instances[0].SubnetId}, - }); err != nil { - return nil, err - } else { - targetId = containerInstance.Ec2InstanceId - subnet = sn.Subnets[0] - } - log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", *targetId, *targetPort, *subnet.AvailabilityZone) - } - return &CanaryTarget{ - targetId: targetId, - targetPort: targetPort, - availabilityZone: subnet.AvailabilityZone, - }, nil -} - -func (c *CanaryTask) registerToTargetGroup(ctx context.Context) error { - info, err := c.describeTaskTarget(ctx) - if err != nil { - return err - } - if _, err := c.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ - TargetGroupArn: c.lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: info.availabilityZone, - Id: info.targetId, - Port: info.targetPort, - }}, - }); err != nil { - return err - } - c.target = info - return nil -} - -func (c *CanaryTask) waitUntilTargetHealthy( - ctx context.Context, -) error { - log.Infof("checking the health state of canary task...") - var unusedCount = 0 - var initialized = false - var recentState *elbv2types.TargetHealthStateEnum - for { - <-c.Time.NewTimer(time.Duration(15) * time.Second).C - if o, err := c.Alb.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.target.targetGroupArn, - Targets: []elbv2types.TargetDescription{{ - Id: c.target.targetId, - Port: c.target.targetPort, - AvailabilityZone: c.target.availabilityZone, - }}, - }); err != nil { - return err - } else { - for _, desc := range o.TargetHealthDescriptions { - if *desc.Target.Id == *c.target.targetId && *desc.Target.Port == *c.target.targetPort { - recentState = &desc.TargetHealth.State - } - } - if recentState == nil { - return xerrors.Errorf("'%s' is not registered to the target group '%s'", c.target.targetId, c.target.targetGroupArn) - } - log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.target.targetId, c.target.targetPort, *recentState) - switch *recentState { - case "healthy": - return nil - case "initial": - initialized = true - log.Infof("still checking the state...") - continue - case "unused": - unusedCount++ - if !initialized && unusedCount < 5 { - continue - } - default: - } - } - // unhealthy, draining, unused - log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) - return xerrors.Errorf( - "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", - *c.taskArn, c.target.targetId, c.target.targetPort, *recentState, - ) - } -} - -func (c *CanaryTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, error) { - deregistrationDelay := 300 * time.Second - if o, err := c.Alb.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ - TargetGroupArn: c.target.targetGroupArn, - }); err != nil { - return deregistrationDelay, err - } else { - // find deregistration_delay.timeout_seconds - for _, attr := range o.Attributes { - if *attr.Key == "deregistration_delay.timeout_seconds" { - if value, err := strconv.ParseInt(*attr.Value, 10, 64); err != nil { - return deregistrationDelay, err - } else { - deregistrationDelay = time.Duration(value) * time.Second - } - } - } - } - return deregistrationDelay, nil -} - -func (c *CanaryTask) Stop(ctx context.Context) error { - if c.target == nil { - log.Info("no load balancer is attached to service. Skip deregisteration.") - } else { - deregistrationDelay, err := c.targetDeregistrationDelay(ctx) - if err != nil { - log.Errorf("failed to get deregistration delay: %v", err) - log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) - } - log.Infof("deregistering the canary task from target group '%s'...", c.target.targetId) - if _, err := c.Alb.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ - TargetGroupArn: c.target.targetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: c.target.availabilityZone, - Id: c.target.targetId, - Port: c.target.targetPort, - }}, - }); err != nil { - log.Errorf("failed to deregister the canary task from target group: %v", err) - log.Errorf("continuing to stop the canary task...") - } else { - log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") - deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety - if err := elbv2.NewTargetDeregisteredWaiter(c.Alb).Wait(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.target.targetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: c.target.availabilityZone, - Id: c.target.targetId, - Port: c.target.targetPort, - }}, - }, deregisterWait); err != nil { - log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) - log.Errorf("continuing to stop the canary task...") - } else { - log.Infof( - "canary task '%s' has successfully been deregistered from target group '%s'", - *c.taskArn, *c.target.targetId, - ) - } - } - } - log.Infof("stopping the canary task '%s'...", *c.taskArn) - if _, err := c.Ecs.StopTask(ctx, &ecs.StopTaskInput{ - Cluster: &c.Env.Cluster, - Task: c.taskArn, - }); err != nil { - return err - } - if err := ecs.NewTasksStoppedWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }, c.Timeout.TaskStopped()); err != nil { - return err - } - log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) - return nil -} diff --git a/rollout.go b/rollout.go index 444788b..4bba347 100644 --- a/rollout.go +++ b/rollout.go @@ -6,21 +6,14 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/canary" + "github.com/loilo-inc/canarycage/types" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" ) -type RollOutInput struct { - // UpdateService is a flag to update service with changed configurations except for task definition - UpdateService bool -} - -type RollOutResult struct { - ServiceIntact bool -} - -func (c *cage) RollOut(ctx context.Context, input *RollOutInput) (*RollOutResult, error) { - result := &RollOutResult{ +func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.RollOutResult, error) { + result := &types.RollOutResult{ ServiceIntact: true, } if out, err := c.Ecs.DescribeServices(ctx, &ecs.DescribeServicesInput{ @@ -124,8 +117,8 @@ func (c *cage) RollOut(ctx context.Context, input *RollOutInput) (*RollOutResult func (c *cage) StartCanaryTasks( ctx context.Context, nextTaskDefinition *ecstypes.TaskDefinition, - input *RollOutInput, -) ([]*CanaryTask, error) { + input *types.RollOutInput, +) ([]canary.Task, error) { var networkConfiguration *ecstypes.NetworkConfiguration var platformVersion *string var loadBalancers []ecstypes.LoadBalancer @@ -149,25 +142,44 @@ func (c *cage) StartCanaryTasks( serviceRegistries = service.ServiceRegistries } } - var results []*CanaryTask - - if len(loadBalancers) == 0 { - task := &CanaryTask{ - c, nextTaskDefinition, nil, networkConfiguration, platformVersion, nil, nil, + var results []canary.Task + for _, lb := range loadBalancers { + task := canary.NewAlbTask(&canary.Input{ + Input: c.Input, + Network: networkConfiguration, + TaskDefinition: nextTaskDefinition, + PlatformVersion: platformVersion, + Timeout: c.Timeout, + }, &lb) + results = append(results, task) + if err := task.Start(ctx); err != nil { + return results, err } + } + for _, srv := range serviceRegistries { + task := canary.NewSrvTask(&canary.Input{ + Input: c.Input, + Network: networkConfiguration, + TaskDefinition: nextTaskDefinition, + PlatformVersion: platformVersion, + Timeout: c.Timeout, + }, &srv) results = append(results, task) if err := task.Start(ctx); err != nil { return results, err } - } else { - for _, lb := range loadBalancers { - task := &CanaryTask{ - c, nextTaskDefinition, &lb, networkConfiguration, platformVersion, nil, nil, - } - results = append(results, task) - if err := task.Start(ctx); err != nil { - return results, err - } + } + if len(results) == 0 { + task := canary.NewSimpleTask(&canary.Input{ + Input: c.Input, + Network: networkConfiguration, + TaskDefinition: nextTaskDefinition, + PlatformVersion: platformVersion, + Timeout: c.Timeout, + }) + results = append(results, task) + if err := task.Start(ctx); err != nil { + return results, err } } return results, nil diff --git a/rollout_test.go b/rollout_test.go index d81cead..6b45c2b 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -16,6 +16,7 @@ import ( cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/types" "github.com/stretchr/testify/assert" ) @@ -44,7 +45,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - result, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NoError(t, err) assert.False(t, result.ServiceIntact) assert.Equal(t, 1, mctx.ActiveServiceSize()) @@ -67,7 +68,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - result, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NoError(t, err) assert.False(t, result.ServiceIntact) assert.Equal(t, 1, mctx.ActiveServiceSize()) @@ -106,7 +107,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - result, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NoError(t, err) assert.NotNil(t, result) }) @@ -150,7 +151,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - _, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NotNil(t, err) }) t.Run("update service", func(t *testing.T) { @@ -183,7 +184,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { assert.Equal(t, "1.4.0", *service.PlatformVersion) assert.NotNil(t, service.NetworkConfiguration) assert.NotNil(t, service.LoadBalancers) - _, err := cagecli.RollOut(ctx, &cage.RollOutInput{UpdateService: true}) + _, err := cagecli.RollOut(ctx, &types.RollOutInput{UpdateService: true}) assert.NoError(t, err) service, _ = mctx.GetService(envars.Service) assert.Equal(t, "LATEST", *service.PlatformVersion) @@ -209,7 +210,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { TargetGroupArn: aws.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:targetgroup/new-target-group/abcdefg"), }, } - result, err := cagecli.RollOut(ctx, &cage.RollOutInput{UpdateService: true}) + result, err := cagecli.RollOut(ctx, &types.RollOutInput{UpdateService: true}) assert.EqualError(t, err, "failed to wait for canary task due to: couldn't find host port in container definition") assert.Equal(t, result.ServiceIntact, true) assert.Equal(t, 1, mctx.RunningTaskSize()) @@ -226,7 +227,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Alb: albMock, }) ctx := context.Background() - _, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.EqualError(t, err, "service 'service' doesn't exist. Run 'cage up' or create service before rolling out") }) t.Run("Roll out even if the service does not have a load balancer", func(t *testing.T) { @@ -243,7 +244,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - if res, err := cagecli.RollOut(ctx, &cage.RollOutInput{}); err != nil { + if res, err := cagecli.RollOut(ctx, &types.RollOutInput{}); err != nil { t.Fatalf(err.Error()) } else if res.ServiceIntact { t.Fatalf("no") @@ -266,7 +267,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Ecs: ecsMock, Time: test.NewFakeTime(), }) - _, err := cagecli.RollOut(context.Background(), &cage.RollOutInput{}) + _, err := cagecli.RollOut(context.Background(), &types.RollOutInput{}) assert.EqualError(t, err, "😵 'service' status is 'INACTIVE'. Stop rolling out") }) t.Run("Stop rolling out if the canary task container does not become healthy", func(t *testing.T) { @@ -312,7 +313,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - res, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + res, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NotNil(t, res) assert.NotNil(t, err) @@ -358,7 +359,7 @@ func TestCage_RollOut_EC2(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - result, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) if err != nil { t.Fatalf("%s", err) } @@ -388,7 +389,7 @@ func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - result, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) if err == nil { t.Fatal("Rollout with no container instance should be error") } else { @@ -423,7 +424,7 @@ func TestCage_RollOut_EC2_no_attribute(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - result, err := cagecli.RollOut(ctx, &cage.RollOutInput{}) + result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) if err != nil { t.Fatalf("%s", err) } diff --git a/run.go b/run.go index 4aff00c..038a4d5 100644 --- a/run.go +++ b/run.go @@ -6,19 +6,11 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecs" - "github.com/aws/aws-sdk-go-v2/service/ecs/types" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) -type RunInput struct { - Container *string - Overrides *types.TaskOverride -} - -type RunResult struct { - ExitCode int32 -} - func containerExistsInDefinition(td *ecs.RegisterTaskDefinitionInput, container *string) bool { for _, v := range td.ContainerDefinitions { if *v.Name == *container { @@ -28,7 +20,7 @@ func containerExistsInDefinition(td *ecs.RegisterTaskDefinitionInput, container return false } -func (c *cage) Run(ctx context.Context, input *RunInput) (*RunResult, error) { +func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult, error) { if !containerExistsInDefinition(c.Env.TaskDefinitionInput, input.Container) { return nil, xerrors.Errorf("🚫 '%s' not found in container definitions", *input.Container) } @@ -39,7 +31,7 @@ func (c *cage) Run(ctx context.Context, input *RunInput) (*RunResult, error) { o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ Cluster: &c.Env.Cluster, TaskDefinition: td.TaskDefinitionArn, - LaunchType: types.LaunchTypeFargate, + LaunchType: ecstypes.LaunchTypeFargate, NetworkConfiguration: c.Env.ServiceDefinitionInput.NetworkConfiguration, PlatformVersion: c.Env.ServiceDefinitionInput.PlatformVersion, Overrides: input.Overrides, @@ -72,7 +64,7 @@ func (c *cage) Run(ctx context.Context, input *RunInput) (*RunResult, error) { } else if *c.ExitCode != 0 { return nil, xerrors.Errorf("task exited with %d", *c.ExitCode) } - return &RunResult{ExitCode: *c.ExitCode}, nil + return &types.RunResult{ExitCode: *c.ExitCode}, nil } } // Never reached? diff --git a/run_test.go b/run_test.go index d7371de..1fb155b 100644 --- a/run_test.go +++ b/run_test.go @@ -11,6 +11,7 @@ import ( cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/types" "github.com/stretchr/testify/assert" ) @@ -43,7 +44,7 @@ func TestCage_Run(t *testing.T) { Ecs: ecsMock, Time: test.NewFakeTime(), }) - result, err := cagecli.Run(ctx, &cage.RunInput{ + result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, }) @@ -73,7 +74,7 @@ func TestCage_Run(t *testing.T) { Ecs: ecsMock, Time: test.NewFakeTime(), }) - result, err := cagecli.Run(ctx, &cage.RunInput{ + result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, }) @@ -95,7 +96,7 @@ func TestCage_Run(t *testing.T) { Ecs: ecsMock, Time: test.NewFakeTime(), }) - result, err := cagecli.Run(ctx, &cage.RunInput{ + result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, }) @@ -123,7 +124,7 @@ func TestCage_Run(t *testing.T) { Ecs: ecsMock, Time: test.NewFakeTime(), }) - result, err := cagecli.Run(ctx, &cage.RunInput{ + result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, }) @@ -151,7 +152,7 @@ func TestCage_Run(t *testing.T) { Ecs: ecsMock, Time: test.NewFakeTime(), }) - result, err := cagecli.Run(ctx, &cage.RunInput{ + result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, }) @@ -167,7 +168,7 @@ func TestCage_Run(t *testing.T) { Ecs: ecsMock, Time: test.NewFakeTime(), }) - result, err := cagecli.Run(ctx, &cage.RunInput{ + result, err := cagecli.Run(ctx, &types.RunInput{ Container: aws.String("foo"), Overrides: overrides, }) diff --git a/types/iface.go b/types/iface.go new file mode 100644 index 0000000..7ba1773 --- /dev/null +++ b/types/iface.go @@ -0,0 +1,69 @@ +package types + +import ( + "context" + "time" + + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/awsiface" +) + +type Envars struct { + _ struct{} `type:"struct"` + CI bool `json:"ci" type:"bool"` + Region string `json:"region" type:"string"` + Cluster string `json:"cluster" type:"string" required:"true"` + Service string `json:"service" type:"string" required:"true"` + CanaryInstanceArn string + TaskDefinitionArn string `json:"nextTaskDefinitionArn" type:"string"` + TaskDefinitionInput *ecs.RegisterTaskDefinitionInput + ServiceDefinitionInput *ecs.CreateServiceInput + CanaryTaskIdleDuration int // sec + CanaryTaskRunningWait int // sec + CanaryTaskHealthCheckWait int // sec + CanaryTaskStoppedWait int // sec + ServiceStableWait int // sec +} + +type Cage interface { + Up(ctx context.Context) (*UpResult, error) + Run(ctx context.Context, input *RunInput) (*RunResult, error) + RollOut(ctx context.Context, input *RollOutInput) (*RollOutResult, error) +} + +type Time interface { + Now() time.Time + NewTimer(time.Duration) *time.Timer +} + +type Input struct { + Env *Envars + Ecs awsiface.EcsClient + Alb awsiface.AlbClient + Ec2 awsiface.Ec2Client + Srv awsiface.SrvClient + Time Time +} +type RunInput struct { + Container *string + Overrides *ecstypes.TaskOverride +} + +type RunResult struct { + ExitCode int32 +} + +type RollOutInput struct { + // UpdateService is a flag to update service with changed configurations except for task definition + UpdateService bool +} + +type RollOutResult struct { + ServiceIntact bool +} + +type UpResult struct { + TaskDefinition *ecstypes.TaskDefinition + Service *ecstypes.Service +} diff --git a/up.go b/up.go index 62b193c..ae9da4d 100644 --- a/up.go +++ b/up.go @@ -6,15 +6,11 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) -type UpResult struct { - TaskDefinition *ecstypes.TaskDefinition - Service *ecstypes.Service -} - -func (c *cage) Up(ctx context.Context) (*UpResult, error) { +func (c *cage) Up(ctx context.Context) (*types.UpResult, error) { td, err := c.CreateNextTaskDefinition(ctx) if err != nil { return nil, err @@ -35,7 +31,7 @@ func (c *cage) Up(ctx context.Context) (*UpResult, error) { if service, err := c.createService(ctx, c.Env.ServiceDefinitionInput); err != nil { return nil, err } else { - return &UpResult{TaskDefinition: td, Service: service}, nil + return &types.UpResult{TaskDefinition: td, Service: service}, nil } } From 903eff191130f7011b316739ca3191e8c6c057a5 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 27 Jun 2024 18:13:45 +0900 Subject: [PATCH 03/47] w --- Makefile | 6 +- cage.go | 26 +---- canary/common.go | 8 +- cli/cage/commands/command.go | 18 +-- cli/cage/commands/command_test.go | 29 ++--- cli/cage/commands/flags.go | 20 ++-- cli/cage/commands/rollout.go | 11 +- cli/cage/commands/run.go | 7 +- cli/cage/commands/up.go | 4 +- cli/cage/main.go | 4 +- cli/cage/prompt/prompt.go | 8 +- cli/cage/prompt/prompt_test.go | 4 +- env.go => env/env.go | 2 +- env_test.go => env/env_test.go | 41 +++---- util.go => env/util.go | 2 +- util_test.go => env/util_test.go | 9 +- mocks/mock_awsiface/iface.go | 104 ++++++++++++++++++ .../cage.go => mock_types/iface.go} | 20 ++-- rollout.go | 9 +- rollout_test.go | 26 ++--- run_test.go | 15 +-- task_definition_test.go | 14 ++- test/fake_timer.go | 4 +- test/setup.go | 8 +- types/iface.go | 22 +--- up_test.go | 5 +- 26 files changed, 251 insertions(+), 175 deletions(-) rename env.go => env/env.go (99%) rename env_test.go => env/env_test.go (72%) rename util.go => env/util.go (97%) rename util_test.go => env/util_test.go (82%) rename mocks/{mock_cage/cage.go => mock_types/iface.go} (86%) diff --git a/Makefile b/Makefile index bdefd47..2be61e7 100644 --- a/Makefile +++ b/Makefile @@ -9,11 +9,11 @@ push-test-container: test-container docker push loilodev/http-server:latest version: go run cli/cage/main.go -v | cut -f 3 -d ' ' -mocks: mocks/mock_awsiface/iface.go mocks/mock_cage/iface.go mocks/mock_upgrade/upgrade.go +mocks: mocks/mock_awsiface/iface.go mocks/mock_types/iface.go mocks/mock_upgrade/upgrade.go mocks/mock_awsiface/iface.go: awsiface/iface.go $(MOCKGEN) -source=./awsiface/iface.go > mocks/mock_awsiface/iface.go -mocks/mock_cage/iface.go: cage.go - $(MOCKGEN) -source=./cage.go > mocks/mock_cage/cage.go +mocks/mock_types/iface.go: cage.go + $(MOCKGEN) -source=./types/iface.go > mocks/mock_types/iface.go mocks/mock_upgrade/upgrade.go: cli/cage/upgrade/upgrade.go $(MOCKGEN) -source=./cli/cage/upgrade/upgrade.go > mocks/mock_upgrade/upgrade.go .PHONY: mocks diff --git a/cage.go b/cage.go index 4e8469f..8a4f514 100644 --- a/cage.go +++ b/cage.go @@ -1,40 +1,18 @@ package cage import ( - "context" "time" - "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" ) -type Cage interface { - Up(ctx context.Context) (*types.UpResult, error) - Run(ctx context.Context, input *types.RunInput) (*types.RunResult, error) - RollOut(ctx context.Context, input *types.RollOutInput) (*types.RollOutResult, error) -} - -type Time interface { - Now() time.Time - NewTimer(time.Duration) *time.Timer -} - type cage struct { - *Input + *types.Input Timeout timeout.Manager } -type Input struct { - Env *Envars - Ecs awsiface.EcsClient - Alb awsiface.AlbClient - Ec2 awsiface.Ec2Client - Srv awsiface.SrvClient - Time Time -} - -func NewCage(input *Input) Cage { +func NewCage(input *types.Input) types.Cage { if input.Time == nil { input.Time = &timeImpl{} } diff --git a/canary/common.go b/canary/common.go index f904a3c..d9d6911 100644 --- a/canary/common.go +++ b/canary/common.go @@ -230,18 +230,22 @@ func (c *common) getEc2Target(ctx context.Context, dest *CanaryTarget) error { } func (c *common) stopTask(ctx context.Context) error { + if c.taskArn == nil { + log.Info("no canary task to stop") + return nil + } log.Infof("stopping the canary task '%s'...", *c.taskArn) if _, err := c.Ecs.StopTask(ctx, &ecs.StopTaskInput{ Cluster: &c.Env.Cluster, Task: c.taskArn, }); err != nil { - return err + return xerrors.Errorf("failed to stop canary task: %w", err) } if err := ecs.NewTasksStoppedWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &c.Env.Cluster, Tasks: []string{*c.taskArn}, }, c.Timeout.TaskStopped()); err != nil { - return err + return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) } log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) return nil diff --git a/cli/cage/commands/command.go b/cli/cage/commands/command.go index 3f3fd87..d36a9f3 100644 --- a/cli/cage/commands/command.go +++ b/cli/cage/commands/command.go @@ -10,6 +10,8 @@ import ( "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/prompt" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/types" "github.com/urfave/cli/v2" "golang.org/x/xerrors" ) @@ -29,16 +31,16 @@ func NewCageCommands( } } -type cageCliProvier = func(envars *cage.Envars) (cage.Cage, error) +type cageCliProvier = func(envars *env.Envars) (types.Cage, error) -func DefalutCageCliProvider(envars *cage.Envars) (cage.Cage, error) { +func DefalutCageCliProvider(envars *env.Envars) (types.Cage, error) { conf, err := config.LoadDefaultConfig( context.Background(), config.WithRegion(envars.Region)) if err != nil { return nil, xerrors.Errorf("failed to load aws config: %w", err) } - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecs.NewFromConfig(conf), Ec2: ec2.NewFromConfig(conf), @@ -63,20 +65,20 @@ func (c *CageCommands) requireArgs( } func (c *CageCommands) setupCage( - envars *cage.Envars, + envars *env.Envars, dir string, -) (cage.Cage, error) { - td, svc, err := cage.LoadDefinitionsFromFiles(dir) +) (types.Cage, error) { + td, svc, err := env.LoadDefinitionsFromFiles(dir) if err != nil { return nil, err } - cage.MergeEnvars(envars, &cage.Envars{ + env.MergeEnvars(envars, &env.Envars{ Cluster: *svc.Cluster, Service: *svc.ServiceName, TaskDefinitionInput: td, ServiceDefinitionInput: svc, }) - if err := cage.EnsureEnvars(envars); err != nil { + if err := env.EnsureEnvars(envars); err != nil { return nil, err } cagecli, err := c.cageCliProvier(envars) diff --git a/cli/cage/commands/command_test.go b/cli/cage/commands/command_test.go index e6e480c..776c39b 100644 --- a/cli/cage/commands/command_test.go +++ b/cli/cage/commands/command_test.go @@ -6,9 +6,10 @@ import ( "testing" "github.com/golang/mock/gomock" - cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/commands" - "github.com/loilo-inc/canarycage/mocks/mock_cage" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/mocks/mock_types" + "github.com/loilo-inc/canarycage/types" "github.com/stretchr/testify/assert" "github.com/urfave/cli/v2" ) @@ -19,15 +20,15 @@ func TestCommands(t *testing.T) { service := "service" stdinService := fmt.Sprintf("%s\n%s\n%s\n%s\n", region, cluster, service, "yes") stdinTask := fmt.Sprintf("%s\n%s\n%s\n", region, cluster, "yes") - setup := func(t *testing.T, input string) (*cli.App, *mock_cage.MockCage) { + setup := func(t *testing.T, input string) (*cli.App, *mock_types.MockCage) { ctrl := gomock.NewController(t) stdin := strings.NewReader(input) - cagecli := mock_cage.NewMockCage(ctrl) + cagecli := mock_types.NewMockCage(ctrl) app := cli.NewApp() - cmds := commands.NewCageCommands(stdin, func(envars *cage.Envars) (cage.Cage, error) { + cmds := commands.NewCageCommands(stdin, func(envars *env.Envars) (types.Cage, error) { return cagecli, nil }) - envars := cage.Envars{CI: input == ""} + envars := env.Envars{CI: input == ""} app.Commands = []*cli.Command{ cmds.Up(&envars), cmds.RollOut(&envars), @@ -38,25 +39,25 @@ func TestCommands(t *testing.T) { t.Run("rollout", func(t *testing.T) { t.Run("basic", func(t *testing.T) { app, cagecli := setup(t, stdinService) - cagecli.EXPECT().RollOut(gomock.Any(), &cage.RollOutInput{}).Return(&cage.RollOutResult{}, nil) + cagecli.EXPECT().RollOut(gomock.Any(), &types.RollOutInput{}).Return(&types.RollOutResult{}, nil) err := app.Run([]string{"cage", "rollout", "--region", "ap-notheast-1", "../../../fixtures"}) assert.NoError(t, err) }) t.Run("basic/ci", func(t *testing.T) { app, cagecli := setup(t, "") - cagecli.EXPECT().RollOut(gomock.Any(), &cage.RollOutInput{}).Return(&cage.RollOutResult{}, nil) + cagecli.EXPECT().RollOut(gomock.Any(), &types.RollOutInput{}).Return(&types.RollOutResult{}, nil) err := app.Run([]string{"cage", "rollout", "--region", "ap-notheast-1", "../../../fixtures"}) assert.NoError(t, err) }) t.Run("basic/udate-service", func(t *testing.T) { app, cagecli := setup(t, stdinService) - cagecli.EXPECT().RollOut(gomock.Any(), &cage.RollOutInput{UpdateService: true}).Return(&cage.RollOutResult{}, nil) + cagecli.EXPECT().RollOut(gomock.Any(), &types.RollOutInput{UpdateService: true}).Return(&types.RollOutResult{}, nil) err := app.Run([]string{"cage", "rollout", "--region", "ap-notheast-1", "--updateService", "../../../fixtures"}) assert.NoError(t, err) }) t.Run("error", func(t *testing.T) { app, cagecli := setup(t, stdinService) - cagecli.EXPECT().RollOut(gomock.Any(), &cage.RollOutInput{}).Return(&cage.RollOutResult{}, fmt.Errorf("error")) + cagecli.EXPECT().RollOut(gomock.Any(), &types.RollOutInput{}).Return(&types.RollOutResult{}, fmt.Errorf("error")) err := app.Run([]string{"cage", "rollout", "--region", "ap-notheast-1", "../../../fixtures"}) assert.EqualError(t, err, "error") }) @@ -64,13 +65,13 @@ func TestCommands(t *testing.T) { t.Run("up", func(t *testing.T) { t.Run("basic", func(t *testing.T) { app, cagecli := setup(t, stdinService) - cagecli.EXPECT().Up(gomock.Any()).Return(&cage.UpResult{}, nil) + cagecli.EXPECT().Up(gomock.Any()).Return(&types.UpResult{}, nil) err := app.Run([]string{"cage", "up", "--region", "ap-notheast-1", "../../../fixtures"}) assert.NoError(t, err) }) t.Run("basic/ci", func(t *testing.T) { app, cagecli := setup(t, "") - cagecli.EXPECT().Up(gomock.Any()).Return(&cage.UpResult{}, nil) + cagecli.EXPECT().Up(gomock.Any()).Return(&types.UpResult{}, nil) err := app.Run([]string{"cage", "up", "--region", "ap-notheast-1", "../../../fixtures"}) assert.NoError(t, err) }) @@ -84,13 +85,13 @@ func TestCommands(t *testing.T) { t.Run("run", func(t *testing.T) { t.Run("basic", func(t *testing.T) { app, cagecli := setup(t, stdinTask) - cagecli.EXPECT().Run(gomock.Any(), gomock.Any()).Return(&cage.RunResult{}, nil) + cagecli.EXPECT().Run(gomock.Any(), gomock.Any()).Return(&types.RunResult{}, nil) err := app.Run([]string{"cage", "run", "--region", "ap-notheast-1", "../../../fixtures", "container", "exec"}) assert.NoError(t, err) }) t.Run("basic/ci", func(t *testing.T) { app, cagecli := setup(t, "") - cagecli.EXPECT().Run(gomock.Any(), gomock.Any()).Return(&cage.RunResult{}, nil) + cagecli.EXPECT().Run(gomock.Any(), gomock.Any()).Return(&types.RunResult{}, nil) err := app.Run([]string{"cage", "run", "--region", "ap-notheast-1", "../../../fixtures", "container", "exec"}) assert.NoError(t, err) }) diff --git a/cli/cage/commands/flags.go b/cli/cage/commands/flags.go index 1ea47ce..009a8c3 100644 --- a/cli/cage/commands/flags.go +++ b/cli/cage/commands/flags.go @@ -1,14 +1,14 @@ package commands import ( - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" "github.com/urfave/cli/v2" ) func RegionFlag(dest *string) *cli.StringFlag { return &cli.StringFlag{ Name: "region", - EnvVars: []string{cage.RegionKey}, + EnvVars: []string{env.RegionKey}, Usage: "aws region for ecs. if not specified, try to load from aws sessions automatically", Destination: dest, Required: true, @@ -17,7 +17,7 @@ func RegionFlag(dest *string) *cli.StringFlag { func ClusterFlag(dest *string) *cli.StringFlag { return &cli.StringFlag{ Name: "cluster", - EnvVars: []string{cage.ClusterKey}, + EnvVars: []string{env.ClusterKey}, Usage: "ecs cluster name. if not specified, load from service.json", Destination: dest, } @@ -25,7 +25,7 @@ func ClusterFlag(dest *string) *cli.StringFlag { func ServiceFlag(dest *string) *cli.StringFlag { return &cli.StringFlag{ Name: "service", - EnvVars: []string{cage.ServiceKey}, + EnvVars: []string{env.ServiceKey}, Usage: "service name. if not specified, load from service.json", Destination: dest, } @@ -33,7 +33,7 @@ func ServiceFlag(dest *string) *cli.StringFlag { func TaskDefinitionArnFlag(dest *string) *cli.StringFlag { return &cli.StringFlag{ Name: "taskDefinitionArn", - EnvVars: []string{cage.TaskDefinitionArnKey}, + EnvVars: []string{env.TaskDefinitionArnKey}, Usage: "full arn for next task definition. if not specified, use task-definition.json for registration", Destination: dest, } @@ -42,7 +42,7 @@ func TaskDefinitionArnFlag(dest *string) *cli.StringFlag { func CanaryTaskIdleDurationFlag(dest *int) *cli.IntFlag { return &cli.IntFlag{ Name: "canaryTaskIdleDuration", - EnvVars: []string{cage.CanaryTaskIdleDuration}, + EnvVars: []string{env.CanaryTaskIdleDuration}, Usage: "duration seconds for waiting canary task that isn't attached to target group considered as ready for serving traffic", Destination: dest, Value: 10, @@ -52,7 +52,7 @@ func CanaryTaskIdleDurationFlag(dest *int) *cli.IntFlag { func TaskRunningWaitFlag(dest *int) *cli.IntFlag { return &cli.IntFlag{ Name: "taskRunningTimeout", - EnvVars: []string{cage.TaskRunningTimeout}, + EnvVars: []string{env.TaskRunningTimeout}, Usage: "max duration seconds for waiting canary task running", Destination: dest, Category: "ADVANCED", @@ -63,7 +63,7 @@ func TaskRunningWaitFlag(dest *int) *cli.IntFlag { func TaskHealthCheckWaitFlag(dest *int) *cli.IntFlag { return &cli.IntFlag{ Name: "taskHealthCheckTimeout", - EnvVars: []string{cage.TaskHealthCheckTimeout}, + EnvVars: []string{env.TaskHealthCheckTimeout}, Usage: "max duration seconds for waiting canary task health check", Destination: dest, Category: "ADVANCED", @@ -74,7 +74,7 @@ func TaskHealthCheckWaitFlag(dest *int) *cli.IntFlag { func TaskStoppedWaitFlag(dest *int) *cli.IntFlag { return &cli.IntFlag{ Name: "taskStoppedTimeout", - EnvVars: []string{cage.TaskStoppedTimeout}, + EnvVars: []string{env.TaskStoppedTimeout}, Usage: "max duration seconds for waiting canary task stopped", Destination: dest, Category: "ADVANCED", @@ -85,7 +85,7 @@ func TaskStoppedWaitFlag(dest *int) *cli.IntFlag { func ServiceStableWaitFlag(dest *int) *cli.IntFlag { return &cli.IntFlag{ Name: "serviceStableTimeout", - EnvVars: []string{cage.ServiceStableTimeout}, + EnvVars: []string{env.ServiceStableTimeout}, Usage: "max duration seconds for waiting service stable", Destination: dest, Category: "ADVANCED", diff --git a/cli/cage/commands/rollout.go b/cli/cage/commands/rollout.go index 94f8b51..87d6129 100644 --- a/cli/cage/commands/rollout.go +++ b/cli/cage/commands/rollout.go @@ -4,12 +4,13 @@ import ( "context" "github.com/apex/log" - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/types" "github.com/urfave/cli/v2" ) func (c *CageCommands) RollOut( - envars *cage.Envars, + envars *env.Envars, ) *cli.Command { var updateServiceConf bool return &cli.Command{ @@ -26,13 +27,13 @@ func (c *CageCommands) RollOut( CanaryTaskIdleDurationFlag(&envars.CanaryTaskIdleDuration), &cli.StringFlag{ Name: "canaryInstanceArn", - EnvVars: []string{cage.CanaryInstanceArnKey}, + EnvVars: []string{env.CanaryInstanceArnKey}, Usage: "EC2 instance ARN for placing canary task. required only when LaunchType is EC2", Destination: &envars.CanaryInstanceArn, }, &cli.BoolFlag{ Name: "updateService", - EnvVars: []string{cage.UpdateServiceKey}, + EnvVars: []string{env.UpdateServiceKey}, Usage: "Update service configurations except for task definiton. Default is false.", Destination: &updateServiceConf, }, @@ -53,7 +54,7 @@ func (c *CageCommands) RollOut( if err := c.Prompt.ConfirmService(envars); err != nil { return err } - result, err := cagecli.RollOut(context.Background(), &cage.RollOutInput{UpdateService: updateServiceConf}) + result, err := cagecli.RollOut(context.Background(), &types.RollOutInput{UpdateService: updateServiceConf}) if err != nil { if result.ServiceIntact { log.Errorf("🤕 failed to roll out new tasks but service '%s' is not changed", envars.Service) diff --git a/cli/cage/commands/run.go b/cli/cage/commands/run.go index 26506c0..c830f5b 100644 --- a/cli/cage/commands/run.go +++ b/cli/cage/commands/run.go @@ -5,12 +5,13 @@ import ( "github.com/apex/log" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/types" "github.com/urfave/cli/v2" ) func (c *CageCommands) Run( - envars *cage.Envars, + envars *env.Envars, ) *cli.Command { return &cli.Command{ Name: "run", @@ -38,7 +39,7 @@ func (c *CageCommands) Run( } container := rest[0] commands := rest[1:] - if _, err := cagecli.Run(context.Background(), &cage.RunInput{ + if _, err := cagecli.Run(context.Background(), &types.RunInput{ Container: &container, Overrides: &ecstypes.TaskOverride{ ContainerOverrides: []ecstypes.ContainerOverride{ diff --git a/cli/cage/commands/up.go b/cli/cage/commands/up.go index 1ea12ce..3ecd906 100644 --- a/cli/cage/commands/up.go +++ b/cli/cage/commands/up.go @@ -3,12 +3,12 @@ package commands import ( "context" - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" "github.com/urfave/cli/v2" ) func (c *CageCommands) Up( - envars *cage.Envars, + envars *env.Envars, ) *cli.Command { return &cli.Command{ Name: "up", diff --git a/cli/cage/main.go b/cli/cage/main.go index 34549b6..40093bf 100644 --- a/cli/cage/main.go +++ b/cli/cage/main.go @@ -5,9 +5,9 @@ import ( "log" "os" - cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/commands" "github.com/loilo-inc/canarycage/cli/cage/upgrade" + "github.com/loilo-inc/canarycage/env" "github.com/urfave/cli/v2" ) @@ -25,7 +25,7 @@ func main() { app.Version = fmt.Sprintf("%s (commit: %s, date: %s)", version, commit, date) app.Usage = "A deployment tool for AWS ECS" app.Description = "A deployment tool for AWS ECS" - envars := cage.Envars{} + envars := env.Envars{} cmds := commands.NewCageCommands(os.Stdin, commands.DefalutCageCliProvider) app.Commands = []*cli.Command{ cmds.Up(&envars), diff --git a/cli/cage/prompt/prompt.go b/cli/cage/prompt/prompt.go index e970efc..05b8e84 100644 --- a/cli/cage/prompt/prompt.go +++ b/cli/cage/prompt/prompt.go @@ -6,7 +6,7 @@ import ( "io" "os" - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" "golang.org/x/xerrors" ) @@ -32,19 +32,19 @@ func (s *Prompter) Confirm( } func (s *Prompter) ConfirmTask( - envars *cage.Envars, + envars *env.Envars, ) error { return s.confirmStackChange(envars, false) } func (s *Prompter) ConfirmService( - envars *cage.Envars, + envars *env.Envars, ) error { return s.confirmStackChange(envars, true) } func (s *Prompter) confirmStackChange( - envars *cage.Envars, + envars *env.Envars, service bool, ) error { // Skip confirmation if running in CI diff --git a/cli/cage/prompt/prompt_test.go b/cli/cage/prompt/prompt_test.go index 1d495e0..b8cb743 100644 --- a/cli/cage/prompt/prompt_test.go +++ b/cli/cage/prompt/prompt_test.go @@ -4,8 +4,8 @@ import ( "strings" "testing" - cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/prompt" + "github.com/loilo-inc/canarycage/env" "github.com/stretchr/testify/assert" ) @@ -24,7 +24,7 @@ func TestPrompter(t *testing.T) { assert.Error(t, err) }) }) - envars := &cage.Envars{ + envars := &env.Envars{ Region: "ap-northeast-1", Cluster: "test-cluster", Service: "test-service", diff --git a/env.go b/env/env.go similarity index 99% rename from env.go rename to env/env.go index a47c310..f9ea511 100644 --- a/env.go +++ b/env/env.go @@ -1,4 +1,4 @@ -package cage +package env import ( "encoding/json" diff --git a/env_test.go b/env/env_test.go similarity index 72% rename from env_test.go rename to env/env_test.go index 84b4ebe..520e15b 100644 --- a/env_test.go +++ b/env/env_test.go @@ -1,66 +1,67 @@ -package cage +package env_test import ( "testing" "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/loilo-inc/canarycage/env" "github.com/stretchr/testify/assert" ) func TestEnsureEnvars(t *testing.T) { t.Run("basic", func(t *testing.T) { - e := &Envars{ + e := &env.Envars{ Region: "us-west-2", Cluster: "cluster", Service: "service-next", TaskDefinitionInput: &ecs.RegisterTaskDefinitionInput{}, } - if err := EnsureEnvars(e); err != nil { + if err := env.EnsureEnvars(e); err != nil { t.Fatalf(err.Error()) } }) t.Run("with td arn", func(t *testing.T) { - e := &Envars{ + e := &env.Envars{ Region: "us-west-2", Cluster: "cluster", Service: "next", TaskDefinitionArn: "arn://aaa", } - if err := EnsureEnvars(e); err != nil { + if err := env.EnsureEnvars(e); err != nil { t.Fatalf(err.Error()) } }) t.Run("should return err if nor taskDefinitionArn neither TaskDefinitionInput is defined", func(t *testing.T) { - e := &Envars{ + e := &env.Envars{ Region: "us-west-2", Cluster: "cluster", Service: "next", } - err := EnsureEnvars(e) + err := env.EnsureEnvars(e) assert.Errorf(t, err, "--nextTaskDefinitionArn or deploy context must be provided") }) t.Run("should return err if required props are not defined", func(t *testing.T) { dummy := "aaa" arr := []string{ - RegionKey, - ServiceKey, - ClusterKey, + env.RegionKey, + env.ServiceKey, + env.ClusterKey, } for i, v := range arr { m := make(map[string]string) - m[ServiceKey] = dummy - m[TaskDefinitionArnKey] = dummy - m[ClusterKey] = dummy + m[env.ServiceKey] = dummy + m[env.TaskDefinitionArnKey] = dummy + m[env.ClusterKey] = dummy for j, u := range arr { if i == j { m[u] = "" } } - e := &Envars{ - Service: m[ServiceKey], - Cluster: m[ClusterKey], + e := &env.Envars{ + Service: m[env.ServiceKey], + Cluster: m[env.ClusterKey], } - err := EnsureEnvars(e) + err := env.EnsureEnvars(e) if err == nil { t.Fatalf("should return error if %s is not defined", v) } @@ -69,15 +70,15 @@ func TestEnsureEnvars(t *testing.T) { } func TestMergeEnvars(t *testing.T) { - e1 := &Envars{ + e1 := &env.Envars{ Region: "us-west-2", Cluster: "cluster", } - e2 := &Envars{ + e2 := &env.Envars{ Cluster: "hoge", Service: "fuga", } - MergeEnvars(e1, e2) + env.MergeEnvars(e1, e2) assert.Equal(t, e1.Region, "us-west-2") assert.Equal(t, e1.Cluster, "hoge") assert.Equal(t, e1.Service, "fuga") diff --git a/util.go b/env/util.go similarity index 97% rename from util.go rename to env/util.go index 522dc46..9b1efdb 100644 --- a/util.go +++ b/env/util.go @@ -1,4 +1,4 @@ -package cage +package env import ( "os" diff --git a/util_test.go b/env/util_test.go similarity index 82% rename from util_test.go rename to env/util_test.go index bcfdc5b..2a73ddc 100644 --- a/util_test.go +++ b/env/util_test.go @@ -1,11 +1,14 @@ -package cage +package env_test import ( - "github.com/stretchr/testify/assert" "log" "os" "testing" "time" + + "github.com/stretchr/testify/assert" + + "github.com/loilo-inc/canarycage/env" ) func TestTimeAdd(t *testing.T) { @@ -19,7 +22,7 @@ func TestTimeAdd(t *testing.T) { func TestReadFileAndApplyEnvars(t *testing.T) { os.Setenv("HOGE", "hogehoge") os.Setenv("FUGA", "fugafuga") - d, err := ReadFileAndApplyEnvars("./fixtures/template.txt") + d, err := env.ReadFileAndApplyEnvars("./fixtures/template.txt") if err != nil { t.Fatalf(err.Error()) } diff --git a/mocks/mock_awsiface/iface.go b/mocks/mock_awsiface/iface.go index b582870..3e79b10 100644 --- a/mocks/mock_awsiface/iface.go +++ b/mocks/mock_awsiface/iface.go @@ -11,6 +11,7 @@ import ( ec2 "github.com/aws/aws-sdk-go-v2/service/ec2" ecs "github.com/aws/aws-sdk-go-v2/service/ecs" elasticloadbalancingv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + servicediscovery "github.com/aws/aws-sdk-go-v2/service/servicediscovery" gomock "github.com/golang/mock/gomock" ) @@ -502,3 +503,106 @@ func (mr *MockEc2ClientMockRecorder) DescribeSubnets(ctx, params interface{}, op varargs := append([]interface{}{ctx, params}, optFns...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeSubnets", reflect.TypeOf((*MockEc2Client)(nil).DescribeSubnets), varargs...) } + +// MockSrvClient is a mock of SrvClient interface. +type MockSrvClient struct { + ctrl *gomock.Controller + recorder *MockSrvClientMockRecorder +} + +// MockSrvClientMockRecorder is the mock recorder for MockSrvClient. +type MockSrvClientMockRecorder struct { + mock *MockSrvClient +} + +// NewMockSrvClient creates a new mock instance. +func NewMockSrvClient(ctrl *gomock.Controller) *MockSrvClient { + mock := &MockSrvClient{ctrl: ctrl} + mock.recorder = &MockSrvClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSrvClient) EXPECT() *MockSrvClientMockRecorder { + return m.recorder +} + +// DeregisterInstance mocks base method. +func (m *MockSrvClient) DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DeregisterInstance", varargs...) + ret0, _ := ret[0].(*servicediscovery.DeregisterInstanceOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeregisterInstance indicates an expected call of DeregisterInstance. +func (mr *MockSrvClientMockRecorder) DeregisterInstance(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeregisterInstance", reflect.TypeOf((*MockSrvClient)(nil).DeregisterInstance), varargs...) +} + +// DiscoverInstances mocks base method. +func (m *MockSrvClient) DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DiscoverInstances", varargs...) + ret0, _ := ret[0].(*servicediscovery.DiscoverInstancesOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DiscoverInstances indicates an expected call of DiscoverInstances. +func (mr *MockSrvClientMockRecorder) DiscoverInstances(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverInstances", reflect.TypeOf((*MockSrvClient)(nil).DiscoverInstances), varargs...) +} + +// GetService mocks base method. +func (m *MockSrvClient) GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetService", varargs...) + ret0, _ := ret[0].(*servicediscovery.GetServiceOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetService indicates an expected call of GetService. +func (mr *MockSrvClientMockRecorder) GetService(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockSrvClient)(nil).GetService), varargs...) +} + +// RegisterInstance mocks base method. +func (m *MockSrvClient) RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "RegisterInstance", varargs...) + ret0, _ := ret[0].(*servicediscovery.RegisterInstanceOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RegisterInstance indicates an expected call of RegisterInstance. +func (mr *MockSrvClientMockRecorder) RegisterInstance(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterInstance", reflect.TypeOf((*MockSrvClient)(nil).RegisterInstance), varargs...) +} diff --git a/mocks/mock_cage/cage.go b/mocks/mock_types/iface.go similarity index 86% rename from mocks/mock_cage/cage.go rename to mocks/mock_types/iface.go index e8aaa4c..e473b4d 100644 --- a/mocks/mock_cage/cage.go +++ b/mocks/mock_types/iface.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./cage.go +// Source: ./types/iface.go -// Package mock_cage is a generated GoMock package. -package mock_cage +// Package mock_types is a generated GoMock package. +package mock_types import ( context "context" @@ -10,7 +10,7 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - cage "github.com/loilo-inc/canarycage" + types "github.com/loilo-inc/canarycage/types" ) // MockCage is a mock of Cage interface. @@ -37,10 +37,10 @@ func (m *MockCage) EXPECT() *MockCageMockRecorder { } // RollOut mocks base method. -func (m *MockCage) RollOut(ctx context.Context, input *cage.RollOutInput) (*cage.RollOutResult, error) { +func (m *MockCage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.RollOutResult, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RollOut", ctx, input) - ret0, _ := ret[0].(*cage.RollOutResult) + ret0, _ := ret[0].(*types.RollOutResult) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -52,10 +52,10 @@ func (mr *MockCageMockRecorder) RollOut(ctx, input interface{}) *gomock.Call { } // Run mocks base method. -func (m *MockCage) Run(ctx context.Context, input *cage.RunInput) (*cage.RunResult, error) { +func (m *MockCage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Run", ctx, input) - ret0, _ := ret[0].(*cage.RunResult) + ret0, _ := ret[0].(*types.RunResult) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -67,10 +67,10 @@ func (mr *MockCageMockRecorder) Run(ctx, input interface{}) *gomock.Call { } // Up mocks base method. -func (m *MockCage) Up(ctx context.Context) (*cage.UpResult, error) { +func (m *MockCage) Up(ctx context.Context) (*types.UpResult, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Up", ctx) - ret0, _ := ret[0].(*cage.UpResult) + ret0, _ := ret[0].(*types.UpResult) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/rollout.go b/rollout.go index 4bba347..55d830a 100644 --- a/rollout.go +++ b/rollout.go @@ -51,15 +51,8 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R _ = recover() eg := errgroup.Group{} for _, canaryTask := range canaryTasks { - if canaryTask.taskArn == nil { - continue - } eg.Go(func() error { - err := canaryTask.Stop(ctx) - if err != nil { - log.Errorf("failed to stop canary task '%s': %s", *canaryTask.taskArn, err) - } - return err + return canaryTask.Stop(ctx) }) } if err := eg.Wait(); err != nil { diff --git a/rollout_test.go b/rollout_test.go index 6b45c2b..fee41f4 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -37,7 +37,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) } - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -60,7 +60,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -99,7 +99,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, nil).Times(2), albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes(), ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -143,7 +143,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, nil), albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes(), ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -172,7 +172,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{newLb} envars.ServiceDefinitionInput.NetworkConfiguration = newNetwork envars.ServiceDefinitionInput.PlatformVersion = aws.String("LATEST") - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -195,7 +195,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars := test.DefaultEnvars() ctrl := gomock.NewController(t) mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -220,7 +220,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mocker, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") delete(mocker.Services, envars.Service) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -236,7 +236,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars.CanaryTaskIdleDuration = 1 ctrl := gomock.NewController(t) _, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -262,7 +262,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, }, nil, ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -305,7 +305,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask).AnyTimes() ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.StopTask).AnyTimes() - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -351,7 +351,7 @@ func TestCage_RollOut_EC2(t *testing.T) { if taskCnt := mctx.RunningTaskSize(); taskCnt != v { t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) } - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -381,7 +381,7 @@ func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { if taskCnt := mctx.RunningTaskSize(); taskCnt != 1 { t.Fatalf("current tasks not setup: %d/%d", 1, taskCnt) } - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -416,7 +416,7 @@ func TestCage_RollOut_EC2_no_attribute(t *testing.T) { Attributes: []ecstypes.Attribute{}, }, nil).AnyTimes() ecsMock.EXPECT().PutAttributes(gomock.Any(), gomock.Any()).Return(&ecs.PutAttributesOutput{}, nil).AnyTimes() - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, diff --git a/run_test.go b/run_test.go index 1fb155b..ffa95fe 100644 --- a/run_test.go +++ b/run_test.go @@ -9,6 +9,7 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/golang/mock/gomock" cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/canarycage/types" @@ -16,7 +17,7 @@ import ( ) func TestCage_Run(t *testing.T) { - setupForBasic := func(t *testing.T) (*cage.Envars, + setupForBasic := func(t *testing.T) (*env.Envars, *test.MockContext, *mock_awsiface.MockEcsClient) { env := test.DefaultEnvars() @@ -39,7 +40,7 @@ func TestCage_Run(t *testing.T) { return mocker.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -69,7 +70,7 @@ func TestCage_Run(t *testing.T) { }, ), ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -91,7 +92,7 @@ func TestCage_Run(t *testing.T) { ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTasks).Times(2), ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -119,7 +120,7 @@ func TestCage_Run(t *testing.T) { return mocker.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -147,7 +148,7 @@ func TestCage_Run(t *testing.T) { return mocker.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -163,7 +164,7 @@ func TestCage_Run(t *testing.T) { overrides := &ecstypes.TaskOverride{} ctx := context.Background() env, _, ecsMock := setupForBasic(t) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), diff --git a/task_definition_test.go b/task_definition_test.go index 4177d66..c700ca6 100644 --- a/task_definition_test.go +++ b/task_definition_test.go @@ -8,8 +8,10 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/golang/mock/gomock" cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/types" "github.com/stretchr/testify/assert" "golang.org/x/xerrors" ) @@ -18,11 +20,11 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { t.Run("should return task definition if taskDefinitionArn is set", func(t *testing.T) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - env := &cage.Envars{ + env := &env.Envars{ TaskDefinitionArn: "arn://aaa", } c := &cage.CageExport{ - Input: &cage.Input{ + Input: &types.Input{ Env: env, Ecs: ecsMock, }, @@ -37,11 +39,11 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { t.Run("should return error if taskDefinitionArn is set and failed to describe task definition", func(t *testing.T) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - env := &cage.Envars{ + env := &env.Envars{ TaskDefinitionArn: "arn://aaa", } c := &cage.CageExport{ - Input: &cage.Input{ + Input: &types.Input{ Env: env, Ecs: ecsMock, }, @@ -56,7 +58,7 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() c := &cage.CageExport{ - Input: &cage.Input{ + Input: &types.Input{ Env: env, Ecs: ecsMock, }, @@ -76,7 +78,7 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() c := &cage.CageExport{ - Input: &cage.Input{ + Input: &types.Input{ Env: env, Ecs: ecsMock, }, diff --git a/test/fake_timer.go b/test/fake_timer.go index 0ebeb14..111c5d8 100644 --- a/test/fake_timer.go +++ b/test/fake_timer.go @@ -3,7 +3,7 @@ package test import ( "time" - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/types" ) func newTimer(_ time.Duration) *time.Timer { @@ -24,6 +24,6 @@ func (t *timeImpl) Now() time.Time { func (t *timeImpl) NewTimer(d time.Duration) *time.Timer { return newTimer(d) } -func NewFakeTime() cage.Time { +func NewFakeTime() types.Time { return &timeImpl{} } diff --git a/test/setup.go b/test/setup.go index b2e77a2..e3bc09e 100644 --- a/test/setup.go +++ b/test/setup.go @@ -11,11 +11,11 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" "github.com/golang/mock/gomock" - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" ) -func Setup(ctrl *gomock.Controller, envars *cage.Envars, currentTaskCount int, launchType ecstypes.LaunchType) ( +func Setup(ctrl *gomock.Controller, envars *env.Envars, currentTaskCount int, launchType ecstypes.LaunchType) ( *MockContext, *mock_awsiface.MockEcsClient, *mock_awsiface.MockAlbClient, @@ -62,13 +62,13 @@ func Setup(ctrl *gomock.Controller, envars *cage.Envars, currentTaskCount int, l return mocker, ecsMock, albMock, ec2Mock } -func DefaultEnvars() *cage.Envars { +func DefaultEnvars() *env.Envars { d, _ := os.ReadFile("fixtures/task-definition.json") var taskDefinition ecs.RegisterTaskDefinitionInput if err := json.Unmarshal(d, &taskDefinition); err != nil { log.Fatalf(err.Error()) } - return &cage.Envars{ + return &env.Envars{ Region: "us-west-2", Cluster: "cage-test", Service: "service", diff --git a/types/iface.go b/types/iface.go index 7ba1773..db997af 100644 --- a/types/iface.go +++ b/types/iface.go @@ -4,28 +4,11 @@ import ( "context" "time" - "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" ) -type Envars struct { - _ struct{} `type:"struct"` - CI bool `json:"ci" type:"bool"` - Region string `json:"region" type:"string"` - Cluster string `json:"cluster" type:"string" required:"true"` - Service string `json:"service" type:"string" required:"true"` - CanaryInstanceArn string - TaskDefinitionArn string `json:"nextTaskDefinitionArn" type:"string"` - TaskDefinitionInput *ecs.RegisterTaskDefinitionInput - ServiceDefinitionInput *ecs.CreateServiceInput - CanaryTaskIdleDuration int // sec - CanaryTaskRunningWait int // sec - CanaryTaskHealthCheckWait int // sec - CanaryTaskStoppedWait int // sec - ServiceStableWait int // sec -} - type Cage interface { Up(ctx context.Context) (*UpResult, error) Run(ctx context.Context, input *RunInput) (*RunResult, error) @@ -38,13 +21,14 @@ type Time interface { } type Input struct { - Env *Envars + Env *env.Envars Ecs awsiface.EcsClient Alb awsiface.AlbClient Ec2 awsiface.Ec2Client Srv awsiface.SrvClient Time Time } + type RunInput struct { Container *string Overrides *ecstypes.TaskOverride diff --git a/up_test.go b/up_test.go index 06b5b9d..9a20e0e 100644 --- a/up_test.go +++ b/up_test.go @@ -7,6 +7,7 @@ import ( "github.com/golang/mock/gomock" cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/types" "github.com/stretchr/testify/assert" ) @@ -16,7 +17,7 @@ func TestCage_Up(t *testing.T) { ctrl := gomock.NewController(t) ctx, ecsMock, _, _ := test.Setup(ctrl, env, 1, "FARGATE") delete(ctx.Services, env.Service) - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, }) @@ -29,7 +30,7 @@ func TestCage_Up(t *testing.T) { env := test.DefaultEnvars() ctrl := gomock.NewController(t) _, ecsMock, _, _ := test.Setup(ctrl, env, 1, "FARGATE") - cagecli := cage.NewCage(&cage.Input{ + cagecli := cage.NewCage(&types.Input{ Env: env, Ecs: ecsMock, }) From 78cecbd5afceb4f73e97e036b332567a9455fb84 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 27 Jun 2024 18:34:22 +0900 Subject: [PATCH 04/47] wip --- cage.go | 6 +- canary/common.go | 6 +- cli/cage/commands/command.go | 2 +- env/env.go | 20 ++ env/env_test.go | 18 ++ {fixtures => env/fixtures}/template.txt | 0 env/util.go | 27 --- env/util_test.go | 36 ---- rollout.go | 42 ++-- rollout_test.go | 26 +-- run_test.go | 12 +- {canary => task}/alb_task.go | 21 +- task/common.go | 242 ++++++++++++++++++++++++ {canary => task}/simple_task.go | 7 +- {canary => task}/srv_task.go | 12 +- task_definition_test.go | 8 +- test/context.go | 4 +- types/iface.go | 2 +- up_test.go | 4 +- 19 files changed, 356 insertions(+), 139 deletions(-) rename {fixtures => env/fixtures}/template.txt (100%) delete mode 100644 env/util.go delete mode 100644 env/util_test.go rename {canary => task}/alb_task.go (91%) create mode 100644 task/common.go rename {canary => task}/simple_task.go (85%) rename {canary => task}/srv_task.go (94%) diff --git a/cage.go b/cage.go index 8a4f514..553d0f6 100644 --- a/cage.go +++ b/cage.go @@ -8,11 +8,11 @@ import ( ) type cage struct { - *types.Input + *types.Deps Timeout timeout.Manager } -func NewCage(input *types.Input) types.Cage { +func NewCage(input *types.Deps) types.Cage { if input.Time == nil { input.Time = &timeImpl{} } @@ -21,7 +21,7 @@ func NewCage(input *types.Input) types.Cage { taskStoppedWait := (time.Duration)(input.Env.CanaryTaskStoppedWait) * time.Second serviceStableWait := (time.Duration)(input.Env.ServiceStableWait) * time.Second return &cage{ - Input: input, + Deps: input, Timeout: timeout.NewManager( 15*time.Minute, &timeout.Input{ diff --git a/canary/common.go b/canary/common.go index d9d6911..9ff38fb 100644 --- a/canary/common.go +++ b/canary/common.go @@ -30,7 +30,7 @@ type Task interface { } type Input struct { - *types.Input + *types.Deps TaskDefinition *ecstypes.TaskDefinition Network *ecstypes.NetworkConfiguration PlatformVersion *string @@ -42,10 +42,6 @@ type common struct { taskArn *string } -func NewSimpleTask(input *Input) Task { - return &simpleTask{common: &common{Input: input}} -} - func (c *common) Start(ctx context.Context) error { if c.Env.CanaryInstanceArn != "" { // ec2 diff --git a/cli/cage/commands/command.go b/cli/cage/commands/command.go index d36a9f3..32fea4d 100644 --- a/cli/cage/commands/command.go +++ b/cli/cage/commands/command.go @@ -40,7 +40,7 @@ func DefalutCageCliProvider(envars *env.Envars) (types.Cage, error) { if err != nil { return nil, xerrors.Errorf("failed to load aws config: %w", err) } - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecs.NewFromConfig(conf), Ec2: ec2.NewFromConfig(conf), diff --git a/env/env.go b/env/env.go index f9ea511..225565f 100644 --- a/env/env.go +++ b/env/env.go @@ -4,6 +4,8 @@ import ( "encoding/json" "os" "path/filepath" + "regexp" + "strings" "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" @@ -119,3 +121,21 @@ func ReadAndUnmarshalJson(path string, dest interface{}) ([]byte, error) { return d, nil } } + +func ReadFileAndApplyEnvars(path string) ([]byte, error) { + d, err := os.ReadFile(path) + if err != nil { + return nil, err + } + str := string(d) + reg := regexp.MustCompile(`\${(.+?)}`) + submatches := reg.FindAllStringSubmatch(str, -1) + for _, m := range submatches { + if envar, ok := os.LookupEnv(m[1]); ok { + str = strings.Replace(str, m[0], envar, -1) + } else { + log.Fatalf("envar literal '%s' found in %s but was not defined", m[0], path) + } + } + return []byte(str), nil +} diff --git a/env/env_test.go b/env/env_test.go index 520e15b..1645f8c 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -1,8 +1,10 @@ package env_test import ( + "os" "testing" + "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/loilo-inc/canarycage/env" "github.com/stretchr/testify/assert" @@ -83,3 +85,19 @@ func TestMergeEnvars(t *testing.T) { assert.Equal(t, e1.Cluster, "hoge") assert.Equal(t, e1.Service, "fuga") } + +func TestReadFileAndApplyEnvars(t *testing.T) { + os.Setenv("HOGE", "hogehoge") + os.Setenv("FUGA", "fugafuga") + d, err := env.ReadFileAndApplyEnvars("./fixtures/template.txt") + if err != nil { + t.Fatalf(err.Error()) + } + s := string(d) + e := `HOGE=hogehoge +FUGA=fugafuga +fugafuga=hogehoge` + if s != e { + log.Fatalf("e: %s, a: %s", e, s) + } +} diff --git a/fixtures/template.txt b/env/fixtures/template.txt similarity index 100% rename from fixtures/template.txt rename to env/fixtures/template.txt diff --git a/env/util.go b/env/util.go deleted file mode 100644 index 9b1efdb..0000000 --- a/env/util.go +++ /dev/null @@ -1,27 +0,0 @@ -package env - -import ( - "os" - "regexp" - "strings" - - "github.com/apex/log" -) - -func ReadFileAndApplyEnvars(path string) ([]byte, error) { - d, err := os.ReadFile(path) - if err != nil { - return nil, err - } - str := string(d) - reg := regexp.MustCompile(`\${(.+?)}`) - submatches := reg.FindAllStringSubmatch(str, -1) - for _, m := range submatches { - if envar, ok := os.LookupEnv(m[1]); ok { - str = strings.Replace(str, m[0], envar, -1) - } else { - log.Fatalf("envar literal '%s' found in %s but was not defined", m[0], path) - } - } - return []byte(str), nil -} diff --git a/env/util_test.go b/env/util_test.go deleted file mode 100644 index 2a73ddc..0000000 --- a/env/util_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package env_test - -import ( - "log" - "os" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/loilo-inc/canarycage/env" -) - -func TestTimeAdd(t *testing.T) { - now := time.Now() - after5min := now - after5min = now.Add(time.Duration(5) * time.Minute) - assert.Equal(t, after5min.After(now), true) - assert.NotEqual(t, now.Unix(), after5min.Unix()) -} - -func TestReadFileAndApplyEnvars(t *testing.T) { - os.Setenv("HOGE", "hogehoge") - os.Setenv("FUGA", "fugafuga") - d, err := env.ReadFileAndApplyEnvars("./fixtures/template.txt") - if err != nil { - t.Fatalf(err.Error()) - } - s := string(d) - e := `HOGE=hogehoge -FUGA=fugafuga -fugafuga=hogehoge` - if s != e { - log.Fatalf("e: %s, a: %s", e, s) - } -} diff --git a/rollout.go b/rollout.go index 55d830a..3f75e7b 100644 --- a/rollout.go +++ b/rollout.go @@ -6,7 +6,7 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/loilo-inc/canarycage/canary" + "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/types" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -111,7 +111,7 @@ func (c *cage) StartCanaryTasks( ctx context.Context, nextTaskDefinition *ecstypes.TaskDefinition, input *types.RollOutInput, -) ([]canary.Task, error) { +) ([]task.Task, error) { var networkConfiguration *ecstypes.NetworkConfiguration var platformVersion *string var loadBalancers []ecstypes.LoadBalancer @@ -135,14 +135,14 @@ func (c *cage) StartCanaryTasks( serviceRegistries = service.ServiceRegistries } } - var results []canary.Task + var results []task.Task for _, lb := range loadBalancers { - task := canary.NewAlbTask(&canary.Input{ - Input: c.Input, - Network: networkConfiguration, - TaskDefinition: nextTaskDefinition, - PlatformVersion: platformVersion, - Timeout: c.Timeout, + task := task.NewAlbTask(&task.Input{ + Deps: c.Deps, + NetworkConfiguration: networkConfiguration, + TaskDefinition: nextTaskDefinition, + PlatformVersion: platformVersion, + Timeout: c.Timeout, }, &lb) results = append(results, task) if err := task.Start(ctx); err != nil { @@ -150,12 +150,12 @@ func (c *cage) StartCanaryTasks( } } for _, srv := range serviceRegistries { - task := canary.NewSrvTask(&canary.Input{ - Input: c.Input, - Network: networkConfiguration, - TaskDefinition: nextTaskDefinition, - PlatformVersion: platformVersion, - Timeout: c.Timeout, + task := task.NewSrvTask(&task.Input{ + Deps: c.Deps, + NetworkConfiguration: networkConfiguration, + TaskDefinition: nextTaskDefinition, + PlatformVersion: platformVersion, + Timeout: c.Timeout, }, &srv) results = append(results, task) if err := task.Start(ctx); err != nil { @@ -163,12 +163,12 @@ func (c *cage) StartCanaryTasks( } } if len(results) == 0 { - task := canary.NewSimpleTask(&canary.Input{ - Input: c.Input, - Network: networkConfiguration, - TaskDefinition: nextTaskDefinition, - PlatformVersion: platformVersion, - Timeout: c.Timeout, + task := task.NewSimpleTask(&task.Input{ + Deps: c.Deps, + NetworkConfiguration: networkConfiguration, + TaskDefinition: nextTaskDefinition, + PlatformVersion: platformVersion, + Timeout: c.Timeout, }) results = append(results, task) if err := task.Start(ctx); err != nil { diff --git a/rollout_test.go b/rollout_test.go index fee41f4..06cc18b 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -37,7 +37,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) } - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -60,7 +60,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -99,7 +99,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, nil).Times(2), albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes(), ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -143,7 +143,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, nil), albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes(), ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -172,7 +172,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{newLb} envars.ServiceDefinitionInput.NetworkConfiguration = newNetwork envars.ServiceDefinitionInput.PlatformVersion = aws.String("LATEST") - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -195,7 +195,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars := test.DefaultEnvars() ctrl := gomock.NewController(t) mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -220,7 +220,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mocker, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") delete(mocker.Services, envars.Service) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -236,7 +236,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars.CanaryTaskIdleDuration = 1 ctrl := gomock.NewController(t) _, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Alb: albMock, @@ -262,7 +262,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, }, nil, ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -305,7 +305,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask).AnyTimes() ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.StopTask).AnyTimes() - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -351,7 +351,7 @@ func TestCage_RollOut_EC2(t *testing.T) { if taskCnt := mctx.RunningTaskSize(); taskCnt != v { t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) } - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -381,7 +381,7 @@ func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { if taskCnt := mctx.RunningTaskSize(); taskCnt != 1 { t.Fatalf("current tasks not setup: %d/%d", 1, taskCnt) } - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, @@ -416,7 +416,7 @@ func TestCage_RollOut_EC2_no_attribute(t *testing.T) { Attributes: []ecstypes.Attribute{}, }, nil).AnyTimes() ecsMock.EXPECT().PutAttributes(gomock.Any(), gomock.Any()).Return(&ecs.PutAttributesOutput{}, nil).AnyTimes() - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, Ec2: ec2Mock, diff --git a/run_test.go b/run_test.go index ffa95fe..8169ba7 100644 --- a/run_test.go +++ b/run_test.go @@ -40,7 +40,7 @@ func TestCage_Run(t *testing.T) { return mocker.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -70,7 +70,7 @@ func TestCage_Run(t *testing.T) { }, ), ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -92,7 +92,7 @@ func TestCage_Run(t *testing.T) { ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTasks).Times(2), ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -120,7 +120,7 @@ func TestCage_Run(t *testing.T) { return mocker.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -148,7 +148,7 @@ func TestCage_Run(t *testing.T) { return mocker.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), @@ -164,7 +164,7 @@ func TestCage_Run(t *testing.T) { overrides := &ecstypes.TaskOverride{} ctx := context.Background() env, _, ecsMock := setupForBasic(t) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, Time: test.NewFakeTime(), diff --git a/canary/alb_task.go b/task/alb_task.go similarity index 91% rename from canary/alb_task.go rename to task/alb_task.go index 8f8829b..9bfd86c 100644 --- a/canary/alb_task.go +++ b/task/alb_task.go @@ -1,4 +1,4 @@ -package canary +package task import ( "context" @@ -12,6 +12,7 @@ import ( "golang.org/x/xerrors" ) +// albTask is a task that is attached to an Application Load Balancer type albTask struct { *common lb *ecstypes.LoadBalancer @@ -86,31 +87,33 @@ func (c *albTask) Stop(ctx context.Context) error { return c.stopTask(ctx) } -func (c *albTask) getTargetPort(ctx context.Context) (int32, error) { +func (c *albTask) getTargetPort() (int32, error) { for _, container := range c.TaskDefinition.ContainerDefinitions { if *container.Name == *c.lb.ContainerName { - return *container.PortMappings[0].ContainerPort, nil + return *container.PortMappings[0].HostPort, nil } } return 0, xerrors.Errorf("couldn't find host port in container definition") } func (c *albTask) registerToTargetGroup(ctx context.Context) error { - info, err := c.describeTaskTarget(ctx, c.getTargetPort) - if err != nil { + if targetPort, err := c.getTargetPort(); err != nil { return err + } else if target, err := c.describeTaskTarget(ctx, targetPort); err != nil { + return err + } else { + c.target = target } if _, err := c.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ TargetGroupArn: c.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &info.availabilityZone, - Id: &info.targetId, - Port: &info.targetPort, + AvailabilityZone: &c.target.availabilityZone, + Id: &c.target.targetId, + Port: &c.target.targetPort, }}, }); err != nil { return err } - c.target = info return nil } diff --git a/task/common.go b/task/common.go new file mode 100644 index 0000000..933b581 --- /dev/null +++ b/task/common.go @@ -0,0 +1,242 @@ +package task + +import ( + "context" + "fmt" + "time" + + "github.com/apex/log" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "golang.org/x/xerrors" +) + +type CanaryTarget struct { + targetId string + targetIpv4 string + targetPort int32 + availabilityZone string +} + +type Task interface { + Start(ctx context.Context) error + Wait(ctx context.Context) error + Stop(ctx context.Context) error +} + +type Input struct { + *types.Deps + TaskDefinition *ecstypes.TaskDefinition + NetworkConfiguration *ecstypes.NetworkConfiguration + PlatformVersion *string + Timeout timeout.Manager +} + +type common struct { + *Input + taskArn *string +} + +func (c *common) Start(ctx context.Context) error { + if c.Env.CanaryInstanceArn != "" { + // ec2 + startTask := &ecs.StartTaskInput{ + Cluster: &c.Env.Cluster, + Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), + NetworkConfiguration: c.NetworkConfiguration, + TaskDefinition: c.TaskDefinition.TaskDefinitionArn, + ContainerInstances: []string{c.Env.CanaryInstanceArn}, + } + if o, err := c.Ecs.StartTask(ctx, startTask); err != nil { + return err + } else { + c.taskArn = o.Tasks[0].TaskArn + } + } else { + // fargate + if o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ + Cluster: &c.Env.Cluster, + Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), + NetworkConfiguration: c.NetworkConfiguration, + TaskDefinition: c.TaskDefinition.TaskDefinitionArn, + LaunchType: ecstypes.LaunchTypeFargate, + PlatformVersion: c.PlatformVersion, + }); err != nil { + return err + } else { + c.taskArn = o.Tasks[0].TaskArn + } + } + return nil +} + +func (c *common) wait(ctx context.Context) error { + log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) + if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }, c.Timeout.TaskRunning()); err != nil { + return err + } + log.Infof("🐣 canary task '%s' is running!", *c.taskArn) + if err := c.waitContainerHealthCheck(ctx); err != nil { + return err + } + log.Info("🤩 canary task container(s) is healthy!") + log.Infof("canary task '%s' ensured.", *c.taskArn) + return nil +} + +func (c *common) waitContainerHealthCheck(ctx context.Context) error { + log.Infof("😷 ensuring canary task container(s) to become healthy...") + containerHasHealthChecks := map[string]struct{}{} + for _, definition := range c.TaskDefinition.ContainerDefinitions { + if definition.HealthCheck != nil { + containerHasHealthChecks[*definition.Name] = struct{}{} + } + } + healthCheckWait := c.Timeout.TaskHealthCheck() + healthCheckPeriod := 15 * time.Second + countPerPeriod := int(healthCheckWait.Seconds() / 15) + for count := 0; count < countPerPeriod; count++ { + <-c.Time.NewTimer(healthCheckPeriod).C + log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) + if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }); err != nil { + return err + } else { + task := o.Tasks[0] + if *task.LastStatus != "RUNNING" { + return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) + } + + for _, container := range task.Containers { + if _, ok := containerHasHealthChecks[*container.Name]; !ok { + continue + } + if container.HealthStatus != ecstypes.HealthStatusHealthy { + log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) + continue + } + delete(containerHasHealthChecks, *container.Name) + } + if len(containerHasHealthChecks) == 0 { + return nil + } + } + } + return xerrors.Errorf("😨 canary task hasn't become to be healthy") +} + +func (c *common) describeTaskTarget( + ctx context.Context, + targetPort int32, +) (*CanaryTarget, error) { + target := CanaryTarget{targetPort: targetPort} + if c.Env.CanaryInstanceArn == "" { // Fargate + if err := c.getFargateTarget(ctx, &target); err != nil { + return nil, err + } + log.Infof("canary task was placed: privateIp = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + } else { + if err := c.getEc2Target(ctx, &target); err != nil { + return nil, err + } + log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + } + return &target, nil +} + +func (c *common) getFargateTarget(ctx context.Context, dest *CanaryTarget) error { + var task ecstypes.Task + if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }); err != nil { + return err + } else { + task = o.Tasks[0] + } + details := task.Attachments[0].Details + var subnetId *string + var privateIp *string + for _, v := range details { + if *v.Name == "subnetId" { + subnetId = v.Value + } else if *v.Name == "privateIPv4Address" { + privateIp = v.Value + } + } + if subnetId == nil || privateIp == nil { + return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") + } + if o, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: []string{*subnetId}, + }); err != nil { + return err + } else { + dest.targetId = *privateIp + dest.targetIpv4 = *privateIp + dest.availabilityZone = *o.Subnets[0].AvailabilityZone + } + return nil +} + +func (c *common) getEc2Target(ctx context.Context, dest *CanaryTarget) error { + var containerInstance ecstypes.ContainerInstance + if outputs, err := c.Ecs.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ + Cluster: &c.Env.Cluster, + ContainerInstances: []string{c.Env.CanaryInstanceArn}, + }); err != nil { + return err + } else { + containerInstance = outputs.ContainerInstances[0] + } + var ec2Instance ec2types.Instance + if o, err := c.Ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + InstanceIds: []string{*containerInstance.Ec2InstanceId}, + }); err != nil { + return err + } else { + ec2Instance = o.Reservations[0].Instances[0] + } + if sn, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: []string{*ec2Instance.SubnetId}, + }); err != nil { + return err + } else { + dest.targetId = *containerInstance.Ec2InstanceId + dest.targetIpv4 = *ec2Instance.PrivateIpAddress + dest.availabilityZone = *sn.Subnets[0].AvailabilityZone + } + return nil +} + +func (c *common) stopTask(ctx context.Context) error { + if c.taskArn == nil { + log.Info("no canary task to stop") + return nil + } + log.Infof("stopping the canary task '%s'...", *c.taskArn) + if _, err := c.Ecs.StopTask(ctx, &ecs.StopTaskInput{ + Cluster: &c.Env.Cluster, + Task: c.taskArn, + }); err != nil { + return xerrors.Errorf("failed to stop canary task: %w", err) + } + if err := ecs.NewTasksStoppedWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }, c.Timeout.TaskStopped()); err != nil { + return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) + } + log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) + return nil +} diff --git a/canary/simple_task.go b/task/simple_task.go similarity index 85% rename from canary/simple_task.go rename to task/simple_task.go index 1189833..bee5667 100644 --- a/canary/simple_task.go +++ b/task/simple_task.go @@ -1,4 +1,4 @@ -package canary +package task import ( "context" @@ -9,10 +9,15 @@ import ( "golang.org/x/xerrors" ) +// simpleTask is a task that isn't attachet to any load balancer or service discovery type simpleTask struct { *common } +func NewSimpleTask(input *Input) Task { + return &simpleTask{common: &common{Input: input}} +} + func (c *simpleTask) Wait(ctx context.Context) error { if err := c.wait(ctx); err != nil { return err diff --git a/canary/srv_task.go b/task/srv_task.go similarity index 94% rename from canary/srv_task.go rename to task/srv_task.go index 835c295..a12d6fe 100644 --- a/canary/srv_task.go +++ b/task/srv_task.go @@ -1,4 +1,4 @@ -package canary +package task import ( "context" @@ -12,6 +12,7 @@ import ( "golang.org/x/xerrors" ) +// srvTask is a task that is attached to an Service Discovery type srvTask struct { *common registry *ecstypes.ServiceRegistry @@ -49,15 +50,8 @@ func (c *srvTask) Stop(ctx context.Context) error { return c.stopTask(ctx) } -func (c *srvTask) getTargetPort(ctx context.Context) (int32, error) { - if c.registry.Port != nil { - return *c.registry.Port, nil - } - return 80, nil -} - func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { - target, err := c.describeTaskTarget(ctx, c.getTargetPort) + target, err := c.describeTaskTarget(ctx, *c.registry.Port) if err != nil { return err } diff --git a/task_definition_test.go b/task_definition_test.go index c700ca6..3c51608 100644 --- a/task_definition_test.go +++ b/task_definition_test.go @@ -24,7 +24,7 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { TaskDefinitionArn: "arn://aaa", } c := &cage.CageExport{ - Input: &types.Input{ + Deps: &types.Deps{ Env: env, Ecs: ecsMock, }, @@ -43,7 +43,7 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { TaskDefinitionArn: "arn://aaa", } c := &cage.CageExport{ - Input: &types.Input{ + Deps: &types.Deps{ Env: env, Ecs: ecsMock, }, @@ -58,7 +58,7 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() c := &cage.CageExport{ - Input: &types.Input{ + Deps: &types.Deps{ Env: env, Ecs: ecsMock, }, @@ -78,7 +78,7 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() c := &cage.CageExport{ - Input: &types.Input{ + Deps: &types.Deps{ Env: env, Ecs: ecsMock, }, diff --git a/test/context.go b/test/context.go index e161401..746aa92 100644 --- a/test/context.go +++ b/test/context.go @@ -453,7 +453,9 @@ func (ctx *MockContext) DescribeInstances(_ context.Context, input *ec2.Describe return &ec2.DescribeInstancesOutput{ Reservations: []ec2types.Reservation{{ Instances: []ec2types.Instance{{ - SubnetId: aws.String("us-west-2a"), + InstanceId: aws.String("i-123456"), + PrivateIpAddress: aws.String("127.0.1.0"), + SubnetId: aws.String("us-west-2a"), }}, }}, }, nil diff --git a/types/iface.go b/types/iface.go index db997af..7e28a9d 100644 --- a/types/iface.go +++ b/types/iface.go @@ -20,7 +20,7 @@ type Time interface { NewTimer(time.Duration) *time.Timer } -type Input struct { +type Deps struct { Env *env.Envars Ecs awsiface.EcsClient Alb awsiface.AlbClient diff --git a/up_test.go b/up_test.go index 9a20e0e..1f3023e 100644 --- a/up_test.go +++ b/up_test.go @@ -17,7 +17,7 @@ func TestCage_Up(t *testing.T) { ctrl := gomock.NewController(t) ctx, ecsMock, _, _ := test.Setup(ctrl, env, 1, "FARGATE") delete(ctx.Services, env.Service) - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, }) @@ -30,7 +30,7 @@ func TestCage_Up(t *testing.T) { env := test.DefaultEnvars() ctrl := gomock.NewController(t) _, ecsMock, _, _ := test.Setup(ctrl, env, 1, "FARGATE") - cagecli := cage.NewCage(&types.Input{ + cagecli := cage.NewCage(&types.Deps{ Env: env, Ecs: ecsMock, }) From c23f33bf9cc28bb0357e31737b23ba84af7dc638 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 27 Jun 2024 18:35:43 +0900 Subject: [PATCH 05/47] Delete common.go --- canary/common.go | 248 ----------------------------------------------- 1 file changed, 248 deletions(-) delete mode 100644 canary/common.go diff --git a/canary/common.go b/canary/common.go deleted file mode 100644 index 9ff38fb..0000000 --- a/canary/common.go +++ /dev/null @@ -1,248 +0,0 @@ -package canary - -import ( - "context" - "fmt" - "time" - - "github.com/apex/log" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go-v2/service/ecs" - ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/loilo-inc/canarycage/timeout" - "github.com/loilo-inc/canarycage/types" - "golang.org/x/xerrors" -) - -type CanaryTarget struct { - targetId string - targetIpv4 string - targetPort int32 - availabilityZone string -} - -type Task interface { - Start(ctx context.Context) error - Wait(ctx context.Context) error - Stop(ctx context.Context) error -} - -type Input struct { - *types.Deps - TaskDefinition *ecstypes.TaskDefinition - Network *ecstypes.NetworkConfiguration - PlatformVersion *string - Timeout timeout.Manager -} - -type common struct { - *Input - taskArn *string -} - -func (c *common) Start(ctx context.Context) error { - if c.Env.CanaryInstanceArn != "" { - // ec2 - startTask := &ecs.StartTaskInput{ - Cluster: &c.Env.Cluster, - Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), - NetworkConfiguration: c.Network, - TaskDefinition: c.TaskDefinition.TaskDefinitionArn, - ContainerInstances: []string{c.Env.CanaryInstanceArn}, - } - if o, err := c.Ecs.StartTask(ctx, startTask); err != nil { - return err - } else { - c.taskArn = o.Tasks[0].TaskArn - } - } else { - // fargate - if o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ - Cluster: &c.Env.Cluster, - Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), - NetworkConfiguration: c.Network, - TaskDefinition: c.TaskDefinition.TaskDefinitionArn, - LaunchType: ecstypes.LaunchTypeFargate, - PlatformVersion: c.PlatformVersion, - }); err != nil { - return err - } else { - c.taskArn = o.Tasks[0].TaskArn - } - } - return nil -} - -func (c *common) wait(ctx context.Context) error { - log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) - if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }, c.Timeout.TaskRunning()); err != nil { - return err - } - log.Infof("🐣 canary task '%s' is running!", *c.taskArn) - if err := c.waitUntilHealthCheckPassed(ctx); err != nil { - return err - } - log.Info("🤩 canary task container(s) is healthy!") - log.Infof("canary task '%s' ensured.", *c.taskArn) - return nil -} - -func (c *common) waitUntilHealthCheckPassed(ctx context.Context) error { - log.Infof("😷 ensuring canary task container(s) to become healthy...") - containerHasHealthChecks := map[string]struct{}{} - for _, definition := range c.TaskDefinition.ContainerDefinitions { - if definition.HealthCheck != nil { - containerHasHealthChecks[*definition.Name] = struct{}{} - } - } - healthCheckWait := c.Timeout.TaskHealthCheck() - healthCheckPeriod := 15 * time.Second - countPerPeriod := int(healthCheckWait.Seconds() / 15) - for count := 0; count < countPerPeriod; count++ { - <-c.Time.NewTimer(healthCheckPeriod).C - log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) - if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }); err != nil { - return err - } else { - task := o.Tasks[0] - if *task.LastStatus != "RUNNING" { - return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) - } - - for _, container := range task.Containers { - if _, ok := containerHasHealthChecks[*container.Name]; !ok { - continue - } - if container.HealthStatus != ecstypes.HealthStatusHealthy { - log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) - continue - } - delete(containerHasHealthChecks, *container.Name) - } - if len(containerHasHealthChecks) == 0 { - return nil - } - } - } - return xerrors.Errorf("😨 canary task hasn't become to be healthy") -} - -func (c *common) describeTaskTarget( - ctx context.Context, - getTargetPort func(ctx context.Context) (int32, error), -) (*CanaryTarget, error) { - var targetPort int32 - if v, err := getTargetPort(ctx); err != nil { - return nil, err - } else { - targetPort = v - } - target := CanaryTarget{targetPort: targetPort} - if c.Env.CanaryInstanceArn == "" { // Fargate - if err := c.getFargateTarget(ctx, &target); err != nil { - return nil, err - } - log.Infof("canary task was placed: privateIp = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) - } else { - if err := c.getEc2Target(ctx, &target); err != nil { - return nil, err - } - log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) - } - return &target, nil -} - -func (c *common) getFargateTarget(ctx context.Context, dest *CanaryTarget) error { - var task ecstypes.Task - if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }); err != nil { - return err - } else { - task = o.Tasks[0] - } - details := task.Attachments[0].Details - var subnetId *string - var privateIp *string - for _, v := range details { - if *v.Name == "subnetId" { - subnetId = v.Value - } else if *v.Name == "privateIPv4Address" { - privateIp = v.Value - } - } - if subnetId == nil || privateIp == nil { - return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") - } - if o, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*subnetId}, - }); err != nil { - return err - } else { - dest.targetId = *privateIp - dest.targetIpv4 = *privateIp - dest.availabilityZone = *o.Subnets[0].AvailabilityZone - } - return nil -} - -func (c *common) getEc2Target(ctx context.Context, dest *CanaryTarget) error { - var containerInstance ecstypes.ContainerInstance - if outputs, err := c.Ecs.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ - Cluster: &c.Env.Cluster, - ContainerInstances: []string{c.Env.CanaryInstanceArn}, - }); err != nil { - return err - } else { - containerInstance = outputs.ContainerInstances[0] - } - var ec2Instance ec2types.Instance - if o, err := c.Ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ - InstanceIds: []string{*containerInstance.Ec2InstanceId}, - }); err != nil { - return err - } else { - ec2Instance = o.Reservations[0].Instances[0] - } - if sn, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*ec2Instance.SubnetId}, - }); err != nil { - return err - } else { - dest.targetId = *containerInstance.Ec2InstanceId - dest.targetIpv4 = *ec2Instance.PrivateIpAddress - dest.availabilityZone = *sn.Subnets[0].AvailabilityZone - } - return nil -} - -func (c *common) stopTask(ctx context.Context) error { - if c.taskArn == nil { - log.Info("no canary task to stop") - return nil - } - log.Infof("stopping the canary task '%s'...", *c.taskArn) - if _, err := c.Ecs.StopTask(ctx, &ecs.StopTaskInput{ - Cluster: &c.Env.Cluster, - Task: c.taskArn, - }); err != nil { - return xerrors.Errorf("failed to stop canary task: %w", err) - } - if err := ecs.NewTasksStoppedWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }, c.Timeout.TaskStopped()); err != nil { - return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) - } - log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) - return nil -} From fba195a9cf86b65bcb080326bf973c361fa131fe Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 27 Jun 2024 18:58:04 +0900 Subject: [PATCH 06/47] wip --- cli/cage/commands/command.go | 22 +++-- cli/cage/commands/flags.go | 2 +- env/env.go | 35 ++++---- fixtures/task-definition.json | 165 +--------------------------------- rollout_test.go | 13 +-- task/common.go | 11 ++- 6 files changed, 46 insertions(+), 202 deletions(-) diff --git a/cli/cage/commands/command.go b/cli/cage/commands/command.go index 32fea4d..8c2e721 100644 --- a/cli/cage/commands/command.go +++ b/cli/cage/commands/command.go @@ -68,15 +68,25 @@ func (c *CageCommands) setupCage( envars *env.Envars, dir string, ) (types.Cage, error) { - td, svc, err := env.LoadDefinitionsFromFiles(dir) - if err != nil { + var service *ecs.CreateServiceInput + var taskDefinition *ecs.RegisterTaskDefinitionInput + if srv, err := env.LoadServiceDefiniton(dir); err != nil { return nil, err + } else { + service = srv + } + if envars.TaskDefinitionArn == "" { + if td, err := env.LoadTaskDefiniton(dir); err != nil { + return nil, err + } else { + taskDefinition = td + } } env.MergeEnvars(envars, &env.Envars{ - Cluster: *svc.Cluster, - Service: *svc.ServiceName, - TaskDefinitionInput: td, - ServiceDefinitionInput: svc, + Cluster: *service.Cluster, + Service: *service.ServiceName, + TaskDefinitionInput: taskDefinition, + ServiceDefinitionInput: service, }) if err := env.EnsureEnvars(envars); err != nil { return nil, err diff --git a/cli/cage/commands/flags.go b/cli/cage/commands/flags.go index 009a8c3..0e228b5 100644 --- a/cli/cage/commands/flags.go +++ b/cli/cage/commands/flags.go @@ -34,7 +34,7 @@ func TaskDefinitionArnFlag(dest *string) *cli.StringFlag { return &cli.StringFlag{ Name: "taskDefinitionArn", EnvVars: []string{env.TaskDefinitionArnKey}, - Usage: "full arn for next task definition. if not specified, use task-definition.json for registration", + Usage: "full arn or family:revision of task definition. if not specified, new task definition will be created based on task-definition.json", Destination: dest, } } diff --git a/env/env.go b/env/env.go index 225565f..f9b0e1d 100644 --- a/env/env.go +++ b/env/env.go @@ -50,6 +50,9 @@ func EnsureEnvars( dest *Envars, ) error { // required + if dest.Region == "" { + return xerrors.Errorf("--region [%s] is required", RegionKey) + } if dest.Cluster == "" { return xerrors.Errorf("--cluster [%s] is required", ClusterKey) } else if dest.Service == "" { @@ -58,33 +61,33 @@ func EnsureEnvars( if dest.TaskDefinitionArn == "" && dest.TaskDefinitionInput == nil { return xerrors.Errorf("--nextTaskDefinitionArn or deploy context must be provided") } - if dest.Region == "" { - log.Fatalf("region must be specified. set --region flag or see also https://docs.aws.amazon.com/sdk-for-go/v1/developer-guide/configuring-sdk.html") - } return nil } -func LoadDefinitionsFromFiles(dir string) ( - *ecs.RegisterTaskDefinitionInput, - *ecs.CreateServiceInput, - error, -) { +func LoadServiceDefiniton(dir string) (*ecs.CreateServiceInput, error) { svcPath := filepath.Join(dir, "service.json") - tdPath := filepath.Join(dir, "task-definition.json") _, noSvc := os.Stat(svcPath) - _, noTd := os.Stat(tdPath) var service ecs.CreateServiceInput - var td ecs.RegisterTaskDefinitionInput - if noSvc != nil || noTd != nil { - return nil, nil, xerrors.Errorf("roll out context specified at '%s' but no 'service.json' or 'task-definition.json'", dir) + if noSvc != nil { + return nil, xerrors.Errorf("roll out context specified at '%s' but no 'service.json' or 'task-definition.json'", dir) } if _, err := ReadAndUnmarshalJson(svcPath, &service); err != nil { - return nil, nil, xerrors.Errorf("failed to read and unmarshal service.json: %s", err) + return nil, xerrors.Errorf("failed to read and unmarshal service.json: %s", err) + } + return &service, nil +} + +func LoadTaskDefiniton(dir string) (*ecs.RegisterTaskDefinitionInput, error) { + tdPath := filepath.Join(dir, "task-definition.json") + _, noTd := os.Stat(tdPath) + var td ecs.RegisterTaskDefinitionInput + if noTd != nil { + return nil, xerrors.Errorf("roll out context specified at '%s' but no 'service.json' or 'task-definition.json'", dir) } if _, err := ReadAndUnmarshalJson(tdPath, &td); err != nil { - return nil, nil, xerrors.Errorf("failed to read and unmarshal task-definition.json: %s", err) + return nil, xerrors.Errorf("failed to read and unmarshal task-definition.json: %s", err) } - return &td, &service, nil + return &td, nil } func MergeEnvars(dest *Envars, src *Envars) { diff --git a/fixtures/task-definition.json b/fixtures/task-definition.json index 23ecc4e..639030b 100644 --- a/fixtures/task-definition.json +++ b/fixtures/task-definition.json @@ -7,13 +7,6 @@ { "name": "container", "image": "", - "repositoryCredentials": { - "credentialsParameter": "" - }, - "cpu": 0, - "memory": 0, - "memoryReservation": 0, - "links": [""], "portMappings": [ { "containerPort": 8000, @@ -21,80 +14,6 @@ } ], "essential": true, - "entryPoint": [""], - "command": [""], - "environment": [ - { - "name": "", - "value": "" - } - ], - "mountPoints": [ - { - "sourceVolume": "", - "containerPath": "", - "readOnly": true - } - ], - "volumesFrom": [ - { - "sourceContainer": "", - "readOnly": true - } - ], - "linuxParameters": { - "capabilities": { - "add": [""], - "drop": [""] - }, - "devices": [ - { - "hostPath": "", - "containerPath": "", - "permissions": ["mknod"] - } - ], - "initProcessEnabled": true, - "sharedMemorySize": 0, - "tmpfs": [ - { - "containerPath": "", - "size": 0, - "mountOptions": [""] - } - ] - }, - "hostname": "", - "user": "", - "workingDirectory": "", - "disableNetworking": true, - "privileged": true, - "readonlyRootFilesystem": true, - "dnsServers": [""], - "dnsSearchDomains": [""], - "extraHosts": [ - { - "hostname": "", - "ipAddress": "" - } - ], - "dockerSecurityOptions": [""], - "dockerLabels": { - "KeyName": "" - }, - "ulimits": [ - { - "name": "core", - "softLimit": 0, - "hardLimit": 0 - } - ], - "logConfiguration": { - "logDriver": "gelf", - "options": { - "KeyName": "" - } - }, "healthCheck": { "command": [""], "interval": 0, @@ -106,94 +25,12 @@ { "name": "containerWithoutHealthCheck", "image": "", - "repositoryCredentials": { - "credentialsParameter": "" - }, - "cpu": 0, - "memory": 0, - "memoryReservation": 0, - "links": [""], "portMappings": [ { "containerPort": 8000, "hostPort": 80 } - ], - "essential": true, - "entryPoint": [""], - "command": [""], - "environment": [ - { - "name": "", - "value": "" - } - ], - "mountPoints": [ - { - "sourceVolume": "", - "containerPath": "", - "readOnly": true - } - ], - "volumesFrom": [ - { - "sourceContainer": "", - "readOnly": true - } - ], - "linuxParameters": { - "capabilities": { - "add": [""], - "drop": [""] - }, - "devices": [ - { - "hostPath": "", - "containerPath": "", - "permissions": ["mknod"] - } - ], - "initProcessEnabled": true, - "sharedMemorySize": 0, - "tmpfs": [ - { - "containerPath": "", - "size": 0, - "mountOptions": [""] - } - ] - }, - "hostname": "", - "user": "", - "workingDirectory": "", - "disableNetworking": true, - "privileged": true, - "readonlyRootFilesystem": true, - "dnsServers": [""], - "dnsSearchDomains": [""], - "extraHosts": [ - { - "hostname": "", - "ipAddress": "" - } - ], - "dockerSecurityOptions": [""], - "dockerLabels": { - "KeyName": "" - }, - "ulimits": [ - { - "name": "core", - "softLimit": 0, - "hardLimit": 0 - } - ], - "logConfiguration": { - "logDriver": "gelf", - "options": { - "KeyName": "" - } - } + ] } ], "requiresCompatibilities": ["FARGATE"], diff --git a/rollout_test.go b/rollout_test.go index 06cc18b..68c832c 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -390,12 +390,9 @@ func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { }) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - if err == nil { - t.Fatal("Rollout with no container instance should be error") - } else { - assert.True(t, regexp.MustCompile("canaryInstanceArn is required").MatchString(err.Error())) - assert.NotNil(t, result) - } + assert.NoError(t, err) + assert.True(t, regexp.MustCompile("canaryInstanceArn is required").MatchString(err.Error())) + assert.NotNil(t, result) } func TestCage_RollOut_EC2_no_attribute(t *testing.T) { @@ -425,9 +422,7 @@ func TestCage_RollOut_EC2_no_attribute(t *testing.T) { }) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - if err != nil { - t.Fatalf("%s", err) - } + assert.NoError(t, err) assert.False(t, result.ServiceIntact) assert.Equal(t, 1, mctx.ActiveServiceSize()) assert.Equal(t, 1, mctx.RunningTaskSize()) diff --git a/task/common.go b/task/common.go index 933b581..17e56db 100644 --- a/task/common.go +++ b/task/common.go @@ -6,7 +6,6 @@ import ( "time" "github.com/apex/log" - "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/ecs" @@ -43,16 +42,16 @@ type common struct { } func (c *common) Start(ctx context.Context) error { + group := fmt.Sprintf("cage:canary-task:%s", c.Env.Service) if c.Env.CanaryInstanceArn != "" { // ec2 - startTask := &ecs.StartTaskInput{ + if o, err := c.Ecs.StartTask(ctx, &ecs.StartTaskInput{ Cluster: &c.Env.Cluster, - Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), + Group: &group, NetworkConfiguration: c.NetworkConfiguration, TaskDefinition: c.TaskDefinition.TaskDefinitionArn, ContainerInstances: []string{c.Env.CanaryInstanceArn}, - } - if o, err := c.Ecs.StartTask(ctx, startTask); err != nil { + }); err != nil { return err } else { c.taskArn = o.Tasks[0].TaskArn @@ -61,7 +60,7 @@ func (c *common) Start(ctx context.Context) error { // fargate if o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ Cluster: &c.Env.Cluster, - Group: aws.String(fmt.Sprintf("cage:canary-task:%s", c.Env.Service)), + Group: &group, NetworkConfiguration: c.NetworkConfiguration, TaskDefinition: c.TaskDefinition.TaskDefinitionArn, LaunchType: ecstypes.LaunchTypeFargate, From b14e30eaec81b309d97086403ce28a0442bbd477 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 27 Jun 2024 19:35:25 +0900 Subject: [PATCH 07/47] wip --- rollout_test.go | 4 +-- task/alb_task.go | 81 +++++++++++++++++++++++---------------------- task/simple_task.go | 16 ++++----- task/srv_task.go | 11 +++--- timeout/timeout.go | 17 +++++++--- 5 files changed, 71 insertions(+), 58 deletions(-) diff --git a/rollout_test.go b/rollout_test.go index 68c832c..1618ef3 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -2,7 +2,6 @@ package cage_test import ( "context" - "regexp" "strings" "testing" @@ -390,8 +389,7 @@ func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { }) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.NoError(t, err) - assert.True(t, regexp.MustCompile("canaryInstanceArn is required").MatchString(err.Error())) + assert.ErrorContains(t, err, "canaryInstanceArn is required") assert.NotNil(t, result) } diff --git a/task/alb_task.go b/task/alb_task.go index 9bfd86c..0c91165 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -122,51 +122,54 @@ func (c *albTask) waitUntilTargetHealthy( ) error { log.Infof("checking the health state of canary task...") var unusedCount = 0 - var initialized = false var recentState *elbv2types.TargetHealthStateEnum - for { - <-c.Time.NewTimer(time.Duration(15) * time.Second).C - if o, err := c.Alb.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - Id: &c.target.targetId, - Port: &c.target.targetPort, - AvailabilityZone: &c.target.availabilityZone, - }}, - }); err != nil { - return err - } else { - for _, desc := range o.TargetHealthDescriptions { - if *desc.Target.Id == c.target.targetId && *desc.Target.Port == c.target.targetPort { - recentState = &desc.TargetHealth.State + rest := c.Timeout.TargetHealthCheck() + waitPeriod := 15 * time.Second + for rest > 0 && unusedCount < 5 { + if rest < waitPeriod { + waitPeriod = rest + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.Time.NewTimer(waitPeriod).C: + if o, err := c.Alb.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + Id: &c.target.targetId, + Port: &c.target.targetPort, + AvailabilityZone: &c.target.availabilityZone, + }}, + }); err != nil { + return err + } else { + for _, desc := range o.TargetHealthDescriptions { + if *desc.Target.Id == c.target.targetId && *desc.Target.Port == c.target.targetPort { + recentState = &desc.TargetHealth.State + } } - } - if recentState == nil { - return xerrors.Errorf("'%s' is not registered to the target group '%s'", c.target.targetId, *c.lb.TargetGroupArn) - } - log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.target.targetId, c.target.targetPort, *recentState) - switch *recentState { - case "healthy": - return nil - case "initial": - initialized = true - log.Infof("still checking the state...") - continue - case "unused": - unusedCount++ - if !initialized && unusedCount < 5 { - continue + if recentState == nil { + return xerrors.Errorf("'%s' is not registered to the target group '%s'", c.target.targetId, *c.lb.TargetGroupArn) + } + log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.target.targetId, c.target.targetPort, *recentState) + switch *recentState { + case "healthy": + return nil + case "unused": + unusedCount++ + default: + log.Infof("still checking the state...") } - default: } } - // unhealthy, draining, unused - log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) - return xerrors.Errorf( - "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", - *c.taskArn, c.target.targetId, c.target.targetPort, *recentState, - ) + rest -= waitPeriod } + // unhealthy, draining, unused + log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) + return xerrors.Errorf( + "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", + *c.taskArn, c.target.targetId, c.target.targetPort, *recentState, + ) } func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, error) { diff --git a/task/simple_task.go b/task/simple_task.go index bee5667..69e28e4 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -31,19 +31,19 @@ func (c *simpleTask) Stop(ctx context.Context) error { func (c *simpleTask) waitForIdleDuration(ctx context.Context) error { log.Infof("wait %d seconds for canary task to be stable...", c.Env.CanaryTaskIdleDuration) - duration := c.Env.CanaryTaskIdleDuration - for duration > 0 { - wt := 10 - if duration < 10 { - wt = duration + rest := time.Duration(c.Env.CanaryTaskIdleDuration) * time.Second + waitPeriod := 15 * time.Second + for rest > 0 { + if rest < waitPeriod { + waitPeriod = rest } select { case <-ctx.Done(): return ctx.Err() - case <-c.Time.NewTimer(time.Duration(wt) * time.Second).C: - duration -= 10 + case <-c.Time.NewTimer(waitPeriod).C: + rest -= waitPeriod } - log.Infof("still waiting...; %d seconds left", duration) + log.Infof("still waiting...; %d seconds left", rest) } o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &c.Env.Cluster, diff --git a/task/srv_task.go b/task/srv_task.go index a12d6fe..1eae60a 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -95,9 +95,12 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { func (c *srvTask) waitUntilSrvInstHelthy( ctx context.Context, ) error { - var maxWait = 900 - var waitPeriod = 15 - for maxWait > 0 { + var rest = c.Timeout.TargetHealthCheck() + var waitPeriod = 15 * time.Second + for rest > 0 { + if rest < waitPeriod { + waitPeriod = rest + } select { case <-ctx.Done(): return ctx.Err() @@ -121,7 +124,7 @@ func (c *srvTask) waitUntilSrvInstHelthy( return nil } } - maxWait -= waitPeriod + rest -= waitPeriod } } } diff --git a/timeout/timeout.go b/timeout/timeout.go index e9e7b1a..7ad4ee3 100644 --- a/timeout/timeout.go +++ b/timeout/timeout.go @@ -3,10 +3,11 @@ package timeout import "time" type Input struct { - TaskStoppedWait time.Duration - TaskRunningWait time.Duration - TaskHealthCheckWait time.Duration - ServiceStableWait time.Duration + TaskStoppedWait time.Duration + TaskRunningWait time.Duration + TaskHealthCheckWait time.Duration + TargetHealthCheckWait time.Duration + ServiceStableWait time.Duration } type Manager interface { @@ -14,6 +15,7 @@ type Manager interface { TaskHealthCheck() time.Duration TaskStopped() time.Duration ServiceStable() time.Duration + TargetHealthCheck() time.Duration } type manager struct { @@ -58,3 +60,10 @@ func (t *manager) ServiceStable() time.Duration { } return t.DefaultTimeout } + +func (t *manager) TargetHealthCheck() time.Duration { + if t.TargetHealthCheckWait > 0 { + return t.TargetHealthCheckWait + } + return t.DefaultTimeout +} From 010efae64d5c7456adbe0a874129435c233c2f8b Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 27 Jun 2024 19:55:18 +0900 Subject: [PATCH 08/47] wip --- task/alb_task.go | 93 +++++++++++++++++++++++---------------------- task/common.go | 3 +- task/simple_task.go | 2 +- task/srv_task.go | 20 +++++----- 4 files changed, 61 insertions(+), 57 deletions(-) diff --git a/task/alb_task.go b/task/alb_task.go index 0c91165..0bdb731 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -29,61 +29,23 @@ func NewAlbTask(input *Input, } func (c *albTask) Wait(ctx context.Context) error { - if err := c.wait(ctx); err != nil { + if err := c.waitForTask(ctx); err != nil { return err } if err := c.registerToTargetGroup(ctx); err != nil { return err } - log.Infof("😷 ensuring canary task to become healthy...") + log.Infof("canary task '%s' is registered to target group '%s'", c.target.targetId, *c.lb.TargetGroupArn) + log.Infof("😷 waiting canary target to be healthy...") if err := c.waitUntilTargetHealthy(ctx); err != nil { return err } - log.Info("🤩 canary task is healthy!") + log.Info("🤩 canary target is healthy!") return nil } func (c *albTask) Stop(ctx context.Context) error { - if c.target == nil { - log.Info("no target is registered. skip deregisteration.") - } else { - deregistrationDelay, err := c.targetDeregistrationDelay(ctx) - if err != nil { - log.Errorf("failed to get deregistration delay: %v", err) - log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) - } - log.Infof("deregistering the canary task from target group '%s'...", c.target.targetId) - if _, err := c.Alb.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ - TargetGroupArn: c.lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.target.availabilityZone, - Id: &c.target.targetId, - Port: &c.target.targetPort, - }}, - }); err != nil { - log.Errorf("failed to deregister the canary task from target group: %v", err) - log.Errorf("continuing to stop the canary task...") - } else { - log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") - deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety - if err := elbv2.NewTargetDeregisteredWaiter(c.Alb).Wait(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.target.availabilityZone, - Id: &c.target.targetId, - Port: &c.target.targetPort, - }}, - }, deregisterWait); err != nil { - log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) - log.Errorf("continuing to stop the canary task...") - } else { - log.Infof( - "canary task '%s' has successfully been deregistered from target group '%s'", - *c.taskArn, c.target.targetId, - ) - } - } - } + c.deregisterTarget(ctx) return c.stopTask(ctx) } @@ -97,6 +59,7 @@ func (c *albTask) getTargetPort() (int32, error) { } func (c *albTask) registerToTargetGroup(ctx context.Context) error { + log.Infof("registering the canary task to target group '%s'...", *c.lb.TargetGroupArn) if targetPort, err := c.getTargetPort(); err != nil { return err } else if target, err := c.describeTaskTarget(ctx, targetPort); err != nil { @@ -157,8 +120,6 @@ func (c *albTask) waitUntilTargetHealthy( return nil case "unused": unusedCount++ - default: - log.Infof("still checking the state...") } } } @@ -192,3 +153,45 @@ func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, } return deregistrationDelay, nil } + +func (c *albTask) deregisterTarget(ctx context.Context) { + if c.target == nil { + return + } + deregistrationDelay, err := c.targetDeregistrationDelay(ctx) + if err != nil { + log.Errorf("failed to get deregistration delay: %v", err) + log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) + } + log.Infof("deregistering the canary task from target group '%s'...", c.target.targetId) + if _, err := c.Alb.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + AvailabilityZone: &c.target.availabilityZone, + Id: &c.target.targetId, + Port: &c.target.targetPort, + }}, + }); err != nil { + log.Errorf("failed to deregister the canary task from target group: %v", err) + log.Errorf("continuing to stop the canary task...") + } else { + log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") + deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety + if err := elbv2.NewTargetDeregisteredWaiter(c.Alb).Wait(ctx, &elbv2.DescribeTargetHealthInput{ + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + AvailabilityZone: &c.target.availabilityZone, + Id: &c.target.targetId, + Port: &c.target.targetPort, + }}, + }, deregisterWait); err != nil { + log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) + log.Errorf("continuing to stop the canary task...") + } else { + log.Infof( + "canary task '%s' has successfully been deregistered from target group '%s'", + *c.taskArn, c.target.targetId, + ) + } + } +} diff --git a/task/common.go b/task/common.go index 17e56db..84383ed 100644 --- a/task/common.go +++ b/task/common.go @@ -74,7 +74,7 @@ func (c *common) Start(ctx context.Context) error { return nil } -func (c *common) wait(ctx context.Context) error { +func (c *common) waitForTask(ctx context.Context) error { log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &c.Env.Cluster, @@ -220,7 +220,6 @@ func (c *common) getEc2Target(ctx context.Context, dest *CanaryTarget) error { func (c *common) stopTask(ctx context.Context) error { if c.taskArn == nil { - log.Info("no canary task to stop") return nil } log.Infof("stopping the canary task '%s'...", *c.taskArn) diff --git a/task/simple_task.go b/task/simple_task.go index 69e28e4..f7ae23e 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -19,7 +19,7 @@ func NewSimpleTask(input *Input) Task { } func (c *simpleTask) Wait(ctx context.Context) error { - if err := c.wait(ctx); err != nil { + if err := c.waitForTask(ctx); err != nil { return err } return c.waitForIdleDuration(ctx) diff --git a/task/srv_task.go b/task/srv_task.go index 1eae60a..d58d81b 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -29,24 +29,23 @@ func NewSrvTask(input *Input, registry *ecstypes.ServiceRegistry) Task { } func (c *srvTask) Wait(ctx context.Context) error { - if err := c.wait(ctx); err != nil { + if err := c.waitForTask(ctx); err != nil { return err } if err := c.registerToSrvDiscovery(ctx); err != nil { return err } - log.Infof("😷 ensuring canary task to become healthy...") + log.Infof("canary task '%s' is registered to service discovery instance '%s'", *c.taskArn, *c.inst.InstanceId) + log.Infof("😷 ensuring canary service instance to become healthy...") if err := c.waitUntilSrvInstHelthy(ctx); err != nil { return err } - log.Info("🤩 canary task is healthy!") + log.Info("🤩 canary service instance is healthy!") return nil } func (c *srvTask) Stop(ctx context.Context) error { - if err := c.deregisterSrvInst(ctx); err != nil { - return err - } + c.deregisterSrvInst(ctx) return c.stopTask(ctx) } @@ -133,12 +132,15 @@ func (c *srvTask) waitUntilSrvInstHelthy( func (c *srvTask) deregisterSrvInst( ctx context.Context, -) error { +) { + if c.inst == nil { + return + } if _, err := c.Srv.DeregisterInstance(ctx, &servicediscovery.DeregisterInstanceInput{ ServiceId: c.srv.Id, InstanceId: c.inst.InstanceId, }); err != nil { - return xerrors.Errorf("failed to deregister the canary task from service discovery: %w", err) + log.Errorf("failed to deregister the canary task from service discovery: %v", err) + log.Errorf("continuing to stop the canary task...") } - return nil } From 76ea966fac783b2c9f3459741a9c099ca7e5d168 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Fri, 28 Jun 2024 16:34:24 +0900 Subject: [PATCH 09/47] wip --- Makefile | 10 +- cage.go | 4 +- mocks/mock_cage/task_factory.go | 78 ++++++ mocks/mock_task/task.go | 77 ++++++ rollout.go | 6 +- rollout_test.go | 16 +- task/alb_task_test.go | 1 + task/common.go | 6 - task/task.go | 9 + task_factory.go | 26 ++ test/alb.go | 98 +++++++ test/context.go | 452 ++------------------------------ test/ec2.go | 52 ++++ test/ecs.go | 334 +++++++++++++++++++++++ test/setup.go | 8 +- test/srv.go | 91 +++++++ 16 files changed, 815 insertions(+), 453 deletions(-) create mode 100644 mocks/mock_cage/task_factory.go create mode 100644 mocks/mock_task/task.go create mode 100644 task/alb_task_test.go create mode 100644 task/task.go create mode 100644 task_factory.go create mode 100644 test/alb.go create mode 100644 test/ec2.go create mode 100644 test/ecs.go create mode 100644 test/srv.go diff --git a/Makefile b/Makefile index 2be61e7..758b07f 100644 --- a/Makefile +++ b/Makefile @@ -9,11 +9,19 @@ push-test-container: test-container docker push loilodev/http-server:latest version: go run cli/cage/main.go -v | cut -f 3 -d ' ' -mocks: mocks/mock_awsiface/iface.go mocks/mock_types/iface.go mocks/mock_upgrade/upgrade.go +mocks: mocks/mock_awsiface/iface.go \ + mocks/mock_types/iface.go \ + mocks/mock_upgrade/upgrade.go \ + mocks/mock_cage/task_factory.go \ + mocks/mock_task/task.go mocks/mock_awsiface/iface.go: awsiface/iface.go $(MOCKGEN) -source=./awsiface/iface.go > mocks/mock_awsiface/iface.go mocks/mock_types/iface.go: cage.go $(MOCKGEN) -source=./types/iface.go > mocks/mock_types/iface.go +mocks/mock_cage/task_factory.go: task_factory.go + $(MOCKGEN) -source=./task_factory.go > mocks/mock_cage/task_factory.go mocks/mock_upgrade/upgrade.go: cli/cage/upgrade/upgrade.go $(MOCKGEN) -source=./cli/cage/upgrade/upgrade.go > mocks/mock_upgrade/upgrade.go +mocks/mock_task/task.go: task/task.go + $(MOCKGEN) -source=./task/task.go > mocks/mock_task/task.go .PHONY: mocks diff --git a/cage.go b/cage.go index 553d0f6..2c12236 100644 --- a/cage.go +++ b/cage.go @@ -9,7 +9,8 @@ import ( type cage struct { *types.Deps - Timeout timeout.Manager + Timeout timeout.Manager + TaskFactory TaskFactory } func NewCage(input *types.Deps) types.Cage { @@ -30,5 +31,6 @@ func NewCage(input *types.Deps) types.Cage { TaskStoppedWait: taskStoppedWait, ServiceStableWait: serviceStableWait, }), + TaskFactory: &taskFactory{}, } } diff --git a/mocks/mock_cage/task_factory.go b/mocks/mock_cage/task_factory.go new file mode 100644 index 0000000..a8f7cea --- /dev/null +++ b/mocks/mock_cage/task_factory.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./task_factory.go + +// Package mock_cage is a generated GoMock package. +package mock_cage + +import ( + reflect "reflect" + + types "github.com/aws/aws-sdk-go-v2/service/ecs/types" + gomock "github.com/golang/mock/gomock" + task "github.com/loilo-inc/canarycage/task" +) + +// MockTaskFactory is a mock of TaskFactory interface. +type MockTaskFactory struct { + ctrl *gomock.Controller + recorder *MockTaskFactoryMockRecorder +} + +// MockTaskFactoryMockRecorder is the mock recorder for MockTaskFactory. +type MockTaskFactoryMockRecorder struct { + mock *MockTaskFactory +} + +// NewMockTaskFactory creates a new mock instance. +func NewMockTaskFactory(ctrl *gomock.Controller) *MockTaskFactory { + mock := &MockTaskFactory{ctrl: ctrl} + mock.recorder = &MockTaskFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTaskFactory) EXPECT() *MockTaskFactoryMockRecorder { + return m.recorder +} + +// NewAlbTask mocks base method. +func (m *MockTaskFactory) NewAlbTask(input *task.Input, lb *types.LoadBalancer) task.Task { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewAlbTask", input, lb) + ret0, _ := ret[0].(task.Task) + return ret0 +} + +// NewAlbTask indicates an expected call of NewAlbTask. +func (mr *MockTaskFactoryMockRecorder) NewAlbTask(input, lb interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAlbTask", reflect.TypeOf((*MockTaskFactory)(nil).NewAlbTask), input, lb) +} + +// NewSimpleTask mocks base method. +func (m *MockTaskFactory) NewSimpleTask(input *task.Input) task.Task { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSimpleTask", input) + ret0, _ := ret[0].(task.Task) + return ret0 +} + +// NewSimpleTask indicates an expected call of NewSimpleTask. +func (mr *MockTaskFactoryMockRecorder) NewSimpleTask(input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSimpleTask", reflect.TypeOf((*MockTaskFactory)(nil).NewSimpleTask), input) +} + +// NewSrvTask mocks base method. +func (m *MockTaskFactory) NewSrvTask(input *task.Input, srv *types.ServiceRegistry) task.Task { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSrvTask", input, srv) + ret0, _ := ret[0].(task.Task) + return ret0 +} + +// NewSrvTask indicates an expected call of NewSrvTask. +func (mr *MockTaskFactoryMockRecorder) NewSrvTask(input, srv interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSrvTask", reflect.TypeOf((*MockTaskFactory)(nil).NewSrvTask), input, srv) +} diff --git a/mocks/mock_task/task.go b/mocks/mock_task/task.go new file mode 100644 index 0000000..fa6ea25 --- /dev/null +++ b/mocks/mock_task/task.go @@ -0,0 +1,77 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./task/task.go + +// Package mock_task is a generated GoMock package. +package mock_task + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockTask is a mock of Task interface. +type MockTask struct { + ctrl *gomock.Controller + recorder *MockTaskMockRecorder +} + +// MockTaskMockRecorder is the mock recorder for MockTask. +type MockTaskMockRecorder struct { + mock *MockTask +} + +// NewMockTask creates a new mock instance. +func NewMockTask(ctrl *gomock.Controller) *MockTask { + mock := &MockTask{ctrl: ctrl} + mock.recorder = &MockTaskMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTask) EXPECT() *MockTaskMockRecorder { + return m.recorder +} + +// Start mocks base method. +func (m *MockTask) Start(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Start", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Start indicates an expected call of Start. +func (mr *MockTaskMockRecorder) Start(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockTask)(nil).Start), ctx) +} + +// Stop mocks base method. +func (m *MockTask) Stop(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Stop", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Stop indicates an expected call of Stop. +func (mr *MockTaskMockRecorder) Stop(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockTask)(nil).Stop), ctx) +} + +// Wait mocks base method. +func (m *MockTask) Wait(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Wait", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Wait indicates an expected call of Wait. +func (mr *MockTaskMockRecorder) Wait(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Wait", reflect.TypeOf((*MockTask)(nil).Wait), ctx) +} diff --git a/rollout.go b/rollout.go index 3f75e7b..13ef773 100644 --- a/rollout.go +++ b/rollout.go @@ -137,7 +137,7 @@ func (c *cage) StartCanaryTasks( } var results []task.Task for _, lb := range loadBalancers { - task := task.NewAlbTask(&task.Input{ + task := c.TaskFactory.NewAlbTask(&task.Input{ Deps: c.Deps, NetworkConfiguration: networkConfiguration, TaskDefinition: nextTaskDefinition, @@ -150,7 +150,7 @@ func (c *cage) StartCanaryTasks( } } for _, srv := range serviceRegistries { - task := task.NewSrvTask(&task.Input{ + task := c.TaskFactory.NewSrvTask(&task.Input{ Deps: c.Deps, NetworkConfiguration: networkConfiguration, TaskDefinition: nextTaskDefinition, @@ -163,7 +163,7 @@ func (c *cage) StartCanaryTasks( } } if len(results) == 0 { - task := task.NewSimpleTask(&task.Input{ + task := c.TaskFactory.NewSimpleTask(&task.Input{ Deps: c.Deps, NetworkConfiguration: networkConfiguration, TaskDefinition: nextTaskDefinition, diff --git a/rollout_test.go b/rollout_test.go index 1618ef3..97fbcab 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -79,9 +79,9 @@ func TestCage_RollOut_FARGATE(t *testing.T) { mocker, ecsMock, _, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") albMock := mock_awsiface.NewMockAlbClient(ctrl) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTarget).Times(1) - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTarget).Times(1) - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttibutes).Times(1) + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTargets).Times(1) + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTargets).Times(1) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttributes).Times(1) gomock.InOrder( albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ @@ -115,9 +115,9 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mocker, ecsMock, _, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") albMock := mock_awsiface.NewMockAlbClient(ctrl) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTarget).Times(1) - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTarget).Times(1) - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttibutes).Times(1) + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTargets).Times(1) + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTargets).Times(1) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttributes).Times(1) gomock.InOrder( albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ TargetHealthDescriptions: []elbv2types.TargetHealthDescription{{ @@ -179,13 +179,13 @@ func TestCage_RollOut_FARGATE(t *testing.T) { Time: test.NewFakeTime(), }) ctx := context.Background() - service, _ := mctx.GetService(envars.Service) + service, _ := mctx.GetEcsService(envars.Service) assert.Equal(t, "1.4.0", *service.PlatformVersion) assert.NotNil(t, service.NetworkConfiguration) assert.NotNil(t, service.LoadBalancers) _, err := cagecli.RollOut(ctx, &types.RollOutInput{UpdateService: true}) assert.NoError(t, err) - service, _ = mctx.GetService(envars.Service) + service, _ = mctx.GetEcsService(envars.Service) assert.Equal(t, "LATEST", *service.PlatformVersion) assert.Equal(t, *newNetwork, *service.NetworkConfiguration) assert.Equal(t, *service.LoadBalancers[0].ContainerName, *newLb.ContainerName) diff --git a/task/alb_task_test.go b/task/alb_task_test.go new file mode 100644 index 0000000..651b54f --- /dev/null +++ b/task/alb_task_test.go @@ -0,0 +1 @@ +package task_test diff --git a/task/common.go b/task/common.go index 84383ed..0cb3a5e 100644 --- a/task/common.go +++ b/task/common.go @@ -22,12 +22,6 @@ type CanaryTarget struct { availabilityZone string } -type Task interface { - Start(ctx context.Context) error - Wait(ctx context.Context) error - Stop(ctx context.Context) error -} - type Input struct { *types.Deps TaskDefinition *ecstypes.TaskDefinition diff --git a/task/task.go b/task/task.go new file mode 100644 index 0000000..b5fed47 --- /dev/null +++ b/task/task.go @@ -0,0 +1,9 @@ +package task + +import "context" + +type Task interface { + Start(ctx context.Context) error + Wait(ctx context.Context) error + Stop(ctx context.Context) error +} diff --git a/task_factory.go b/task_factory.go new file mode 100644 index 0000000..bb11fcc --- /dev/null +++ b/task_factory.go @@ -0,0 +1,26 @@ +package cage + +import ( + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/task" +) + +type TaskFactory interface { + NewAlbTask(input *task.Input, lb *ecstypes.LoadBalancer) task.Task + NewSrvTask(input *task.Input, srv *ecstypes.ServiceRegistry) task.Task + NewSimpleTask(input *task.Input) task.Task +} + +type taskFactory struct{} + +func (f *taskFactory) NewAlbTask(input *task.Input, lb *ecstypes.LoadBalancer) task.Task { + return task.NewAlbTask(input, lb) +} + +func (f *taskFactory) NewSrvTask(input *task.Input, srv *ecstypes.ServiceRegistry) task.Task { + return task.NewSrvTask(input, srv) +} + +func (f *taskFactory) NewSimpleTask(input *task.Input) task.Task { + return task.NewSimpleTask(input) +} diff --git a/test/alb.go b/test/alb.go new file mode 100644 index 0000000..57782ec --- /dev/null +++ b/test/alb.go @@ -0,0 +1,98 @@ +package test + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" +) + +type AlbServer struct { + *commons +} + +func (ctx *AlbServer) DescribeTargetGroups(_ context.Context, input *elbv2.DescribeTargetGroupsInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) { + return &elbv2.DescribeTargetGroupsOutput{ + TargetGroups: []elbv2types.TargetGroup{ + { + TargetGroupName: aws.String("tgname"), + TargetGroupArn: aws.String(input.TargetGroupArns[0]), + HealthyThresholdCount: aws.Int32(1), + HealthCheckIntervalSeconds: aws.Int32(0), + LoadBalancerArns: []string{"arn://hoge/app/aa/bb"}, + }, + }, + }, nil +} +func (ctx *AlbServer) DescribeTargetGroupAttributes(_ context.Context, input *elbv2.DescribeTargetGroupAttributesInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) { + return &elbv2.DescribeTargetGroupAttributesOutput{ + Attributes: []elbv2types.TargetGroupAttribute{ + { + Key: aws.String("deregistration_delay.timeout_seconds"), + Value: aws.String("0"), + }, + }, + }, nil +} +func (ctx *AlbServer) DescribeTargetHealth(_ context.Context, input *elbv2.DescribeTargetHealthInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { + if _, ok := ctx.TargetGroups[*input.TargetGroupArn]; !ok { + return &elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ + { + Target: &elbv2types.TargetDescription{ + Id: input.Targets[0].Id, + Port: input.Targets[0].Port, + AvailabilityZone: aws.String("us-west-2"), + }, + TargetHealth: &elbv2types.TargetHealth{ + State: elbv2types.TargetHealthStateEnumUnused, + }, + }, + }, + }, nil + } + + var ret []elbv2types.TargetHealthDescription + for _, task := range ctx.Tasks { + if task.LastStatus != nil && *task.LastStatus == "RUNNING" { + ret = append(ret, elbv2types.TargetHealthDescription{ + Target: &elbv2types.TargetDescription{ + Id: input.Targets[0].Id, + Port: input.Targets[0].Port, + AvailabilityZone: aws.String("us-west-2"), + }, + TargetHealth: &elbv2types.TargetHealth{ + State: elbv2types.TargetHealthStateEnumHealthy, + }, + }) + } + } + return &elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: ret, + }, nil +} + +func (ctx *AlbServer) RegisterTargets(_ context.Context, input *elbv2.RegisterTargetsInput, _ ...func(options *elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { + ctx.TargetGroups[*input.TargetGroupArn] = struct{}{} + return &elbv2.RegisterTargetsOutput{}, nil +} + +func (ctx *AlbServer) DeregisterTargets(_ context.Context, input *elbv2.DeregisterTargetsInput, _ ...func(options *elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { + delete(ctx.TargetGroups, *input.TargetGroupArn) + return &elbv2.DeregisterTargetsOutput{}, nil +} + +func (ctx *AlbServer) DescribeInstances(_ context.Context, input *ec2.DescribeInstancesInput, _ ...func(options *ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{{ + Instances: []ec2types.Instance{{ + InstanceId: aws.String("i-123456"), + PrivateIpAddress: aws.String("127.0.1.0"), + SubnetId: aws.String("us-west-2a"), + }}, + }}, + }, nil +} diff --git a/test/context.go b/test/context.go index 746aa92..00aab32 100644 --- a/test/context.go +++ b/test/context.go @@ -1,23 +1,13 @@ package test import ( - "context" - "fmt" - "regexp" "sync" - "github.com/apex/log" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/ec2" - ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/ecs/types" - elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" - "github.com/google/uuid" + "github.com/loilo-inc/canarycage/awsiface" ) -type MockContext struct { +type commons struct { Services map[string]*types.Service Tasks map[string]*types.Task TaskDefinitions *TaskDefinitionRepository @@ -25,8 +15,16 @@ type MockContext struct { mux sync.Mutex } +type MockContext struct { + *commons + awsiface.EcsClient + awsiface.AlbClient + awsiface.Ec2Client + awsiface.SrvClient +} + func NewMockContext() *MockContext { - return &MockContext{ + cm := &commons{ Services: make(map[string]*types.Service), Tasks: make(map[string]*types.Task), TaskDefinitions: &TaskDefinitionRepository{ @@ -34,16 +32,23 @@ func NewMockContext() *MockContext { }, TargetGroups: make(map[string]struct{}), } + return &MockContext{ + commons: cm, + EcsClient: &EcsServer{commons: cm}, + Ec2Client: &Ec2Server{commons: cm}, + SrvClient: &SrvServer{commons: cm}, + AlbClient: &AlbServer{commons: cm}, + } } -func (ctx *MockContext) GetTask(id string) (*types.Task, bool) { +func (ctx *commons) GetTask(id string) (*types.Task, bool) { ctx.mux.Lock() defer ctx.mux.Unlock() o, ok := ctx.Tasks[id] return o, ok } -func (ctx *MockContext) RunningTaskSize() int { +func (ctx *commons) RunningTaskSize() int { ctx.mux.Lock() defer ctx.mux.Unlock() @@ -57,14 +62,14 @@ func (ctx *MockContext) RunningTaskSize() int { return count } -func (ctx *MockContext) GetService(id string) (*types.Service, bool) { +func (ctx *commons) GetEcsService(id string) (*types.Service, bool) { ctx.mux.Lock() defer ctx.mux.Unlock() o, ok := ctx.Services[id] return o, ok } -func (ctx *MockContext) ActiveServiceSize() (count int) { +func (ctx *commons) ActiveServiceSize() (count int) { ctx.mux.Lock() defer ctx.mux.Unlock() for _, v := range ctx.Services { @@ -74,416 +79,3 @@ func (ctx *MockContext) ActiveServiceSize() (count int) { } return } - -func (ctx *MockContext) CreateService(c context.Context, input *ecs.CreateServiceInput, _ ...func(options *ecs.Options)) (*ecs.CreateServiceOutput, error) { - idstr := uuid.New().String() - st := "ACTIVE" - if old, ok := ctx.Services[*input.ServiceName]; ok { - if *old.Status == "ACTIVE" { - return nil, fmt.Errorf("service already exists: %s", *input.ServiceName) - } - } - ret := &types.Service{ - ServiceName: input.ServiceName, - RunningCount: 0, - LaunchType: input.LaunchType, - LoadBalancers: input.LoadBalancers, - DesiredCount: *input.DesiredCount, - TaskDefinition: input.TaskDefinition, - HealthCheckGracePeriodSeconds: aws.Int32(0), - Status: &st, - ServiceArn: &idstr, - PlatformVersion: input.PlatformVersion, - ServiceRegistries: input.ServiceRegistries, - NetworkConfiguration: input.NetworkConfiguration, - Deployments: []types.Deployment{ - { - DesiredCount: *input.DesiredCount, - LaunchType: input.LaunchType, - RunningCount: *input.DesiredCount, - Status: &st, - TaskDefinition: input.TaskDefinition, - }, - }, - } - ctx.mux.Lock() - ctx.Services[*input.ServiceName] = ret - ctx.mux.Unlock() - log.Debugf("%s: running=%d, desired=%d", *input.ServiceName, ret.RunningCount, *input.DesiredCount) - for i := 0; i < int(*input.DesiredCount); i++ { - ctx.StartTask(c, &ecs.StartTaskInput{ - Cluster: input.Cluster, - Group: aws.String(fmt.Sprintf("service:%s", *input.ServiceName)), - NetworkConfiguration: input.NetworkConfiguration, - TaskDefinition: input.TaskDefinition, - }) - } - ctx.mux.Lock() - ctx.Services[*input.ServiceName].RunningCount = *input.DesiredCount - ctx.mux.Unlock() - log.Debugf("%s: running=%d", *input.ServiceName, ret.RunningCount) - return &ecs.CreateServiceOutput{ - Service: ret, - }, nil -} - -func (ctx *MockContext) UpdateService(c context.Context, input *ecs.UpdateServiceInput, _ ...func(options *ecs.Options)) (*ecs.UpdateServiceOutput, error) { - ctx.mux.Lock() - s := ctx.Services[*input.Service] - ctx.mux.Unlock() - nextDesiredCount := s.DesiredCount - nextTaskDefinition := s.TaskDefinition - if input.TaskDefinition != nil { - nextTaskDefinition = input.TaskDefinition - } - if input.DesiredCount != nil { - nextDesiredCount = *input.DesiredCount - } - if diff := nextDesiredCount - s.DesiredCount; diff > 0 { - log.Debugf("diff=%d", diff) - // scale - for i := 0; i < int(diff); i++ { - ctx.StartTask(c, &ecs.StartTaskInput{ - Cluster: input.Cluster, - Group: aws.String(fmt.Sprintf("service:%s", *input.Service)), - TaskDefinition: nextTaskDefinition, - }) - } - } else if diff < 0 { - // descale - var i int32 = 0 - max := -diff - for k, v := range ctx.Tasks { - reg := regexp.MustCompile("service:" + *s.ServiceName) - if reg.MatchString(*v.Group) { - ctx.StopTask(c, &ecs.StopTaskInput{ - Cluster: input.Cluster, - Task: &k, - }) - i++ - if i >= max { - break - } - } - } - } - ctx.mux.Lock() - s.DesiredCount = nextDesiredCount - s.TaskDefinition = nextTaskDefinition - s.RunningCount = nextDesiredCount - s.PlatformVersion = input.PlatformVersion - s.ServiceRegistries = input.ServiceRegistries - s.NetworkConfiguration = input.NetworkConfiguration - s.LoadBalancers = input.LoadBalancers - s.Deployments = []types.Deployment{ - { - DesiredCount: nextDesiredCount, - LaunchType: s.LaunchType, - RunningCount: nextDesiredCount, - Status: s.Status, - TaskDefinition: s.TaskDefinition, - }, - } - ctx.mux.Unlock() - return &ecs.UpdateServiceOutput{ - Service: s, - }, nil -} - -func (ctx *MockContext) DeleteService(c context.Context, input *ecs.DeleteServiceInput, _ ...func(options *ecs.Options)) (*ecs.DeleteServiceOutput, error) { - service := ctx.Services[*input.Service] - reg := regexp.MustCompile(fmt.Sprintf("service:%s", *service.ServiceName)) - for _, v := range ctx.Tasks { - if reg.MatchString(*v.Group) { - _, err := ctx.StopTask(c, &ecs.StopTaskInput{ - Cluster: input.Cluster, - Task: v.TaskArn, - }) - if err != nil { - return nil, err - } - } - } - ctx.mux.Lock() - defer ctx.mux.Unlock() - service.Status = aws.String("INACTIVE") - return &ecs.DeleteServiceOutput{Service: service}, nil -} - -func (ctx *MockContext) RegisterTaskDefinition(_ context.Context, input *ecs.RegisterTaskDefinitionInput, _ ...func(options *ecs.Options)) (*ecs.RegisterTaskDefinitionOutput, error) { - td, err := ctx.TaskDefinitions.Register(input) - if err != nil { - return nil, err - } - return &ecs.RegisterTaskDefinitionOutput{TaskDefinition: td}, nil -} - -func (ctx *MockContext) StartTask(_ context.Context, input *ecs.StartTaskInput, _ ...func(options *ecs.Options)) (*ecs.StartTaskOutput, error) { - ctx.mux.Lock() - defer ctx.mux.Unlock() - td := ctx.TaskDefinitions.Get(*input.TaskDefinition) - if td == nil { - return nil, fmt.Errorf("task definition not found: %s", *input.TaskDefinition) - } - taskArn := fmt.Sprintf("arn:aws:ecs:us-west-2:012345678910:task/%s", uuid.New().String()) - attachment := types.Attachment{ - Details: []types.KeyValuePair{ - { - Name: aws.String("privateIPv4Address"), - Value: aws.String("127.0.0.1"), - }, - }, - } - if input.NetworkConfiguration != nil { - subnet := input.NetworkConfiguration.AwsvpcConfiguration.Subnets[0] - attachment.Details = append(attachment.Details, types.KeyValuePair{ - Name: aws.String("subnetId"), - Value: &subnet, - }) - } - containers := make([]types.Container, len(td.ContainerDefinitions)) - for i, v := range td.ContainerDefinitions { - containers[i] = types.Container{ - Name: v.Name, - Image: v.Image, - LastStatus: aws.String("RUNNING"), - } - if v.HealthCheck != nil { - containers[i].HealthStatus = "HEALTHY" - } else { - containers[i].HealthStatus = "UNKNOWN" - } - } - - ret := types.Task{ - TaskArn: &taskArn, - ClusterArn: input.Cluster, - TaskDefinitionArn: input.TaskDefinition, - Group: input.Group, - Containers: containers, - } - ctx.Tasks[taskArn] = &ret - var launchType types.LaunchType - if len(input.ContainerInstances) > 0 { - launchType = types.LaunchTypeEc2 - } else { - launchType = types.LaunchTypeFargate - } - ret.LaunchType = launchType - if launchType == types.LaunchTypeFargate { - ret.Attachments = []types.Attachment{attachment} - } else { - ret.ContainerInstanceArn = aws.String("arn:aws:ecs:us-west-2:1234567890:container-instance/12345678-hoge-hoge-1234-1f2o3o4ba5r") - } - ret.LastStatus = aws.String("RUNNING") - return &ecs.StartTaskOutput{ - Tasks: []types.Task{ret}, - }, nil -} - -func (ctx *MockContext) RunTask(c context.Context, input *ecs.RunTaskInput, _ ...func(options *ecs.Options)) (*ecs.RunTaskOutput, error) { - o, err := ctx.StartTask(c, &ecs.StartTaskInput{ - Cluster: input.Cluster, - Group: input.Group, - TaskDefinition: input.TaskDefinition, - NetworkConfiguration: input.NetworkConfiguration, - }) - if err != nil { - return nil, err - } - return &ecs.RunTaskOutput{ - Tasks: o.Tasks, - }, nil -} - -func (ctx *MockContext) StopTask(_ context.Context, input *ecs.StopTaskInput, _ ...func(options *ecs.Options)) (*ecs.StopTaskOutput, error) { - ctx.mux.Lock() - defer ctx.mux.Unlock() - log.Debugf("stop: %s", *input.Task) - ret, ok := ctx.Tasks[*input.Task] - if !ok { - return nil, fmt.Errorf("task not found: %s", *input.Task) - } - for i := range ret.Containers { - v := &ret.Containers[i] - v.ExitCode = aws.Int32(0) - v.LastStatus = aws.String("STOPPED") - } - ret.LastStatus = aws.String("STOPPED") - ret.DesiredStatus = aws.String("STOPPED") - service, ok := ctx.Services[*ret.Group] - if ok { - service.RunningCount -= 1 - } - return &ecs.StopTaskOutput{Task: ret}, nil -} - -func (ctx *MockContext) ListTasks(_ context.Context, input *ecs.ListTasksInput, _ ...func(options *ecs.Options)) (*ecs.ListTasksOutput, error) { - var ret []string - ctx.mux.Lock() - defer ctx.mux.Unlock() - for _, v := range ctx.Tasks { - group := fmt.Sprintf("service:%s", *input.ServiceName) - if *v.Group == group { - ret = append(ret, *v.TaskArn) - } - } - return &ecs.ListTasksOutput{ - TaskArns: ret, - }, nil -} - -func (ctx *MockContext) DescribeServices(_ context.Context, input *ecs.DescribeServicesInput, _ ...func(options *ecs.Options)) (*ecs.DescribeServicesOutput, error) { - var ret []types.Service - ctx.mux.Lock() - defer ctx.mux.Unlock() - for _, v := range input.Services { - if s, ok := ctx.Services[v]; ok { - ret = append(ret, *s) - } - } - return &ecs.DescribeServicesOutput{ - Services: ret, - }, nil -} - -func (ctx *MockContext) DescribeTasks(_ context.Context, input *ecs.DescribeTasksInput, _ ...func(options *ecs.Options)) (*ecs.DescribeTasksOutput, error) { - ctx.mux.Lock() - defer ctx.mux.Unlock() - var ret []types.Task - for _, task := range ctx.Tasks { - for _, v := range input.Tasks { - if *task.TaskArn == v { - ret = append(ret, *task) - } - } - } - return &ecs.DescribeTasksOutput{ - Tasks: ret, - }, nil -} -func (ctx *MockContext) DescribeContainerInstances(_ context.Context, input *ecs.DescribeContainerInstancesInput, _ ...func(options *ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) { - ctx.mux.Lock() - defer ctx.mux.Unlock() - var ret []types.ContainerInstance - ec2Id := "i-1234567890abcdefg" - instance := types.ContainerInstance{ - Ec2InstanceId: &ec2Id, - } - ret = append(ret, instance) - return &ecs.DescribeContainerInstancesOutput{ - ContainerInstances: ret, - }, nil -} - -// - -func (ctx *MockContext) DescribeTargetGroups(_ context.Context, input *elbv2.DescribeTargetGroupsInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) { - return &elbv2.DescribeTargetGroupsOutput{ - TargetGroups: []elbv2types.TargetGroup{ - { - TargetGroupName: aws.String("tgname"), - TargetGroupArn: aws.String(input.TargetGroupArns[0]), - HealthyThresholdCount: aws.Int32(1), - HealthCheckIntervalSeconds: aws.Int32(0), - LoadBalancerArns: []string{"arn://hoge/app/aa/bb"}, - }, - }, - }, nil -} -func (ctx *MockContext) DescribeTargetGroupAttibutes(_ context.Context, input *elbv2.DescribeTargetGroupAttributesInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetGroupAttributesOutput, error) { - return &elbv2.DescribeTargetGroupAttributesOutput{ - Attributes: []elbv2types.TargetGroupAttribute{ - { - Key: aws.String("deregistration_delay.timeout_seconds"), - Value: aws.String("0"), - }, - }, - }, nil -} -func (ctx *MockContext) DescribeTargetHealth(_ context.Context, input *elbv2.DescribeTargetHealthInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { - if _, ok := ctx.TargetGroups[*input.TargetGroupArn]; !ok { - return &elbv2.DescribeTargetHealthOutput{ - TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ - { - Target: &elbv2types.TargetDescription{ - Id: input.Targets[0].Id, - Port: input.Targets[0].Port, - AvailabilityZone: aws.String("us-west-2"), - }, - TargetHealth: &elbv2types.TargetHealth{ - State: elbv2types.TargetHealthStateEnumUnused, - }, - }, - }, - }, nil - } - - var ret []elbv2types.TargetHealthDescription - for _, task := range ctx.Tasks { - if task.LastStatus != nil && *task.LastStatus == "RUNNING" { - ret = append(ret, elbv2types.TargetHealthDescription{ - Target: &elbv2types.TargetDescription{ - Id: input.Targets[0].Id, - Port: input.Targets[0].Port, - AvailabilityZone: aws.String("us-west-2"), - }, - TargetHealth: &elbv2types.TargetHealth{ - State: elbv2types.TargetHealthStateEnumHealthy, - }, - }) - } - } - return &elbv2.DescribeTargetHealthOutput{ - TargetHealthDescriptions: ret, - }, nil -} - -func (ctx *MockContext) RegisterTarget(_ context.Context, input *elbv2.RegisterTargetsInput, _ ...func(options *elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { - ctx.TargetGroups[*input.TargetGroupArn] = struct{}{} - return &elbv2.RegisterTargetsOutput{}, nil -} - -func (ctx *MockContext) DeregisterTarget(_ context.Context, input *elbv2.DeregisterTargetsInput, _ ...func(options *elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { - delete(ctx.TargetGroups, *input.TargetGroupArn) - return &elbv2.DeregisterTargetsOutput{}, nil -} - -func (ctx *MockContext) DescribeInstances(_ context.Context, input *ec2.DescribeInstancesInput, _ ...func(options *ec2.Options)) (*ec2.DescribeInstancesOutput, error) { - return &ec2.DescribeInstancesOutput{ - Reservations: []ec2types.Reservation{{ - Instances: []ec2types.Instance{{ - InstanceId: aws.String("i-123456"), - PrivateIpAddress: aws.String("127.0.1.0"), - SubnetId: aws.String("us-west-2a"), - }}, - }}, - }, nil -} - -func (ctx *MockContext) DescribeSubnets(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...func(options *ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { - return &ec2.DescribeSubnetsOutput{ - Subnets: []ec2types.Subnet{{ - AvailabilityZone: aws.String("us-west-2"), - AvailabilityZoneId: nil, - AvailableIpAddressCount: nil, - CidrBlock: nil, - CustomerOwnedIpv4Pool: nil, - DefaultForAz: nil, - EnableDns64: nil, - EnableLniAtDeviceIndex: nil, - Ipv6CidrBlockAssociationSet: nil, - Ipv6Native: nil, - MapCustomerOwnedIpOnLaunch: nil, - MapPublicIpOnLaunch: nil, - OutpostArn: nil, - OwnerId: nil, - PrivateDnsNameOptionsOnLaunch: nil, - State: ec2types.SubnetStateAvailable, - SubnetArn: nil, - SubnetId: aws.String("subnet-1234567890abcdefg"), - Tags: nil, - VpcId: nil, - }}, - }, nil -} diff --git a/test/ec2.go b/test/ec2.go new file mode 100644 index 0000000..8e677c1 --- /dev/null +++ b/test/ec2.go @@ -0,0 +1,52 @@ +package test + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +type Ec2Server struct { + *commons +} + +func (c *Ec2Server) DescribeSubnets(_ context.Context, input *ec2.DescribeSubnetsInput, _ ...func(options *ec2.Options)) (*ec2.DescribeSubnetsOutput, error) { + return &ec2.DescribeSubnetsOutput{ + Subnets: []ec2types.Subnet{{ + AvailabilityZone: aws.String("us-west-2"), + AvailabilityZoneId: nil, + AvailableIpAddressCount: nil, + CidrBlock: nil, + CustomerOwnedIpv4Pool: nil, + DefaultForAz: nil, + EnableDns64: nil, + EnableLniAtDeviceIndex: nil, + Ipv6CidrBlockAssociationSet: nil, + Ipv6Native: nil, + MapCustomerOwnedIpOnLaunch: nil, + MapPublicIpOnLaunch: nil, + OutpostArn: nil, + OwnerId: nil, + PrivateDnsNameOptionsOnLaunch: nil, + State: ec2types.SubnetStateAvailable, + SubnetArn: nil, + SubnetId: aws.String("subnet-1234567890abcdefg"), + Tags: nil, + VpcId: nil, + }}, + }, nil +} + +func (c *Ec2Server) DescribeInstances(_ context.Context, input *ec2.DescribeInstancesInput, _ ...func(options *ec2.Options)) (*ec2.DescribeInstancesOutput, error) { + return &ec2.DescribeInstancesOutput{ + Reservations: []ec2types.Reservation{{ + Instances: []ec2types.Instance{{ + InstanceId: aws.String("i-123456"), + PrivateIpAddress: aws.String("127.0.1.0"), + SubnetId: aws.String("us-west-2a"), + }}, + }}, + }, nil +} diff --git a/test/ecs.go b/test/ecs.go new file mode 100644 index 0000000..727bd1a --- /dev/null +++ b/test/ecs.go @@ -0,0 +1,334 @@ +package test + +import ( + "context" + "fmt" + "regexp" + + "github.com/apex/log" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/google/uuid" +) + +type EcsServer struct { + *commons +} + +func (ctx *EcsServer) CreateService(c context.Context, input *ecs.CreateServiceInput, _ ...func(options *ecs.Options)) (*ecs.CreateServiceOutput, error) { + idstr := uuid.New().String() + st := "ACTIVE" + if old, ok := ctx.Services[*input.ServiceName]; ok { + if *old.Status == "ACTIVE" { + return nil, fmt.Errorf("service already exists: %s", *input.ServiceName) + } + } + ret := &types.Service{ + ServiceName: input.ServiceName, + RunningCount: 0, + LaunchType: input.LaunchType, + LoadBalancers: input.LoadBalancers, + DesiredCount: *input.DesiredCount, + TaskDefinition: input.TaskDefinition, + HealthCheckGracePeriodSeconds: aws.Int32(0), + Status: &st, + ServiceArn: &idstr, + PlatformVersion: input.PlatformVersion, + ServiceRegistries: input.ServiceRegistries, + NetworkConfiguration: input.NetworkConfiguration, + Deployments: []types.Deployment{ + { + DesiredCount: *input.DesiredCount, + LaunchType: input.LaunchType, + RunningCount: *input.DesiredCount, + Status: &st, + TaskDefinition: input.TaskDefinition, + }, + }, + } + ctx.mux.Lock() + ctx.Services[*input.ServiceName] = ret + ctx.mux.Unlock() + log.Debugf("%s: running=%d, desired=%d", *input.ServiceName, ret.RunningCount, *input.DesiredCount) + for i := 0; i < int(*input.DesiredCount); i++ { + ctx.StartTask(c, &ecs.StartTaskInput{ + Cluster: input.Cluster, + Group: aws.String(fmt.Sprintf("service:%s", *input.ServiceName)), + NetworkConfiguration: input.NetworkConfiguration, + TaskDefinition: input.TaskDefinition, + }) + } + ctx.mux.Lock() + ctx.Services[*input.ServiceName].RunningCount = *input.DesiredCount + ctx.mux.Unlock() + log.Debugf("%s: running=%d", *input.ServiceName, ret.RunningCount) + return &ecs.CreateServiceOutput{ + Service: ret, + }, nil +} + +func (ctx *EcsServer) UpdateService(c context.Context, input *ecs.UpdateServiceInput, _ ...func(options *ecs.Options)) (*ecs.UpdateServiceOutput, error) { + ctx.mux.Lock() + s := ctx.Services[*input.Service] + ctx.mux.Unlock() + nextDesiredCount := s.DesiredCount + nextTaskDefinition := s.TaskDefinition + if input.TaskDefinition != nil { + nextTaskDefinition = input.TaskDefinition + } + if input.DesiredCount != nil { + nextDesiredCount = *input.DesiredCount + } + if diff := nextDesiredCount - s.DesiredCount; diff > 0 { + log.Debugf("diff=%d", diff) + // scale + for i := 0; i < int(diff); i++ { + ctx.StartTask(c, &ecs.StartTaskInput{ + Cluster: input.Cluster, + Group: aws.String(fmt.Sprintf("service:%s", *input.Service)), + TaskDefinition: nextTaskDefinition, + }) + } + } else if diff < 0 { + // descale + var i int32 = 0 + max := -diff + for k, v := range ctx.Tasks { + reg := regexp.MustCompile("service:" + *s.ServiceName) + if reg.MatchString(*v.Group) { + ctx.StopTask(c, &ecs.StopTaskInput{ + Cluster: input.Cluster, + Task: &k, + }) + i++ + if i >= max { + break + } + } + } + } + ctx.mux.Lock() + s.DesiredCount = nextDesiredCount + s.TaskDefinition = nextTaskDefinition + s.RunningCount = nextDesiredCount + s.PlatformVersion = input.PlatformVersion + s.ServiceRegistries = input.ServiceRegistries + s.NetworkConfiguration = input.NetworkConfiguration + s.LoadBalancers = input.LoadBalancers + s.Deployments = []types.Deployment{ + { + DesiredCount: nextDesiredCount, + LaunchType: s.LaunchType, + RunningCount: nextDesiredCount, + Status: s.Status, + TaskDefinition: s.TaskDefinition, + }, + } + ctx.mux.Unlock() + return &ecs.UpdateServiceOutput{ + Service: s, + }, nil +} + +func (ctx *EcsServer) DeleteService(c context.Context, input *ecs.DeleteServiceInput, _ ...func(options *ecs.Options)) (*ecs.DeleteServiceOutput, error) { + service := ctx.Services[*input.Service] + reg := regexp.MustCompile(fmt.Sprintf("service:%s", *service.ServiceName)) + for _, v := range ctx.Tasks { + if reg.MatchString(*v.Group) { + _, err := ctx.StopTask(c, &ecs.StopTaskInput{ + Cluster: input.Cluster, + Task: v.TaskArn, + }) + if err != nil { + return nil, err + } + } + } + ctx.mux.Lock() + defer ctx.mux.Unlock() + service.Status = aws.String("INACTIVE") + return &ecs.DeleteServiceOutput{Service: service}, nil +} + +func (ctx *EcsServer) RegisterTaskDefinition(_ context.Context, input *ecs.RegisterTaskDefinitionInput, _ ...func(options *ecs.Options)) (*ecs.RegisterTaskDefinitionOutput, error) { + td, err := ctx.TaskDefinitions.Register(input) + if err != nil { + return nil, err + } + return &ecs.RegisterTaskDefinitionOutput{TaskDefinition: td}, nil +} + +func (ctx *EcsServer) StartTask(_ context.Context, input *ecs.StartTaskInput, _ ...func(options *ecs.Options)) (*ecs.StartTaskOutput, error) { + ctx.mux.Lock() + defer ctx.mux.Unlock() + td := ctx.TaskDefinitions.Get(*input.TaskDefinition) + if td == nil { + return nil, fmt.Errorf("task definition not found: %s", *input.TaskDefinition) + } + taskArn := fmt.Sprintf("arn:aws:ecs:us-west-2:012345678910:task/%s", uuid.New().String()) + attachment := types.Attachment{ + Details: []types.KeyValuePair{ + { + Name: aws.String("privateIPv4Address"), + Value: aws.String("127.0.0.1"), + }, + }, + } + if input.NetworkConfiguration != nil { + subnet := input.NetworkConfiguration.AwsvpcConfiguration.Subnets[0] + attachment.Details = append(attachment.Details, types.KeyValuePair{ + Name: aws.String("subnetId"), + Value: &subnet, + }) + } + containers := make([]types.Container, len(td.ContainerDefinitions)) + for i, v := range td.ContainerDefinitions { + containers[i] = types.Container{ + Name: v.Name, + Image: v.Image, + LastStatus: aws.String("RUNNING"), + } + if v.HealthCheck != nil { + containers[i].HealthStatus = "HEALTHY" + } else { + containers[i].HealthStatus = "UNKNOWN" + } + } + + ret := types.Task{ + TaskArn: &taskArn, + ClusterArn: input.Cluster, + TaskDefinitionArn: input.TaskDefinition, + Group: input.Group, + Containers: containers, + } + ctx.Tasks[taskArn] = &ret + var launchType types.LaunchType + if len(input.ContainerInstances) > 0 { + launchType = types.LaunchTypeEc2 + } else { + launchType = types.LaunchTypeFargate + } + ret.LaunchType = launchType + if launchType == types.LaunchTypeFargate { + ret.Attachments = []types.Attachment{attachment} + } else { + ret.ContainerInstanceArn = aws.String("arn:aws:ecs:us-west-2:1234567890:container-instance/12345678-hoge-hoge-1234-1f2o3o4ba5r") + } + ret.LastStatus = aws.String("RUNNING") + return &ecs.StartTaskOutput{ + Tasks: []types.Task{ret}, + }, nil +} + +func (ctx *EcsServer) RunTask(c context.Context, input *ecs.RunTaskInput, _ ...func(options *ecs.Options)) (*ecs.RunTaskOutput, error) { + o, err := ctx.StartTask(c, &ecs.StartTaskInput{ + Cluster: input.Cluster, + Group: input.Group, + TaskDefinition: input.TaskDefinition, + NetworkConfiguration: input.NetworkConfiguration, + }) + if err != nil { + return nil, err + } + return &ecs.RunTaskOutput{ + Tasks: o.Tasks, + }, nil +} + +func (ctx *EcsServer) StopTask(_ context.Context, input *ecs.StopTaskInput, _ ...func(options *ecs.Options)) (*ecs.StopTaskOutput, error) { + ctx.mux.Lock() + defer ctx.mux.Unlock() + log.Debugf("stop: %s", *input.Task) + ret, ok := ctx.Tasks[*input.Task] + if !ok { + return nil, fmt.Errorf("task not found: %s", *input.Task) + } + for i := range ret.Containers { + v := &ret.Containers[i] + v.ExitCode = aws.Int32(0) + v.LastStatus = aws.String("STOPPED") + } + ret.LastStatus = aws.String("STOPPED") + ret.DesiredStatus = aws.String("STOPPED") + service, ok := ctx.Services[*ret.Group] + if ok { + service.RunningCount -= 1 + } + return &ecs.StopTaskOutput{Task: ret}, nil +} + +func (ctx *EcsServer) ListTasks(_ context.Context, input *ecs.ListTasksInput, _ ...func(options *ecs.Options)) (*ecs.ListTasksOutput, error) { + var ret []string + ctx.mux.Lock() + defer ctx.mux.Unlock() + for _, v := range ctx.Tasks { + group := fmt.Sprintf("service:%s", *input.ServiceName) + if *v.Group == group { + ret = append(ret, *v.TaskArn) + } + } + return &ecs.ListTasksOutput{ + TaskArns: ret, + }, nil +} + +func (ctx *EcsServer) DescribeServices(_ context.Context, input *ecs.DescribeServicesInput, _ ...func(options *ecs.Options)) (*ecs.DescribeServicesOutput, error) { + var ret []types.Service + ctx.mux.Lock() + defer ctx.mux.Unlock() + for _, v := range input.Services { + if s, ok := ctx.Services[v]; ok { + ret = append(ret, *s) + } + } + return &ecs.DescribeServicesOutput{ + Services: ret, + }, nil +} + +func (ctx *EcsServer) DescribeTasks(_ context.Context, input *ecs.DescribeTasksInput, _ ...func(options *ecs.Options)) (*ecs.DescribeTasksOutput, error) { + ctx.mux.Lock() + defer ctx.mux.Unlock() + var ret []types.Task + for _, task := range ctx.Tasks { + for _, v := range input.Tasks { + if *task.TaskArn == v { + ret = append(ret, *task) + } + } + } + return &ecs.DescribeTasksOutput{ + Tasks: ret, + }, nil +} +func (ctx *EcsServer) DescribeContainerInstances(_ context.Context, input *ecs.DescribeContainerInstancesInput, _ ...func(options *ecs.Options)) (*ecs.DescribeContainerInstancesOutput, error) { + ctx.mux.Lock() + defer ctx.mux.Unlock() + var ret []types.ContainerInstance + ec2Id := "i-1234567890abcdefg" + instance := types.ContainerInstance{ + Ec2InstanceId: &ec2Id, + } + ret = append(ret, instance) + return &ecs.DescribeContainerInstancesOutput{ + ContainerInstances: ret, + }, nil +} + +func (e *EcsServer) DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) { + td := e.TaskDefinitions.Get(*params.TaskDefinition) + if td == nil { + return nil, fmt.Errorf("task definition not found: %s", *params.TaskDefinition) + } + return &ecs.DescribeTaskDefinitionOutput{TaskDefinition: td}, nil +} + +func (e *EcsServer) ListAttributes(ctx context.Context, params *ecs.ListAttributesInput, optFns ...func(*ecs.Options)) (*ecs.ListAttributesOutput, error) { + return &ecs.ListAttributesOutput{}, nil +} + +func (e *EcsServer) PutAttributes(ctx context.Context, params *ecs.PutAttributesInput, optFns ...func(*ecs.Options)) (*ecs.PutAttributesOutput, error) { + return nil, nil +} diff --git a/test/setup.go b/test/setup.go index e3bc09e..fb5d152 100644 --- a/test/setup.go +++ b/test/setup.go @@ -39,9 +39,9 @@ func Setup(ctrl *gomock.Controller, envars *env.Envars, currentTaskCount int, la albMock := mock_awsiface.NewMockAlbClient(ctrl) albMock.EXPECT().DescribeTargetGroups(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroups).AnyTimes() albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes() - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttibutes).AnyTimes() - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTarget).AnyTimes() - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTarget).AnyTimes() + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttributes).AnyTimes() + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTargets).AnyTimes() + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTargets).AnyTimes() ec2Mock := mock_awsiface.NewMockEc2Client(ctrl) ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeSubnets).AnyTimes() @@ -54,7 +54,7 @@ func Setup(ctrl *gomock.Controller, envars *env.Envars, currentTaskCount int, la input.LaunchType = launchType svc, _ := mocker.CreateService(context.Background(), &input) if len(svc.Service.LoadBalancers) > 0 { - _, _ = mocker.RegisterTarget(context.Background(), &elbv2.RegisterTargetsInput{ + _, _ = mocker.RegisterTargets(context.Background(), &elbv2.RegisterTargetsInput{ TargetGroupArn: svc.Service.LoadBalancers[0].TargetGroupArn, }) } diff --git a/test/srv.go b/test/srv.go new file mode 100644 index 0000000..58ed9f9 --- /dev/null +++ b/test/srv.go @@ -0,0 +1,91 @@ +package test + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" + srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" + "golang.org/x/xerrors" +) + +type SrvServer struct { + *commons + services []*srvtypes.Service + // service.Name -> []*instance + insts map[string][]*srvtypes.Instance + // instance.Id -> HealthStatus + instHelths map[string]srvtypes.HealthStatus +} + +func (s *SrvServer) getServiceById(id string) *srvtypes.Service { + for _, svc := range s.services { + if *svc.Id == id { + return svc + } + } + return nil +} + +func (s *SrvServer) putInstHealth(id string, health srvtypes.HealthStatus) { + s.instHelths[id] = health +} + +func (s *SrvServer) DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) { + insts, ok := s.insts[*params.ServiceName] + if !ok { + return nil, xerrors.Errorf("service not found: %s", *params.ServiceName) + } + var summories []srvtypes.HttpInstanceSummary + for _, inst := range insts { + health := s.instHelths[*inst.Id] + summories = append(summories, srvtypes.HttpInstanceSummary{ + Attributes: inst.Attributes, + InstanceId: inst.Id, + ServiceName: params.ServiceName, + NamespaceName: params.NamespaceName, + HealthStatus: health, + }) + } + return &servicediscovery.DiscoverInstancesOutput{Instances: summories}, nil +} + +func (s *SrvServer) RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) { + if srv := s.getServiceById(*params.ServiceId); srv == nil { + return nil, xerrors.Errorf("service not found: %s", *params.ServiceId) + } else { + inst := &srvtypes.Instance{ + Id: params.InstanceId, + Attributes: params.Attributes, + } + s.insts[*srv.Name] = append(s.insts[*params.ServiceId], inst) + s.instHelths[*params.InstanceId] = srvtypes.HealthStatusUnhealthy + return &servicediscovery.RegisterInstanceOutput{}, nil + } +} + +func (s *SrvServer) DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) { + srv := s.getServiceById(*params.ServiceId) + if srv == nil { + return nil, xerrors.Errorf("service not found: %s", *params.ServiceId) + } + insts, ok := s.insts[*srv.Name] + if !ok { + return nil, xerrors.Errorf("service not found: %s", *srv.Name) + } + var newInsts []*srvtypes.Instance + for _, inst := range insts { + if *inst.Id != *params.InstanceId { + newInsts = append(newInsts, inst) + } + } + s.insts[*srv.Name] = newInsts + return &servicediscovery.DeregisterInstanceOutput{}, nil +} + +func (s *SrvServer) GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) { + svc := s.getServiceById(*params.Id) + if svc == nil { + return nil, xerrors.Errorf("service not found: %s", *params.Id) + } + return &servicediscovery.GetServiceOutput{Service: svc}, nil +} From d7e82daaff4ee6fd62d86ab5a8d23fc057d7695b Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Fri, 28 Jun 2024 17:59:17 +0900 Subject: [PATCH 10/47] wip --- awsiface/iface.go | 1 + cage.go | 10 ++-- cli/cage/commands/flags.go | 11 ++++ cli/cage/commands/rollout.go | 1 + env/env.go | 2 + fixtures/service.json | 30 ---------- fixtures/task-definition.json | 39 ------------- rollout_test.go | 40 ++++++------- run_test.go | 34 +++++------ task/arn.go | 8 +++ task/arn_test.go | 13 +++++ task/common.go | 58 ++++++++++--------- task/srv_task.go | 49 ++++++++++------ task/srv_task_test.go | 47 +++++++++++++++ test/context.go | 46 +++++++++------ test/setup.go | 105 +++++++++++++++++++++++++--------- test/srv.go | 56 +++++++++++++----- 17 files changed, 335 insertions(+), 215 deletions(-) delete mode 100644 fixtures/service.json delete mode 100644 fixtures/task-definition.json create mode 100644 task/arn.go create mode 100644 task/arn_test.go create mode 100644 task/srv_task_test.go diff --git a/awsiface/iface.go b/awsiface/iface.go index 2dc2fa9..8135fd1 100644 --- a/awsiface/iface.go +++ b/awsiface/iface.go @@ -42,6 +42,7 @@ type ( RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) + GetNamespace(ctx context.Context, params *servicediscovery.GetNamespaceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetNamespaceOutput, error) } ) diff --git a/cage.go b/cage.go index 2c12236..24b3998 100644 --- a/cage.go +++ b/cage.go @@ -21,15 +21,17 @@ func NewCage(input *types.Deps) types.Cage { taskHealthCheckWait := (time.Duration)(input.Env.CanaryTaskHealthCheckWait) * time.Second taskStoppedWait := (time.Duration)(input.Env.CanaryTaskStoppedWait) * time.Second serviceStableWait := (time.Duration)(input.Env.ServiceStableWait) * time.Second + targetHealthCheckWait := (time.Duration)(input.Env.TargetHealthCheckWait) * time.Second return &cage{ Deps: input, Timeout: timeout.NewManager( 15*time.Minute, &timeout.Input{ - TaskRunningWait: taskRunningWait, - TaskHealthCheckWait: taskHealthCheckWait, - TaskStoppedWait: taskStoppedWait, - ServiceStableWait: serviceStableWait, + TaskRunningWait: taskRunningWait, + TaskHealthCheckWait: taskHealthCheckWait, + TaskStoppedWait: taskStoppedWait, + ServiceStableWait: serviceStableWait, + TargetHealthCheckWait: targetHealthCheckWait, }), TaskFactory: &taskFactory{}, } diff --git a/cli/cage/commands/flags.go b/cli/cage/commands/flags.go index 0e228b5..f60cf4a 100644 --- a/cli/cage/commands/flags.go +++ b/cli/cage/commands/flags.go @@ -92,3 +92,14 @@ func ServiceStableWaitFlag(dest *int) *cli.IntFlag { Value: 900, } } + +func TargetHealthCheckWaitFlag(dest *int) *cli.IntFlag { + return &cli.IntFlag{ + Name: "targetHealthCheckTimeout", + EnvVars: []string{env.TargetHealthCheckTimeout}, + Usage: "max duration seconds for waiting target health check", + Destination: dest, + Category: "ADVANCED", + Value: 900, + } +} diff --git a/cli/cage/commands/rollout.go b/cli/cage/commands/rollout.go index 87d6129..362b34c 100644 --- a/cli/cage/commands/rollout.go +++ b/cli/cage/commands/rollout.go @@ -41,6 +41,7 @@ func (c *CageCommands) RollOut( TaskHealthCheckWaitFlag(&envars.CanaryTaskHealthCheckWait), TaskStoppedWaitFlag(&envars.CanaryTaskStoppedWait), ServiceStableWaitFlag(&envars.ServiceStableWait), + TargetHealthCheckWaitFlag(&envars.TargetHealthCheckWait), }, Action: func(ctx *cli.Context) error { dir, _, err := c.requireArgs(ctx, 1, 1) diff --git a/env/env.go b/env/env.go index f9b0e1d..2d7a5e5 100644 --- a/env/env.go +++ b/env/env.go @@ -27,6 +27,7 @@ type Envars struct { CanaryTaskHealthCheckWait int // sec CanaryTaskStoppedWait int // sec ServiceStableWait int // sec + TargetHealthCheckWait int // sec } // required @@ -45,6 +46,7 @@ const TaskRunningTimeout = "CAGE_TASK_RUNNING_TIMEOUT" const TaskHealthCheckTimeout = "CAGE_TASK_HEALTH_CHECK_TIMEOUT" const TaskStoppedTimeout = "CAGE_TASK_STOPPED_TIMEOUT" const ServiceStableTimeout = "CAGE_SERVICE_STABLE_TIMEOUT" +const TargetHealthCheckTimeout = "CAGE_TARGET_HEALTH_CHECK_TIMEOUT" func EnsureEnvars( dest *Envars, diff --git a/fixtures/service.json b/fixtures/service.json deleted file mode 100644 index a7bab44..0000000 --- a/fixtures/service.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "cluster": "cluster", - "serviceName": "service", - "taskDefinition": "test-task:1", - "loadBalancers": [ - { - "targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:123456789012:targetgroup/tg/1234567890123456", - "loadBalancerName": "lb", - "containerName": "container", - "containerPort": 8000 - } - ], - "desiredCount": 1, - "launchType": "FARGATE", - "platformVersion": "1.4.0", - "role": "ecsServiceRole", - "deploymentConfiguration": { - "maximumPercent": 200, - "minimumHealthyPercent": 100 - }, - "networkConfiguration": { - "awsvpcConfiguration": { - "subnets": ["subnet-111", "subnet-222"], - "securityGroups": ["sg-111", "sg-222"], - "assignPublicIp": "ENABLED" - } - }, - "healthCheckGracePeriodSeconds": 0, - "schedulingStrategy": "REPLICA" -} diff --git a/fixtures/task-definition.json b/fixtures/task-definition.json deleted file mode 100644 index 639030b..0000000 --- a/fixtures/task-definition.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "family": "test-task", - "taskRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", - "executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", - "networkMode": "awsvpc", - "containerDefinitions": [ - { - "name": "container", - "image": "", - "portMappings": [ - { - "containerPort": 8000, - "hostPort": 80 - } - ], - "essential": true, - "healthCheck": { - "command": [""], - "interval": 0, - "timeout": 0, - "retries": 0, - "startPeriod": 0 - } - }, - { - "name": "containerWithoutHealthCheck", - "image": "", - "portMappings": [ - { - "containerPort": 8000, - "hostPort": 80 - } - ] - } - ], - "requiresCompatibilities": ["FARGATE"], - "cpu": "256", - "memory": "512" -} diff --git a/rollout_test.go b/rollout_test.go index 97fbcab..3a19fbe 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -79,9 +79,9 @@ func TestCage_RollOut_FARGATE(t *testing.T) { mocker, ecsMock, _, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") albMock := mock_awsiface.NewMockAlbClient(ctrl) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTargets).Times(1) - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTargets).Times(1) - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttributes).Times(1) + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.RegisterTargets).Times(1) + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DeregisterTargets).Times(1) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetGroupAttributes).Times(1) gomock.InOrder( albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ @@ -96,7 +96,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, }}, }, nil).Times(2), - albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes(), + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetHealth).AnyTimes(), ) cagecli := cage.NewCage(&types.Deps{ Env: envars, @@ -115,9 +115,9 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mocker, ecsMock, _, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") albMock := mock_awsiface.NewMockAlbClient(ctrl) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTargets).Times(1) - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTargets).Times(1) - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttributes).Times(1) + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.RegisterTargets).Times(1) + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DeregisterTargets).Times(1) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetGroupAttributes).Times(1) gomock.InOrder( albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ TargetHealthDescriptions: []elbv2types.TargetHealthDescription{{ @@ -140,7 +140,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, }}, }, nil), - albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes(), + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetHealth).AnyTimes(), ) cagecli := cage.NewCage(&types.Deps{ Env: envars, @@ -275,19 +275,18 @@ func TestCage_RollOut_FARGATE(t *testing.T) { mocker, _, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - ecsMock.EXPECT().CreateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.CreateService).AnyTimes() - ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.UpdateService).AnyTimes() - ecsMock.EXPECT().DeleteService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeleteService).AnyTimes() - ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.StartTask).AnyTimes() - ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTaskDefinition).AnyTimes() - ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeServices).AnyTimes() + ecsMock.EXPECT().CreateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.CreateService).AnyTimes() + ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.UpdateService).AnyTimes() + ecsMock.EXPECT().DeleteService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DeleteService).AnyTimes() + ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.StartTask).AnyTimes() + ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RegisterTaskDefinition).AnyTimes() + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeServices).AnyTimes() ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *ecs.DescribeTasksInput, opts ...func(options *ecs.Options)) (*ecs.DescribeTasksOutput, error) { - out, err := mocker.DescribeTasks(ctx, input, opts...) + out, err := mocker.Ecs.DescribeTasks(ctx, input, opts...) if err != nil { return out, err } - task := mocker.Tasks[input.Tasks[0]] if strings.Contains(*task.Group, "canary-task") { for i := range out.Tasks { @@ -299,11 +298,10 @@ func TestCage_RollOut_FARGATE(t *testing.T) { return out, err }, ).AnyTimes() - ecsMock.EXPECT().ListTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.ListTasks).AnyTimes() - ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeContainerInstances).AnyTimes() - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask).AnyTimes() - ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.StopTask).AnyTimes() - + ecsMock.EXPECT().ListTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.ListTasks).AnyTimes() + ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeContainerInstances).AnyTimes() + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask).AnyTimes() + ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.StopTask).AnyTimes() cagecli := cage.NewCage(&types.Deps{ Env: envars, Ecs: ecsMock, diff --git a/run_test.go b/run_test.go index 8169ba7..780b0bc 100644 --- a/run_test.go +++ b/run_test.go @@ -24,7 +24,7 @@ func TestCage_Run(t *testing.T) { mocker := test.NewMockContext() ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTaskDefinition).AnyTimes() + ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RegisterTaskDefinition).AnyTimes() return env, mocker, ecsMock } t.Run("basic", func(t *testing.T) { @@ -33,11 +33,11 @@ func TestCage_Run(t *testing.T) { ctx := context.Background() env, mocker, ecsMock := setupForBasic(t) gomock.InOrder( - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask), - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTasks), + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeTasks), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { - mocker.StopTask(ctx, &ecs.StopTaskInput{Cluster: &env.Cluster, Task: &input.Tasks[0]}) - return mocker.DescribeTasks(ctx, input) + mocker.Ecs.StopTask(ctx, &ecs.StopTaskInput{Cluster: &env.Cluster, Task: &input.Tasks[0]}) + return mocker.Ecs.DescribeTasks(ctx, input) }), ) cagecli := cage.NewCage(&types.Deps{ @@ -59,10 +59,10 @@ func TestCage_Run(t *testing.T) { env, mocker, ecsMock := setupForBasic(t) env.CanaryTaskRunningWait = 1 gomock.InOrder( - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask), + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(ctx context.Context, input *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { - res, err := mocker.DescribeTasks(ctx, input) + res, err := mocker.Ecs.DescribeTasks(ctx, input) for i := range res.Tasks { res.Tasks[i].LastStatus = aws.String("PROVISIONING") } @@ -89,8 +89,8 @@ func TestCage_Run(t *testing.T) { env, mocker, ecsMock := setupForBasic(t) env.CanaryTaskStoppedWait = 1 gomock.InOrder( - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask), - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTasks).Times(2), + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeTasks).Times(2), ) cagecli := cage.NewCage(&types.Deps{ Env: env, @@ -110,14 +110,14 @@ func TestCage_Run(t *testing.T) { ctx := context.Background() env, mocker, ecsMock := setupForBasic(t) gomock.InOrder( - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask), - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTasks), + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeTasks), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { - stop, _ := mocker.StopTask(ctx, &ecs.StopTaskInput{Cluster: &env.Cluster, Task: &input.Tasks[0]}) + stop, _ := mocker.Ecs.StopTask(ctx, &ecs.StopTaskInput{Cluster: &env.Cluster, Task: &input.Tasks[0]}) for i := range stop.Task.Containers { stop.Task.Containers[i].ExitCode = aws.Int32(1) } - return mocker.DescribeTasks(ctx, input) + return mocker.Ecs.DescribeTasks(ctx, input) }), ) cagecli := cage.NewCage(&types.Deps{ @@ -138,14 +138,14 @@ func TestCage_Run(t *testing.T) { ctx := context.Background() env, mocker, ecsMock := setupForBasic(t) gomock.InOrder( - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask), - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTasks), + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeTasks), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, input *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error) { - stop, _ := mocker.StopTask(ctx, &ecs.StopTaskInput{Cluster: &env.Cluster, Task: &input.Tasks[0]}) + stop, _ := mocker.Ecs.StopTask(ctx, &ecs.StopTaskInput{Cluster: &env.Cluster, Task: &input.Tasks[0]}) for i := range stop.Task.Containers { stop.Task.Containers[i].ExitCode = nil } - return mocker.DescribeTasks(ctx, input) + return mocker.Ecs.DescribeTasks(ctx, input) }), ) cagecli := cage.NewCage(&types.Deps{ diff --git a/task/arn.go b/task/arn.go new file mode 100644 index 0000000..bb0dace --- /dev/null +++ b/task/arn.go @@ -0,0 +1,8 @@ +package task + +import "strings" + +func ArnToId(arn string) string { + list := strings.Split(arn, "/") + return list[len(list)-1] +} diff --git a/task/arn_test.go b/task/arn_test.go new file mode 100644 index 0000000..28f3657 --- /dev/null +++ b/task/arn_test.go @@ -0,0 +1,13 @@ +package task_test + +import ( + "testing" + + "github.com/loilo-inc/canarycage/task" + "github.com/stretchr/testify/assert" +) + +func TestArnToId(t *testing.T) { + arn := "arn://aaa/srv-1234" + assert.Equal(t, "srv-1234", task.ArnToId(arn)) +} diff --git a/task/common.go b/task/common.go index 0cb3a5e..2db2e9a 100644 --- a/task/common.go +++ b/task/common.go @@ -93,37 +93,43 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { containerHasHealthChecks[*definition.Name] = struct{}{} } } - healthCheckWait := c.Timeout.TaskHealthCheck() + rest := c.Timeout.TaskHealthCheck() healthCheckPeriod := 15 * time.Second - countPerPeriod := int(healthCheckWait.Seconds() / 15) - for count := 0; count < countPerPeriod; count++ { - <-c.Time.NewTimer(healthCheckPeriod).C - log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) - if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, - Tasks: []string{*c.taskArn}, - }); err != nil { - return err - } else { - task := o.Tasks[0] - if *task.LastStatus != "RUNNING" { - return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) - } - - for _, container := range task.Containers { - if _, ok := containerHasHealthChecks[*container.Name]; !ok { - continue + for rest > 0 { + if rest < healthCheckPeriod { + healthCheckPeriod = rest + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.Time.NewTimer(healthCheckPeriod).C: + log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) + if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &c.Env.Cluster, + Tasks: []string{*c.taskArn}, + }); err != nil { + return err + } else { + task := o.Tasks[0] + if *task.LastStatus != "RUNNING" { + return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) } - if container.HealthStatus != ecstypes.HealthStatusHealthy { - log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) - continue + for _, container := range task.Containers { + if _, ok := containerHasHealthChecks[*container.Name]; !ok { + continue + } + if container.HealthStatus != ecstypes.HealthStatusHealthy { + log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) + continue + } + delete(containerHasHealthChecks, *container.Name) + } + if len(containerHasHealthChecks) == 0 { + return nil } - delete(containerHasHealthChecks, *container.Name) - } - if len(containerHasHealthChecks) == 0 { - return nil } } + rest -= healthCheckPeriod } return xerrors.Errorf("😨 canary task hasn't become to be healthy") } diff --git a/task/srv_task.go b/task/srv_task.go index d58d81b..f3a2af5 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -2,7 +2,6 @@ package task import ( "context" - "regexp" "time" "github.com/apex/log" @@ -18,7 +17,8 @@ type srvTask struct { registry *ecstypes.ServiceRegistry target *CanaryTarget srv *srvtypes.Service - inst *srvtypes.HttpInstanceSummary + instId *string + ns *srvtypes.Namespace } func NewSrvTask(input *Input, registry *ecstypes.ServiceRegistry) Task { @@ -35,7 +35,7 @@ func (c *srvTask) Wait(ctx context.Context) error { if err := c.registerToSrvDiscovery(ctx); err != nil { return err } - log.Infof("canary task '%s' is registered to service discovery instance '%s'", *c.taskArn, *c.inst.InstanceId) + log.Infof("canary task '%s' is registered to service discovery instance '%s'", *c.taskArn, *c.instId) log.Infof("😷 ensuring canary service instance to become healthy...") if err := c.waitUntilSrvInstHelthy(ctx); err != nil { return err @@ -50,19 +50,20 @@ func (c *srvTask) Stop(ctx context.Context) error { } func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { - target, err := c.describeTaskTarget(ctx, *c.registry.Port) + var targetPort int32 + if c.registry.Port != nil { + targetPort = *c.registry.Port + } else { + targetPort = 80 + } + target, err := c.describeTaskTarget(ctx, targetPort) if err != nil { return err } - c.target = target - // get the service id from service registry arn - pat := regexp.MustCompile("arn://.+/(srv-.+)$") - matches := pat.FindStringSubmatch(*c.registry.RegistryArn) - if len(matches) != 2 { - return xerrors.Errorf("service name '%s' doesn't match the pattern", c.Env.Service) - } - srvId := matches[1] + c.target = target // get the service id from service registry arn + srvId := ArnToId(*c.registry.RegistryArn) var svc *srvtypes.Service + var ns *srvtypes.Namespace if o, err := c.Srv.GetService(ctx, &servicediscovery.GetServiceInput{ Id: &srvId, }); err != nil { @@ -70,6 +71,13 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { } else { svc = o.Service } + if o, err := c.Srv.GetNamespace(ctx, &servicediscovery.GetNamespaceInput{ + Id: svc.NamespaceId, + }); err != nil { + return xerrors.Errorf("failed to get the namespace: %w", err) + } else { + ns = o.Namespace + } attrs := map[string]string{ "AWS_INSTANCE_IPV4": target.targetIpv4, "AVAILABILITY_ZONE": target.availabilityZone, @@ -80,14 +88,17 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { "REGION": c.Env.Region, "CAGE_CANARY_TASK": "1", } + taskId := ArnToId(*c.taskArn) if _, err := c.Srv.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ ServiceId: &srvId, - InstanceId: c.taskArn, + InstanceId: &taskId, Attributes: attrs, }); err != nil { return xerrors.Errorf("failed to register the canary task to service discovery: %w", err) } c.srv = svc + c.instId = &taskId + c.ns = ns return nil } @@ -105,11 +116,12 @@ func (c *srvTask) waitUntilSrvInstHelthy( return ctx.Err() case <-c.Time.NewTimer(time.Duration(waitPeriod) * time.Second).C: if list, err := c.Srv.DiscoverInstances(ctx, &servicediscovery.DiscoverInstancesInput{ - NamespaceName: c.inst.NamespaceName, - ServiceName: c.inst.ServiceName, + NamespaceName: c.ns.Name, + ServiceName: c.srv.Name, HealthStatus: srvtypes.HealthStatusFilterHealthy, QueryParameters: map[string]string{ "CAGE_CANARY_TASK": "1", + "AWS_PRIVATE_IPV4": c.target.targetIpv4, }, }); err != nil { return xerrors.Errorf("failed to discover instances: %w", err) @@ -118,8 +130,7 @@ func (c *srvTask) waitUntilSrvInstHelthy( return xerrors.Errorf("no healthy instances found") } for _, inst := range list.Instances { - if ipv4 := inst.Attributes["AWS_INSTANCE_IPV4"]; ipv4 == c.target.targetIpv4 { - c.inst = &inst + if *inst.InstanceId == *c.instId { return nil } } @@ -133,12 +144,12 @@ func (c *srvTask) waitUntilSrvInstHelthy( func (c *srvTask) deregisterSrvInst( ctx context.Context, ) { - if c.inst == nil { + if c.instId == nil { return } if _, err := c.Srv.DeregisterInstance(ctx, &servicediscovery.DeregisterInstanceInput{ ServiceId: c.srv.Id, - InstanceId: c.inst.InstanceId, + InstanceId: c.instId, }); err != nil { log.Errorf("failed to deregister the canary task from service discovery: %v", err) log.Errorf("continuing to stop the canary task...") diff --git a/task/srv_task_test.go b/task/srv_task_test.go new file mode 100644 index 0000000..66b9de4 --- /dev/null +++ b/task/srv_task_test.go @@ -0,0 +1,47 @@ +package task_test + +import ( + "context" + "testing" + + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "github.com/stretchr/testify/assert" +) + +func TestSrvTask(t *testing.T) { + srvSvcName := "internal" + srvNsName := "dev.local" + registryArn := "arn://aaa/srv-" + srvSvcName + mocker := test.NewMockContext() + env := test.DefaultEnvars() + + ctx := context.TODO() + td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) + env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn + ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) + stask := task.NewSrvTask(&task.Input{ + Deps: &types.Deps{ + Env: env, + Ecs: mocker.Ecs, + Ec2: mocker.Ec2, + Alb: mocker.Alb, + Srv: mocker.Srv, + Time: test.NewFakeTime(), + }, + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, + Timeout: timeout.NewManager(1, &timeout.Input{}), + }, &ecstypes.ServiceRegistry{RegistryArn: ®istryArn}) + srvSvc := mocker.CreateSrvService(srvNsName, srvSvcName) + err := stask.Start(ctx) + assert.NoError(t, err) + mocker.PutSrvInstHealth(*srvSvc.Id, "healthy") + err = stask.Wait(ctx) + assert.NoError(t, err) + err = stask.Stop(ctx) + assert.NoError(t, err) +} diff --git a/test/context.go b/test/context.go index 00aab32..675c422 100644 --- a/test/context.go +++ b/test/context.go @@ -3,45 +3,55 @@ package test import ( "sync" - "github.com/aws/aws-sdk-go-v2/service/ecs/types" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" "github.com/loilo-inc/canarycage/awsiface" ) type commons struct { - Services map[string]*types.Service - Tasks map[string]*types.Task + Services map[string]*ecstypes.Service + Tasks map[string]*ecstypes.Task TaskDefinitions *TaskDefinitionRepository TargetGroups map[string]struct{} - mux sync.Mutex + SrvNamespaces []*srvtypes.Namespace + SrvServices []*srvtypes.Service + // service.Name -> []*instance + SrvInsts map[string][]*srvtypes.Instance + // instance.Id -> HealthStatus + SrvInstHelths map[string]srvtypes.HealthStatus + mux sync.Mutex } type MockContext struct { *commons - awsiface.EcsClient - awsiface.AlbClient - awsiface.Ec2Client - awsiface.SrvClient + Ecs awsiface.EcsClient + Alb awsiface.AlbClient + Ec2 awsiface.Ec2Client + Srv awsiface.SrvClient } func NewMockContext() *MockContext { cm := &commons{ - Services: make(map[string]*types.Service), - Tasks: make(map[string]*types.Task), + Services: make(map[string]*ecstypes.Service), + Tasks: make(map[string]*ecstypes.Task), TaskDefinitions: &TaskDefinitionRepository{ families: make(map[string]*TaskDefinitionFamily), }, - TargetGroups: make(map[string]struct{}), + TargetGroups: make(map[string]struct{}), + SrvServices: make([]*srvtypes.Service, 0), + SrvInsts: make(map[string][]*srvtypes.Instance), + SrvInstHelths: make(map[string]srvtypes.HealthStatus), } return &MockContext{ - commons: cm, - EcsClient: &EcsServer{commons: cm}, - Ec2Client: &Ec2Server{commons: cm}, - SrvClient: &SrvServer{commons: cm}, - AlbClient: &AlbServer{commons: cm}, + commons: cm, + Ecs: &EcsServer{commons: cm}, + Ec2: &Ec2Server{commons: cm}, + Srv: &SrvServer{commons: cm}, + Alb: &AlbServer{commons: cm}, } } -func (ctx *commons) GetTask(id string) (*types.Task, bool) { +func (ctx *commons) GetTask(id string) (*ecstypes.Task, bool) { ctx.mux.Lock() defer ctx.mux.Unlock() o, ok := ctx.Tasks[id] @@ -62,7 +72,7 @@ func (ctx *commons) RunningTaskSize() int { return count } -func (ctx *commons) GetEcsService(id string) (*types.Service, bool) { +func (ctx *commons) GetEcsService(id string) (*ecstypes.Service, bool) { ctx.mux.Lock() defer ctx.mux.Unlock() o, ok := ctx.Services[id] diff --git a/test/setup.go b/test/setup.go index fb5d152..8c351dd 100644 --- a/test/setup.go +++ b/test/setup.go @@ -24,37 +24,37 @@ func Setup(ctrl *gomock.Controller, envars *env.Envars, currentTaskCount int, la mocker := NewMockContext() ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - ecsMock.EXPECT().CreateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.CreateService).AnyTimes() - ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.UpdateService).AnyTimes() - ecsMock.EXPECT().DeleteService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeleteService).AnyTimes() - ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.StartTask).AnyTimes() - ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTaskDefinition).AnyTimes() - ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeServices).AnyTimes() - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTasks).AnyTimes() - ecsMock.EXPECT().ListTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.ListTasks).AnyTimes() - ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeContainerInstances).AnyTimes() - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RunTask).AnyTimes() - ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.StopTask).AnyTimes() + ecsMock.EXPECT().CreateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.CreateService).AnyTimes() + ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.UpdateService).AnyTimes() + ecsMock.EXPECT().DeleteService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DeleteService).AnyTimes() + ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.StartTask).AnyTimes() + ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RegisterTaskDefinition).AnyTimes() + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeServices).AnyTimes() + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeTasks).AnyTimes() + ecsMock.EXPECT().ListTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.ListTasks).AnyTimes() + ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeContainerInstances).AnyTimes() + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask).AnyTimes() + ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.StopTask).AnyTimes() albMock := mock_awsiface.NewMockAlbClient(ctrl) - albMock.EXPECT().DescribeTargetGroups(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroups).AnyTimes() - albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetHealth).AnyTimes() - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeTargetGroupAttributes).AnyTimes() - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.RegisterTargets).AnyTimes() - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DeregisterTargets).AnyTimes() + albMock.EXPECT().DescribeTargetGroups(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetGroups).AnyTimes() + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetHealth).AnyTimes() + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetGroupAttributes).AnyTimes() + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.RegisterTargets).AnyTimes() + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DeregisterTargets).AnyTimes() ec2Mock := mock_awsiface.NewMockEc2Client(ctrl) - ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeSubnets).AnyTimes() - ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.DescribeInstances).AnyTimes() - td, _ := mocker.RegisterTaskDefinition(context.Background(), envars.TaskDefinitionInput) + ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ec2.DescribeSubnets).AnyTimes() + ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ec2.DescribeInstances).AnyTimes() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.Background(), envars.TaskDefinitionInput) if currentTaskCount >= 0 { input := *envars.ServiceDefinitionInput input.TaskDefinition = td.TaskDefinition.TaskDefinitionArn input.DesiredCount = aws.Int32(int32(currentTaskCount)) input.LaunchType = launchType - svc, _ := mocker.CreateService(context.Background(), &input) + svc, _ := mocker.Ecs.CreateService(context.Background(), &input) if len(svc.Service.LoadBalancers) > 0 { - _, _ = mocker.RegisterTargets(context.Background(), &elbv2.RegisterTargetsInput{ + _, _ = mocker.Alb.RegisterTargets(context.Background(), &elbv2.RegisterTargetsInput{ TargetGroupArn: svc.Service.LoadBalancers[0].TargetGroupArn, }) } @@ -63,17 +63,68 @@ func Setup(ctrl *gomock.Controller, envars *env.Envars, currentTaskCount int, la } func DefaultEnvars() *env.Envars { - d, _ := os.ReadFile("fixtures/task-definition.json") - var taskDefinition ecs.RegisterTaskDefinitionInput - if err := json.Unmarshal(d, &taskDefinition); err != nil { - log.Fatalf(err.Error()) + service := &ecs.CreateServiceInput{ + Cluster: aws.String("cluster"), + ServiceName: aws.String("service"), + TaskDefinition: aws.String("-"), + LoadBalancers: []ecstypes.LoadBalancer{ + {TargetGroupArn: aws.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:targetgroup/test/123456789012"), + ContainerName: aws.String("container"), + ContainerPort: aws.Int32(8000), + LoadBalancerName: aws.String("lb"), + }, + }, + DesiredCount: aws.Int32(1), + LaunchType: ecstypes.LaunchTypeFargate, + PlatformVersion: aws.String("1.4.0"), + Role: aws.String("arn:aws:iam::123456789012:role/ecsServiceRole"), + DeploymentConfiguration: &ecstypes.DeploymentConfiguration{ + MaximumPercent: aws.Int32(200), + MinimumHealthyPercent: aws.Int32(100), + }, + NetworkConfiguration: &ecstypes.NetworkConfiguration{ + AwsvpcConfiguration: &ecstypes.AwsVpcConfiguration{ + Subnets: []string{"subnet-12345678"}, + SecurityGroups: []string{"sg-12345678"}, + AssignPublicIp: ecstypes.AssignPublicIpDisabled, + }, + }, + HealthCheckGracePeriodSeconds: aws.Int32(0), + SchedulingStrategy: ecstypes.SchedulingStrategyReplica, + } + taskDefinition := &ecs.RegisterTaskDefinitionInput{ + Family: aws.String("test-task"), + TaskRoleArn: aws.String("arn:aws:iam::123456789012:role/ecsTaskExecutionRole"), + ExecutionRoleArn: aws.String("arn:aws:iam::123456789012:role/ecsTaskExecutionRole"), + NetworkMode: ecstypes.NetworkModeAwsvpc, + ContainerDefinitions: []ecstypes.ContainerDefinition{ + { + Name: aws.String("container"), + PortMappings: []ecstypes.PortMapping{ + {ContainerPort: aws.Int32(8000), HostPort: aws.Int32(80)}, + }, + Essential: aws.Bool(true), + HealthCheck: &ecstypes.HealthCheck{ + Command: []string{"CMD-SHELL", "curl -f http://localhost:8000/ || exit 1"}, + }, + }, + { + Name: aws.String("containerWithoutHealthCheck"), + PortMappings: []ecstypes.PortMapping{ + {ContainerPort: aws.Int32(8000), HostPort: aws.Int32(81)}, + }, + }, + }, + RequiresCompatibilities: []ecstypes.Compatibility{ecstypes.CompatibilityFargate}, + Cpu: aws.String("256"), + Memory: aws.String("512"), } return &env.Envars{ Region: "us-west-2", Cluster: "cage-test", Service: "service", - ServiceDefinitionInput: ReadServiceDefinition("fixtures/service.json"), - TaskDefinitionInput: &taskDefinition, + ServiceDefinitionInput: service, + TaskDefinitionInput: taskDefinition, } } diff --git a/test/srv.go b/test/srv.go index 58ed9f9..5a8e3e0 100644 --- a/test/srv.go +++ b/test/srv.go @@ -2,7 +2,9 @@ package test import ( "context" + "fmt" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/servicediscovery" srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" "golang.org/x/xerrors" @@ -10,15 +12,10 @@ import ( type SrvServer struct { *commons - services []*srvtypes.Service - // service.Name -> []*instance - insts map[string][]*srvtypes.Instance - // instance.Id -> HealthStatus - instHelths map[string]srvtypes.HealthStatus } func (s *SrvServer) getServiceById(id string) *srvtypes.Service { - for _, svc := range s.services { + for _, svc := range s.SrvServices { if *svc.Id == id { return svc } @@ -26,18 +23,40 @@ func (s *SrvServer) getServiceById(id string) *srvtypes.Service { return nil } -func (s *SrvServer) putInstHealth(id string, health srvtypes.HealthStatus) { - s.instHelths[id] = health +func (s *commons) CreateSrvService( + namepsaceName string, + serviceName string) *srvtypes.Service { + nsId := fmt.Sprintf("ns-%s", namepsaceName) + ns := &srvtypes.Namespace{ + Id: &nsId, + Name: &namepsaceName, + Arn: aws.String(fmt.Sprintf("arn:aws:servicediscovery:ap-northeast-1:123456789012:namespace/%s", nsId)), + } + svId := fmt.Sprintf("srv-%s", serviceName) + svc := &srvtypes.Service{ + NamespaceId: ns.Id, + Id: &svId, + Name: &serviceName, + Arn: aws.String(fmt.Sprintf("arn:aws:servicediscovery:ap-northeast-1:123456789012:service/%s", svId)), + InstanceCount: aws.Int32(0), + } + s.SrvNamespaces = append(s.SrvNamespaces, ns) + s.SrvServices = append(s.SrvServices, svc) + return svc +} + +func (s *commons) PutSrvInstHealth(id string, health srvtypes.HealthStatus) { + s.SrvInstHelths[id] = health } func (s *SrvServer) DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) { - insts, ok := s.insts[*params.ServiceName] + insts, ok := s.SrvInsts[*params.ServiceName] if !ok { return nil, xerrors.Errorf("service not found: %s", *params.ServiceName) } var summories []srvtypes.HttpInstanceSummary for _, inst := range insts { - health := s.instHelths[*inst.Id] + health := s.SrvInstHelths[*inst.Id] summories = append(summories, srvtypes.HttpInstanceSummary{ Attributes: inst.Attributes, InstanceId: inst.Id, @@ -57,8 +76,8 @@ func (s *SrvServer) RegisterInstance(ctx context.Context, params *servicediscove Id: params.InstanceId, Attributes: params.Attributes, } - s.insts[*srv.Name] = append(s.insts[*params.ServiceId], inst) - s.instHelths[*params.InstanceId] = srvtypes.HealthStatusUnhealthy + s.SrvInsts[*srv.Name] = append(s.SrvInsts[*params.ServiceId], inst) + s.SrvInstHelths[*params.InstanceId] = srvtypes.HealthStatusUnhealthy return &servicediscovery.RegisterInstanceOutput{}, nil } } @@ -68,7 +87,7 @@ func (s *SrvServer) DeregisterInstance(ctx context.Context, params *servicedisco if srv == nil { return nil, xerrors.Errorf("service not found: %s", *params.ServiceId) } - insts, ok := s.insts[*srv.Name] + insts, ok := s.SrvInsts[*srv.Name] if !ok { return nil, xerrors.Errorf("service not found: %s", *srv.Name) } @@ -78,7 +97,7 @@ func (s *SrvServer) DeregisterInstance(ctx context.Context, params *servicedisco newInsts = append(newInsts, inst) } } - s.insts[*srv.Name] = newInsts + s.SrvInsts[*srv.Name] = newInsts return &servicediscovery.DeregisterInstanceOutput{}, nil } @@ -89,3 +108,12 @@ func (s *SrvServer) GetService(ctx context.Context, params *servicediscovery.Get } return &servicediscovery.GetServiceOutput{Service: svc}, nil } + +func (s *SrvServer) GetNamespace(ctx context.Context, params *servicediscovery.GetNamespaceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetNamespaceOutput, error) { + for _, ns := range s.SrvNamespaces { + if *ns.Id == *params.Id { + return &servicediscovery.GetNamespaceOutput{Namespace: ns}, nil + } + } + return nil, xerrors.Errorf("namespace not found: %s", *params.Id) +} From 5b29e1b3340cdad4d75588da3e015c4c998638df Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Fri, 28 Jun 2024 18:00:38 +0900 Subject: [PATCH 11/47] a --- fixtures/service.json | 30 +++++ fixtures/task-definition.json | 202 ++++++++++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 fixtures/service.json create mode 100644 fixtures/task-definition.json diff --git a/fixtures/service.json b/fixtures/service.json new file mode 100644 index 0000000..a7bab44 --- /dev/null +++ b/fixtures/service.json @@ -0,0 +1,30 @@ +{ + "cluster": "cluster", + "serviceName": "service", + "taskDefinition": "test-task:1", + "loadBalancers": [ + { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-west-2:123456789012:targetgroup/tg/1234567890123456", + "loadBalancerName": "lb", + "containerName": "container", + "containerPort": 8000 + } + ], + "desiredCount": 1, + "launchType": "FARGATE", + "platformVersion": "1.4.0", + "role": "ecsServiceRole", + "deploymentConfiguration": { + "maximumPercent": 200, + "minimumHealthyPercent": 100 + }, + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": ["subnet-111", "subnet-222"], + "securityGroups": ["sg-111", "sg-222"], + "assignPublicIp": "ENABLED" + } + }, + "healthCheckGracePeriodSeconds": 0, + "schedulingStrategy": "REPLICA" +} diff --git a/fixtures/task-definition.json b/fixtures/task-definition.json new file mode 100644 index 0000000..23ecc4e --- /dev/null +++ b/fixtures/task-definition.json @@ -0,0 +1,202 @@ +{ + "family": "test-task", + "taskRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + "executionRoleArn": "arn:aws:iam::123456789012:role/ecsTaskExecutionRole", + "networkMode": "awsvpc", + "containerDefinitions": [ + { + "name": "container", + "image": "", + "repositoryCredentials": { + "credentialsParameter": "" + }, + "cpu": 0, + "memory": 0, + "memoryReservation": 0, + "links": [""], + "portMappings": [ + { + "containerPort": 8000, + "hostPort": 80 + } + ], + "essential": true, + "entryPoint": [""], + "command": [""], + "environment": [ + { + "name": "", + "value": "" + } + ], + "mountPoints": [ + { + "sourceVolume": "", + "containerPath": "", + "readOnly": true + } + ], + "volumesFrom": [ + { + "sourceContainer": "", + "readOnly": true + } + ], + "linuxParameters": { + "capabilities": { + "add": [""], + "drop": [""] + }, + "devices": [ + { + "hostPath": "", + "containerPath": "", + "permissions": ["mknod"] + } + ], + "initProcessEnabled": true, + "sharedMemorySize": 0, + "tmpfs": [ + { + "containerPath": "", + "size": 0, + "mountOptions": [""] + } + ] + }, + "hostname": "", + "user": "", + "workingDirectory": "", + "disableNetworking": true, + "privileged": true, + "readonlyRootFilesystem": true, + "dnsServers": [""], + "dnsSearchDomains": [""], + "extraHosts": [ + { + "hostname": "", + "ipAddress": "" + } + ], + "dockerSecurityOptions": [""], + "dockerLabels": { + "KeyName": "" + }, + "ulimits": [ + { + "name": "core", + "softLimit": 0, + "hardLimit": 0 + } + ], + "logConfiguration": { + "logDriver": "gelf", + "options": { + "KeyName": "" + } + }, + "healthCheck": { + "command": [""], + "interval": 0, + "timeout": 0, + "retries": 0, + "startPeriod": 0 + } + }, + { + "name": "containerWithoutHealthCheck", + "image": "", + "repositoryCredentials": { + "credentialsParameter": "" + }, + "cpu": 0, + "memory": 0, + "memoryReservation": 0, + "links": [""], + "portMappings": [ + { + "containerPort": 8000, + "hostPort": 80 + } + ], + "essential": true, + "entryPoint": [""], + "command": [""], + "environment": [ + { + "name": "", + "value": "" + } + ], + "mountPoints": [ + { + "sourceVolume": "", + "containerPath": "", + "readOnly": true + } + ], + "volumesFrom": [ + { + "sourceContainer": "", + "readOnly": true + } + ], + "linuxParameters": { + "capabilities": { + "add": [""], + "drop": [""] + }, + "devices": [ + { + "hostPath": "", + "containerPath": "", + "permissions": ["mknod"] + } + ], + "initProcessEnabled": true, + "sharedMemorySize": 0, + "tmpfs": [ + { + "containerPath": "", + "size": 0, + "mountOptions": [""] + } + ] + }, + "hostname": "", + "user": "", + "workingDirectory": "", + "disableNetworking": true, + "privileged": true, + "readonlyRootFilesystem": true, + "dnsServers": [""], + "dnsSearchDomains": [""], + "extraHosts": [ + { + "hostname": "", + "ipAddress": "" + } + ], + "dockerSecurityOptions": [""], + "dockerLabels": { + "KeyName": "" + }, + "ulimits": [ + { + "name": "core", + "softLimit": 0, + "hardLimit": 0 + } + ], + "logConfiguration": { + "logDriver": "gelf", + "options": { + "KeyName": "" + } + } + } + ], + "requiresCompatibilities": ["FARGATE"], + "cpu": "256", + "memory": "512" +} From 1e9975591e7fea3b4c85dbd1aadbb86f2f243881 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Fri, 28 Jun 2024 18:36:46 +0900 Subject: [PATCH 12/47] wip --- task/common.go | 4 ++++ task/srv_task.go | 6 +++--- task/srv_task_test.go | 7 +++++-- task/task.go | 1 + test/context.go | 7 ------- test/srv.go | 46 +++++++++++++++++++++++++++++++------------ 6 files changed, 46 insertions(+), 25 deletions(-) diff --git a/task/common.go b/task/common.go index 2db2e9a..b28d78c 100644 --- a/task/common.go +++ b/task/common.go @@ -68,6 +68,10 @@ func (c *common) Start(ctx context.Context) error { return nil } +func (c *common) TaskArn() *string { + return c.taskArn +} + func (c *common) waitForTask(ctx context.Context) error { log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ diff --git a/task/srv_task.go b/task/srv_task.go index f3a2af5..e36a3b6 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -50,6 +50,7 @@ func (c *srvTask) Stop(ctx context.Context) error { } func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { + log.Infof("registring canary task '%s' to service discovery...", *c.taskArn) var targetPort int32 if c.registry.Port != nil { targetPort = *c.registry.Port @@ -126,9 +127,6 @@ func (c *srvTask) waitUntilSrvInstHelthy( }); err != nil { return xerrors.Errorf("failed to discover instances: %w", err) } else { - if len(list.Instances) == 0 { - return xerrors.Errorf("no healthy instances found") - } for _, inst := range list.Instances { if *inst.InstanceId == *c.instId { return nil @@ -147,6 +145,7 @@ func (c *srvTask) deregisterSrvInst( if c.instId == nil { return } + log.Info("deregistering the canary task from service discovery...") if _, err := c.Srv.DeregisterInstance(ctx, &servicediscovery.DeregisterInstanceInput{ ServiceId: c.srv.Id, InstanceId: c.instId, @@ -154,4 +153,5 @@ func (c *srvTask) deregisterSrvInst( log.Errorf("failed to deregister the canary task from service discovery: %v", err) log.Errorf("continuing to stop the canary task...") } + log.Infof("canary task '%s' is deregistered from service discovery", *c.taskArn) } diff --git a/task/srv_task_test.go b/task/srv_task_test.go index 66b9de4..e4a5c07 100644 --- a/task/srv_task_test.go +++ b/task/srv_task_test.go @@ -36,12 +36,15 @@ func TestSrvTask(t *testing.T) { NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, Timeout: timeout.NewManager(1, &timeout.Input{}), }, &ecstypes.ServiceRegistry{RegistryArn: ®istryArn}) - srvSvc := mocker.CreateSrvService(srvNsName, srvSvcName) + _ = mocker.CreateSrvService(srvNsName, srvSvcName) err := stask.Start(ctx) assert.NoError(t, err) - mocker.PutSrvInstHealth(*srvSvc.Id, "healthy") + taskId := task.ArnToId(*stask.TaskArn()) + mocker.PutSrvInstHealth(taskId, "healthy") err = stask.Wait(ctx) assert.NoError(t, err) err = stask.Stop(ctx) assert.NoError(t, err) + assert.Equal(t, 1, mocker.RunningTaskSize()) + assert.Equal(t, 0, len(mocker.SrvInsts)) } diff --git a/task/task.go b/task/task.go index b5fed47..dd21a00 100644 --- a/task/task.go +++ b/task/task.go @@ -6,4 +6,5 @@ type Task interface { Start(ctx context.Context) error Wait(ctx context.Context) error Stop(ctx context.Context) error + TaskArn() *string } diff --git a/test/context.go b/test/context.go index 675c422..3f2e359 100644 --- a/test/context.go +++ b/test/context.go @@ -51,13 +51,6 @@ func NewMockContext() *MockContext { } } -func (ctx *commons) GetTask(id string) (*ecstypes.Task, bool) { - ctx.mux.Lock() - defer ctx.mux.Unlock() - o, ok := ctx.Tasks[id] - return o, ok -} - func (ctx *commons) RunningTaskSize() int { ctx.mux.Lock() defer ctx.mux.Unlock() diff --git a/test/srv.go b/test/srv.go index 5a8e3e0..25aed50 100644 --- a/test/srv.go +++ b/test/srv.go @@ -57,6 +57,19 @@ func (s *SrvServer) DiscoverInstances(ctx context.Context, params *servicediscov var summories []srvtypes.HttpInstanceSummary for _, inst := range insts { health := s.SrvInstHelths[*inst.Id] + if !matchInst(inst, params) { + continue + } + switch params.HealthStatus { + case srvtypes.HealthStatusFilterHealthy: + if health != srvtypes.HealthStatusHealthy { + continue + } + case srvtypes.HealthStatusFilterUnhealthy: + if health != srvtypes.HealthStatusUnhealthy { + continue + } + } summories = append(summories, srvtypes.HttpInstanceSummary{ Attributes: inst.Attributes, InstanceId: inst.Id, @@ -68,6 +81,15 @@ func (s *SrvServer) DiscoverInstances(ctx context.Context, params *servicediscov return &servicediscovery.DiscoverInstancesOutput{Instances: summories}, nil } +func matchInst(inst *srvtypes.Instance, params *servicediscovery.DiscoverInstancesInput) bool { + for k, v := range params.QueryParameters { + if act, ok := inst.Attributes[k]; !ok || act != v { + return false + } + } + return true +} + func (s *SrvServer) RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) { if srv := s.getServiceById(*params.ServiceId); srv == nil { return nil, xerrors.Errorf("service not found: %s", *params.ServiceId) @@ -83,22 +105,20 @@ func (s *SrvServer) RegisterInstance(ctx context.Context, params *servicediscove } func (s *SrvServer) DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) { - srv := s.getServiceById(*params.ServiceId) - if srv == nil { + if srv := s.getServiceById(*params.ServiceId); srv == nil { return nil, xerrors.Errorf("service not found: %s", *params.ServiceId) - } - insts, ok := s.SrvInsts[*srv.Name] - if !ok { - return nil, xerrors.Errorf("service not found: %s", *srv.Name) - } - var newInsts []*srvtypes.Instance - for _, inst := range insts { - if *inst.Id != *params.InstanceId { - newInsts = append(newInsts, inst) + } else { + insts := s.SrvInsts[*srv.Name] + for i, inst := range insts { + if *inst.Id == *params.InstanceId { + insts = append(insts[:i], insts[i+1:]...) + s.SrvInsts[*params.ServiceId] = insts + delete(s.SrvInstHelths, *params.InstanceId) + return &servicediscovery.DeregisterInstanceOutput{}, nil + } } + return nil, xerrors.Errorf("instance not found: %s", *params.InstanceId) } - s.SrvInsts[*srv.Name] = newInsts - return &servicediscovery.DeregisterInstanceOutput{}, nil } func (s *SrvServer) GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) { From 5e9296f62b25fed9f70eb78f2698ae2109a75ead Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 1 Jul 2024 16:35:22 +0900 Subject: [PATCH 13/47] add tests --- task/alb_task_test.go | 44 ++++++++++++++++++++++++++++++++++++++++ task/simple_task_test.go | 41 +++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 task/simple_task_test.go diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 651b54f..cd65892 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -1 +1,45 @@ package task_test + +import ( + "context" + "testing" + + elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "github.com/stretchr/testify/assert" +) + +func TestAlbTask(t *testing.T) { + mocker := test.NewMockContext() + env := test.DefaultEnvars() + ctx := context.TODO() + td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) + env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn + ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) + stask := task.NewAlbTask(&task.Input{ + Deps: &types.Deps{ + Env: env, + Ecs: mocker.Ecs, + Ec2: mocker.Ec2, + Alb: mocker.Alb, + Srv: mocker.Srv, + Time: test.NewFakeTime(), + }, + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, + Timeout: timeout.NewManager(1, &timeout.Input{}), + }, &ecsSvc.Service.LoadBalancers[0]) + mocker.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ + TargetGroupArn: ecsSvc.Service.LoadBalancers[0].TargetGroupArn, + }) + err := stask.Start(ctx) + assert.NoError(t, err) + err = stask.Wait(ctx) + assert.NoError(t, err) + err = stask.Stop(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, mocker.RunningTaskSize()) +} diff --git a/task/simple_task_test.go b/task/simple_task_test.go new file mode 100644 index 0000000..218cf57 --- /dev/null +++ b/task/simple_task_test.go @@ -0,0 +1,41 @@ +package task_test + +import ( + "context" + "testing" + + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "github.com/stretchr/testify/assert" +) + +func TestSimpleTask(t *testing.T) { + ctx := context.TODO() + mocker := test.NewMockContext() + env := test.DefaultEnvars() + td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) + env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn + ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) + stask := task.NewSimpleTask(&task.Input{ + Deps: &types.Deps{ + Env: env, + Ecs: mocker.Ecs, + Ec2: mocker.Ec2, + Alb: mocker.Alb, + Srv: mocker.Srv, + Time: test.NewFakeTime(), + }, + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, + Timeout: timeout.NewManager(1, &timeout.Input{}), + }) + err := stask.Start(ctx) + assert.NoError(t, err) + err = stask.Wait(ctx) + assert.NoError(t, err) + err = stask.Stop(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, mocker.RunningTaskSize()) +} From a8cdd7008445dcbeffbd5195b1dcdda81db5fbfc Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 1 Jul 2024 17:09:11 +0900 Subject: [PATCH 14/47] add tests --- task/srv_task.go | 5 ++--- task/srv_task_test.go | 2 +- test/srv.go | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/task/srv_task.go b/task/srv_task.go index e36a3b6..0ecb554 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -87,7 +87,7 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { "ECS_SERVICE_NAME": c.Env.Service, "ECS_TASK_DEFINITION_FAMILY": *c.TaskDefinition.Family, "REGION": c.Env.Region, - "CAGE_CANARY_TASK": "1", + "CAGE_TASK_ID": ArnToId(*c.taskArn), } taskId := ArnToId(*c.taskArn) if _, err := c.Srv.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ @@ -121,8 +121,7 @@ func (c *srvTask) waitUntilSrvInstHelthy( ServiceName: c.srv.Name, HealthStatus: srvtypes.HealthStatusFilterHealthy, QueryParameters: map[string]string{ - "CAGE_CANARY_TASK": "1", - "AWS_PRIVATE_IPV4": c.target.targetIpv4, + "CAGE_TASK_ID": ArnToId(*c.taskArn), }, }); err != nil { return xerrors.Errorf("failed to discover instances: %w", err) diff --git a/task/srv_task_test.go b/task/srv_task_test.go index e4a5c07..374caad 100644 --- a/task/srv_task_test.go +++ b/task/srv_task_test.go @@ -46,5 +46,5 @@ func TestSrvTask(t *testing.T) { err = stask.Stop(ctx) assert.NoError(t, err) assert.Equal(t, 1, mocker.RunningTaskSize()) - assert.Equal(t, 0, len(mocker.SrvInsts)) + assert.Equal(t, 0, len(mocker.SrvInsts[srvSvcName])) } diff --git a/test/srv.go b/test/srv.go index 25aed50..08887e3 100644 --- a/test/srv.go +++ b/test/srv.go @@ -99,7 +99,7 @@ func (s *SrvServer) RegisterInstance(ctx context.Context, params *servicediscove Attributes: params.Attributes, } s.SrvInsts[*srv.Name] = append(s.SrvInsts[*params.ServiceId], inst) - s.SrvInstHelths[*params.InstanceId] = srvtypes.HealthStatusUnhealthy + s.SrvInstHelths[*params.InstanceId] = srvtypes.HealthStatusHealthy return &servicediscovery.RegisterInstanceOutput{}, nil } } @@ -112,7 +112,7 @@ func (s *SrvServer) DeregisterInstance(ctx context.Context, params *servicedisco for i, inst := range insts { if *inst.Id == *params.InstanceId { insts = append(insts[:i], insts[i+1:]...) - s.SrvInsts[*params.ServiceId] = insts + s.SrvInsts[*srv.Name] = insts delete(s.SrvInstHelths, *params.InstanceId) return &servicediscovery.DeregisterInstanceOutput{}, nil } From fca8f4aed0dcf92cbcc318cc8e4f1de4e09705e1 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 1 Jul 2024 17:24:44 +0900 Subject: [PATCH 15/47] Update timeout_test.go --- timeout/timeout_test.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/timeout/timeout_test.go b/timeout/timeout_test.go index 7643af2..9b4517d 100644 --- a/timeout/timeout_test.go +++ b/timeout/timeout_test.go @@ -15,17 +15,20 @@ func TestManager(t *testing.T) { assert.Equal(t, time.Duration(10), man.TaskStopped()) assert.Equal(t, time.Duration(10), man.TaskHealthCheck()) assert.Equal(t, time.Duration(10), man.ServiceStable()) + assert.Equal(t, time.Duration(10), man.TargetHealthCheck()) }) t.Run("with config", func(t *testing.T) { man := timeout.NewManager(10, &timeout.Input{ - TaskRunningWait: 1, - TaskStoppedWait: 2, - TaskHealthCheckWait: 3, - ServiceStableWait: 4, + TaskRunningWait: 1, + TaskStoppedWait: 2, + TaskHealthCheckWait: 3, + ServiceStableWait: 4, + TargetHealthCheckWait: 5, }) assert.Equal(t, time.Duration(1), man.TaskRunning()) assert.Equal(t, time.Duration(2), man.TaskStopped()) assert.Equal(t, time.Duration(3), man.TaskHealthCheck()) assert.Equal(t, time.Duration(4), man.ServiceStable()) + assert.Equal(t, time.Duration(5), man.TargetHealthCheck()) }) } From ad6eb7f3f3559a4e165e9ece0d03197022d4399c Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 1 Jul 2024 18:43:33 +0900 Subject: [PATCH 16/47] add test --- Makefile | 11 ++- cage.go | 5 +- mocks/mock_task/task.go | 14 ++++ mocks/mock_taskset/factory.go | 78 +++++++++++++++++++ mocks/mock_taskset/taskset.go | 63 +++++++++++++++ rollout.go | 81 ++++++------------- task_factory.go => taskset/factory.go | 16 ++-- taskset/factory_test.go | 33 ++++++++ taskset/taskset.go | 73 ++++++++++++++++++ taskset/taskset_test.go | 107 ++++++++++++++++++++++++++ 10 files changed, 410 insertions(+), 71 deletions(-) create mode 100644 mocks/mock_taskset/factory.go create mode 100644 mocks/mock_taskset/taskset.go rename task_factory.go => taskset/factory.go (53%) create mode 100644 taskset/factory_test.go create mode 100644 taskset/taskset.go create mode 100644 taskset/taskset_test.go diff --git a/Makefile b/Makefile index 758b07f..17a9652 100644 --- a/Makefile +++ b/Makefile @@ -12,16 +12,19 @@ version: mocks: mocks/mock_awsiface/iface.go \ mocks/mock_types/iface.go \ mocks/mock_upgrade/upgrade.go \ - mocks/mock_cage/task_factory.go \ - mocks/mock_task/task.go + mocks/mock_task/task.go \ + mocks/mock_taskset/taskset.go \ + mocks/mock_taskset/factory.go mocks/mock_awsiface/iface.go: awsiface/iface.go $(MOCKGEN) -source=./awsiface/iface.go > mocks/mock_awsiface/iface.go mocks/mock_types/iface.go: cage.go $(MOCKGEN) -source=./types/iface.go > mocks/mock_types/iface.go -mocks/mock_cage/task_factory.go: task_factory.go - $(MOCKGEN) -source=./task_factory.go > mocks/mock_cage/task_factory.go mocks/mock_upgrade/upgrade.go: cli/cage/upgrade/upgrade.go $(MOCKGEN) -source=./cli/cage/upgrade/upgrade.go > mocks/mock_upgrade/upgrade.go mocks/mock_task/task.go: task/task.go $(MOCKGEN) -source=./task/task.go > mocks/mock_task/task.go +mocks/mock_taskset/taskset.go: taskset/taskset.go + $(MOCKGEN) -source=./taskset/taskset.go > mocks/mock_taskset/taskset.go +mocks/mock_taskset/factory.go: taskset/factory.go + $(MOCKGEN) -source=./taskset/factory.go > mocks/mock_taskset/factory.go .PHONY: mocks diff --git a/cage.go b/cage.go index 24b3998..2e2c69a 100644 --- a/cage.go +++ b/cage.go @@ -3,6 +3,7 @@ package cage import ( "time" + "github.com/loilo-inc/canarycage/taskset" "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" ) @@ -10,7 +11,7 @@ import ( type cage struct { *types.Deps Timeout timeout.Manager - TaskFactory TaskFactory + TaskFactory taskset.Factory } func NewCage(input *types.Deps) types.Cage { @@ -33,6 +34,6 @@ func NewCage(input *types.Deps) types.Cage { ServiceStableWait: serviceStableWait, TargetHealthCheckWait: targetHealthCheckWait, }), - TaskFactory: &taskFactory{}, + TaskFactory: taskset.NewFactory(), } } diff --git a/mocks/mock_task/task.go b/mocks/mock_task/task.go index fa6ea25..efc19e6 100644 --- a/mocks/mock_task/task.go +++ b/mocks/mock_task/task.go @@ -62,6 +62,20 @@ func (mr *MockTaskMockRecorder) Stop(ctx interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockTask)(nil).Stop), ctx) } +// TaskArn mocks base method. +func (m *MockTask) TaskArn() *string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TaskArn") + ret0, _ := ret[0].(*string) + return ret0 +} + +// TaskArn indicates an expected call of TaskArn. +func (mr *MockTaskMockRecorder) TaskArn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TaskArn", reflect.TypeOf((*MockTask)(nil).TaskArn)) +} + // Wait mocks base method. func (m *MockTask) Wait(ctx context.Context) error { m.ctrl.T.Helper() diff --git a/mocks/mock_taskset/factory.go b/mocks/mock_taskset/factory.go new file mode 100644 index 0000000..c646b24 --- /dev/null +++ b/mocks/mock_taskset/factory.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./taskset/factory.go + +// Package mock_taskset is a generated GoMock package. +package mock_taskset + +import ( + reflect "reflect" + + types "github.com/aws/aws-sdk-go-v2/service/ecs/types" + gomock "github.com/golang/mock/gomock" + task "github.com/loilo-inc/canarycage/task" +) + +// MockFactory is a mock of Factory interface. +type MockFactory struct { + ctrl *gomock.Controller + recorder *MockFactoryMockRecorder +} + +// MockFactoryMockRecorder is the mock recorder for MockFactory. +type MockFactoryMockRecorder struct { + mock *MockFactory +} + +// NewMockFactory creates a new mock instance. +func NewMockFactory(ctrl *gomock.Controller) *MockFactory { + mock := &MockFactory{ctrl: ctrl} + mock.recorder = &MockFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFactory) EXPECT() *MockFactoryMockRecorder { + return m.recorder +} + +// NewAlbTask mocks base method. +func (m *MockFactory) NewAlbTask(input *task.Input, lb *types.LoadBalancer) task.Task { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewAlbTask", input, lb) + ret0, _ := ret[0].(task.Task) + return ret0 +} + +// NewAlbTask indicates an expected call of NewAlbTask. +func (mr *MockFactoryMockRecorder) NewAlbTask(input, lb interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAlbTask", reflect.TypeOf((*MockFactory)(nil).NewAlbTask), input, lb) +} + +// NewSimpleTask mocks base method. +func (m *MockFactory) NewSimpleTask(input *task.Input) task.Task { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSimpleTask", input) + ret0, _ := ret[0].(task.Task) + return ret0 +} + +// NewSimpleTask indicates an expected call of NewSimpleTask. +func (mr *MockFactoryMockRecorder) NewSimpleTask(input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSimpleTask", reflect.TypeOf((*MockFactory)(nil).NewSimpleTask), input) +} + +// NewSrvTask mocks base method. +func (m *MockFactory) NewSrvTask(input *task.Input, srv *types.ServiceRegistry) task.Task { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewSrvTask", input, srv) + ret0, _ := ret[0].(task.Task) + return ret0 +} + +// NewSrvTask indicates an expected call of NewSrvTask. +func (mr *MockFactoryMockRecorder) NewSrvTask(input, srv interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSrvTask", reflect.TypeOf((*MockFactory)(nil).NewSrvTask), input, srv) +} diff --git a/mocks/mock_taskset/taskset.go b/mocks/mock_taskset/taskset.go new file mode 100644 index 0000000..8127ce8 --- /dev/null +++ b/mocks/mock_taskset/taskset.go @@ -0,0 +1,63 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./taskset/taskset.go + +// Package mock_taskset is a generated GoMock package. +package mock_taskset + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockSet is a mock of Set interface. +type MockSet struct { + ctrl *gomock.Controller + recorder *MockSetMockRecorder +} + +// MockSetMockRecorder is the mock recorder for MockSet. +type MockSetMockRecorder struct { + mock *MockSet +} + +// NewMockSet creates a new mock instance. +func NewMockSet(ctrl *gomock.Controller) *MockSet { + mock := &MockSet{ctrl: ctrl} + mock.recorder = &MockSetMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSet) EXPECT() *MockSetMockRecorder { + return m.recorder +} + +// Cleanup mocks base method. +func (m *MockSet) Cleanup(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Cleanup", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Cleanup indicates an expected call of Cleanup. +func (mr *MockSetMockRecorder) Cleanup(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Cleanup", reflect.TypeOf((*MockSet)(nil).Cleanup), ctx) +} + +// Exec mocks base method. +func (m *MockSet) Exec(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exec", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockSetMockRecorder) Exec(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockSet)(nil).Exec), ctx) +} diff --git a/rollout.go b/rollout.go index 13ef773..346223b 100644 --- a/rollout.go +++ b/rollout.go @@ -7,8 +7,8 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/taskset" "github.com/loilo-inc/canarycage/types" - "golang.org/x/sync/errgroup" "golang.org/x/xerrors" ) @@ -49,28 +49,20 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R // ensure canary task stopped after rolling out either success or failure defer func() { _ = recover() - eg := errgroup.Group{} - for _, canaryTask := range canaryTasks { - eg.Go(func() error { - return canaryTask.Stop(ctx) - }) - } - if err := eg.Wait(); err != nil { - log.Errorf("failed to stop canary tasks due to: %s", err) + if canaryTasks == nil { + return + } else if err := canaryTasks.Cleanup(ctx); err != nil { + log.Errorf("failed to cleanup canary tasks due to: %s", err) } }() if startCanaryTaskErr != nil { return result, xerrors.Errorf("failed to start canary task due to: %w", startCanaryTaskErr) } - eg := errgroup.Group{} - for _, canaryTask := range canaryTasks { - eg.Go(func() error { - return canaryTask.Wait(ctx) - }) - } - if err := eg.Wait(); err != nil { - return result, xerrors.Errorf("failed to wait for canary task due to: %w", err) + log.Infof("executing canary tasks...") + if err := canaryTasks.Exec(ctx); err != nil { + return result, xerrors.Errorf("failed to exec canary task due to: %w", err) } + log.Infof("canary tasks have been executed successfully!") log.Infof( "updating the task definition of '%s' into '%s:%d'...", c.Env.Service, *nextTaskDefinition.Family, nextTaskDefinition.Revision, @@ -111,7 +103,7 @@ func (c *cage) StartCanaryTasks( ctx context.Context, nextTaskDefinition *ecstypes.TaskDefinition, input *types.RollOutInput, -) ([]task.Task, error) { +) (taskset.Set, error) { var networkConfiguration *ecstypes.NetworkConfiguration var platformVersion *string var loadBalancers []ecstypes.LoadBalancer @@ -135,45 +127,16 @@ func (c *cage) StartCanaryTasks( serviceRegistries = service.ServiceRegistries } } - var results []task.Task - for _, lb := range loadBalancers { - task := c.TaskFactory.NewAlbTask(&task.Input{ - Deps: c.Deps, - NetworkConfiguration: networkConfiguration, - TaskDefinition: nextTaskDefinition, - PlatformVersion: platformVersion, - Timeout: c.Timeout, - }, &lb) - results = append(results, task) - if err := task.Start(ctx); err != nil { - return results, err - } - } - for _, srv := range serviceRegistries { - task := c.TaskFactory.NewSrvTask(&task.Input{ - Deps: c.Deps, - NetworkConfiguration: networkConfiguration, - TaskDefinition: nextTaskDefinition, - PlatformVersion: platformVersion, - Timeout: c.Timeout, - }, &srv) - results = append(results, task) - if err := task.Start(ctx); err != nil { - return results, err - } - } - if len(results) == 0 { - task := c.TaskFactory.NewSimpleTask(&task.Input{ - Deps: c.Deps, - NetworkConfiguration: networkConfiguration, - TaskDefinition: nextTaskDefinition, - PlatformVersion: platformVersion, - Timeout: c.Timeout, - }) - results = append(results, task) - if err := task.Start(ctx); err != nil { - return results, err - } - } - return results, nil + return taskset.NewSet( + c.TaskFactory, + &taskset.Input{ + Input: &task.Input{ + NetworkConfiguration: networkConfiguration, + TaskDefinition: nextTaskDefinition, + PlatformVersion: platformVersion, + }, + LoadBalancers: loadBalancers, + ServiceRegistries: serviceRegistries, + }, + ), nil } diff --git a/task_factory.go b/taskset/factory.go similarity index 53% rename from task_factory.go rename to taskset/factory.go index bb11fcc..8391c8b 100644 --- a/task_factory.go +++ b/taskset/factory.go @@ -1,26 +1,30 @@ -package cage +package taskset import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/loilo-inc/canarycage/task" ) -type TaskFactory interface { +type Factory interface { NewAlbTask(input *task.Input, lb *ecstypes.LoadBalancer) task.Task NewSrvTask(input *task.Input, srv *ecstypes.ServiceRegistry) task.Task NewSimpleTask(input *task.Input) task.Task } -type taskFactory struct{} +type factory struct{} -func (f *taskFactory) NewAlbTask(input *task.Input, lb *ecstypes.LoadBalancer) task.Task { +func NewFactory() Factory { + return &factory{} +} + +func (f *factory) NewAlbTask(input *task.Input, lb *ecstypes.LoadBalancer) task.Task { return task.NewAlbTask(input, lb) } -func (f *taskFactory) NewSrvTask(input *task.Input, srv *ecstypes.ServiceRegistry) task.Task { +func (f *factory) NewSrvTask(input *task.Input, srv *ecstypes.ServiceRegistry) task.Task { return task.NewSrvTask(input, srv) } -func (f *taskFactory) NewSimpleTask(input *task.Input) task.Task { +func (f *factory) NewSimpleTask(input *task.Input) task.Task { return task.NewSimpleTask(input) } diff --git a/taskset/factory_test.go b/taskset/factory_test.go new file mode 100644 index 0000000..4521491 --- /dev/null +++ b/taskset/factory_test.go @@ -0,0 +1,33 @@ +package taskset_test + +import ( + "testing" + + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/taskset" + "github.com/stretchr/testify/assert" +) + +func TestFactory(t *testing.T) { + t.Run("NewAlbTask", func(t *testing.T) { + f := taskset.NewFactory() + input := &task.Input{} + lb := &ecstypes.LoadBalancer{} + task := f.NewAlbTask(input, lb) + assert.NotNil(t, task) + }) + t.Run("NewSrvTask", func(t *testing.T) { + f := taskset.NewFactory() + input := &task.Input{} + srv := &ecstypes.ServiceRegistry{} + task := f.NewSrvTask(input, srv) + assert.NotNil(t, task) + }) + t.Run("NewSimpleTask", func(t *testing.T) { + f := taskset.NewFactory() + input := &task.Input{} + task := f.NewSimpleTask(input) + assert.NotNil(t, task) + }) +} diff --git a/taskset/taskset.go b/taskset/taskset.go new file mode 100644 index 0000000..9d08eb8 --- /dev/null +++ b/taskset/taskset.go @@ -0,0 +1,73 @@ +package taskset + +import ( + "context" + + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/task" + "golang.org/x/sync/errgroup" +) + +type Set interface { + Exec(ctx context.Context) error + Cleanup(ctx context.Context) error +} + +type set struct { + tasks []task.Task +} + +type Input struct { + *task.Input + LoadBalancers []ecstypes.LoadBalancer + ServiceRegistries []ecstypes.ServiceRegistry +} + +func NewSet(factory Factory, input *Input) Set { + var results []task.Task + taskInput := &task.Input{ + Deps: input.Deps, + NetworkConfiguration: input.NetworkConfiguration, + TaskDefinition: input.TaskDefinition, + PlatformVersion: input.PlatformVersion, + Timeout: input.Timeout, + } + for _, lb := range input.LoadBalancers { + task := factory.NewAlbTask(taskInput, &lb) + results = append(results, task) + } + for _, srv := range input.ServiceRegistries { + task := factory.NewSrvTask(taskInput, &srv) + results = append(results, task) + } + if len(results) == 0 { + task := factory.NewSimpleTask(taskInput) + results = append(results, task) + } + return &set{tasks: results} +} + +func (s *set) Exec(ctx context.Context) error { + for _, t := range s.tasks { + if err := t.Start(ctx); err != nil { + return err + } + } + eg := errgroup.Group{} + for _, t := range s.tasks { + eg.Go(func() error { + return t.Wait(ctx) + }) + } + return eg.Wait() +} + +func (s *set) Cleanup(ctx context.Context) error { + eg := errgroup.Group{} + for _, t := range s.tasks { + eg.Go(func() error { + return t.Stop(ctx) + }) + } + return eg.Wait() +} diff --git a/taskset/taskset_test.go b/taskset/taskset_test.go new file mode 100644 index 0000000..008bc83 --- /dev/null +++ b/taskset/taskset_test.go @@ -0,0 +1,107 @@ +package taskset_test + +import ( + "context" + "fmt" + "testing" + + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/golang/mock/gomock" + "github.com/loilo-inc/canarycage/mocks/mock_task" + "github.com/loilo-inc/canarycage/mocks/mock_taskset" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/taskset" + "github.com/stretchr/testify/assert" +) + +func TestSet(t *testing.T) { + t.Run("basic", func(t *testing.T) { + ctrl := gomock.NewController(t) + factory := mock_taskset.NewMockFactory(ctrl) + albTask := mock_task.NewMockTask(ctrl) + srvTask := mock_task.NewMockTask(ctrl) + lb := ecstypes.LoadBalancer{} + srg := ecstypes.ServiceRegistry{} + gomock.InOrder( + factory.EXPECT().NewAlbTask(gomock.Any(), &lb).Return(albTask), + factory.EXPECT().NewSrvTask(gomock.Any(), &srg).Return(srvTask), + ) + gomock.InOrder( + albTask.EXPECT().Start(gomock.Any()).Return(nil), + srvTask.EXPECT().Start(gomock.Any()).Return(nil), + ) + albTask.EXPECT().Wait(gomock.Any()).Return(nil) + srvTask.EXPECT().Wait(gomock.Any()).Return(nil) + albTask.EXPECT().Stop(gomock.Any()).Return(nil) + srvTask.EXPECT().Stop(gomock.Any()).Return(nil) + input := &taskset.Input{ + Input: &task.Input{}, + LoadBalancers: []ecstypes.LoadBalancer{lb}, + ServiceRegistries: []ecstypes.ServiceRegistry{srg}, + } + set := taskset.NewSet(factory, input) + ctx := context.TODO() + assert.NoError(t, set.Exec(ctx)) + assert.NoError(t, set.Cleanup(ctx)) + }) + t.Run("should add a simple task if no load balancer or service registry is given", func(t *testing.T) { + ctrl := gomock.NewController(t) + factory := mock_taskset.NewMockFactory(ctrl) + simpleTask := mock_task.NewMockTask(ctrl) + input := &taskset.Input{ + Input: &task.Input{}, + } + factory.EXPECT().NewSimpleTask(input.Input).Return(simpleTask) + simpleTask.EXPECT().Start(gomock.Any()).Return(nil) + simpleTask.EXPECT().Wait(gomock.Any()).Return(nil) + simpleTask.EXPECT().Stop(gomock.Any()).Return(nil) + set := taskset.NewSet(factory, input) + ctx := context.TODO() + assert.NoError(t, set.Exec(ctx)) + assert.NoError(t, set.Cleanup(ctx)) + }) + t.Run("should aggregate errors from task.Wait", func(t *testing.T) { + ctrl := gomock.NewController(t) + factory := mock_taskset.NewMockFactory(ctrl) + albTask := mock_task.NewMockTask(ctrl) + srvTask := mock_task.NewMockTask(ctrl) + lb := ecstypes.LoadBalancer{} + srg := ecstypes.ServiceRegistry{} + gomock.InOrder( + factory.EXPECT().NewAlbTask(gomock.Any(), &lb).Return(albTask), + factory.EXPECT().NewSrvTask(gomock.Any(), &srg).Return(srvTask), + ) + gomock.InOrder( + albTask.EXPECT().Start(gomock.Any()).Return(nil), + srvTask.EXPECT().Start(gomock.Any()).Return(nil), + ) + albTask.EXPECT().Wait(gomock.Any()).Return(fmt.Errorf("error")) + srvTask.EXPECT().Wait(gomock.Any()).Return(nil) + input := &taskset.Input{ + Input: &task.Input{}, + LoadBalancers: []ecstypes.LoadBalancer{lb}, + ServiceRegistries: []ecstypes.ServiceRegistry{srg}, + } + set := taskset.NewSet(factory, input) + ctx := context.TODO() + err := set.Exec(ctx) + assert.EqualError(t, err, "error") + }) + + t.Run("should error immediately if task failed to start", func(t *testing.T) { + ctrl := gomock.NewController(t) + factory := mock_taskset.NewMockFactory(ctrl) + mtask := mock_task.NewMockTask(ctrl) + lb := &ecstypes.LoadBalancer{} + factory.EXPECT().NewAlbTask(gomock.Any(), lb).Return(mtask) + mtask.EXPECT().Start(gomock.Any()).Return(fmt.Errorf("error")) + input := &taskset.Input{ + Input: &task.Input{}, + LoadBalancers: []ecstypes.LoadBalancer{{}}, + } + set := taskset.NewSet(factory, input) + ctx := context.TODO() + err := set.Exec(ctx) + assert.EqualError(t, err, "error") + }) +} From 9e99624546ccf9a5cfb79994534193ce5236aea3 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 1 Jul 2024 18:55:25 +0900 Subject: [PATCH 17/47] add tests --- rollout.go | 3 ++- rollout_test.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/rollout.go b/rollout.go index 346223b..65d3de8 100644 --- a/rollout.go +++ b/rollout.go @@ -44,7 +44,6 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R if input.UpdateService { log.Info("--updateService flag is set. use provided service configurations for canary test instead of current service") } - log.Infof("starting canary task...") canaryTasks, startCanaryTaskErr := c.StartCanaryTasks(ctx, nextTaskDefinition, input) // ensure canary task stopped after rolling out either success or failure defer func() { @@ -131,9 +130,11 @@ func (c *cage) StartCanaryTasks( c.TaskFactory, &taskset.Input{ Input: &task.Input{ + Deps: c.Deps, NetworkConfiguration: networkConfiguration, TaskDefinition: nextTaskDefinition, PlatformVersion: platformVersion, + Timeout: c.Timeout, }, LoadBalancers: loadBalancers, ServiceRegistries: serviceRegistries, diff --git a/rollout_test.go b/rollout_test.go index 3a19fbe..61e9795 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -210,7 +210,7 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, } result, err := cagecli.RollOut(ctx, &types.RollOutInput{UpdateService: true}) - assert.EqualError(t, err, "failed to wait for canary task due to: couldn't find host port in container definition") + assert.EqualError(t, err, "failed to exec canary task due to: couldn't find host port in container definition") assert.Equal(t, result.ServiceIntact, true) assert.Equal(t, 1, mctx.RunningTaskSize()) }) From e9d44e45d946626e8317cf6238986e39253ddcee Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 1 Jul 2024 20:21:22 +0900 Subject: [PATCH 18/47] di --- cage.go | 33 +---- cli/cage/commands/command.go | 21 ++- export_test.go | 6 + go.mod | 1 + go.sum | 3 + key/keys.go | 14 ++ rollout.go | 76 ++++++----- rollout_test.go | 209 +++++++++++++++++------------- run.go | 29 +++-- run_test.go | 69 +++++----- task/alb_task.go | 31 +++-- task/alb_task_test.go | 23 ++-- task/common.go | 82 +++++++----- task/common_test.go | 9 ++ task/export_test.go | 3 + task/factory.go | 32 +++++ {taskset => task}/factory_test.go | 11 +- task/simple_task.go | 22 +++- task/simple_task_test.go | 23 ++-- task/srv_task.go | 36 +++-- task/srv_task_test.go | 23 ++-- task_definition.go | 15 ++- task_definition_test.go | 43 +++--- taskset/factory.go | 30 ----- taskset/taskset.go | 6 +- time.go | 12 -- timeout/time.go | 12 ++ timeout/timeout.go | 43 +++--- timeout/timeout_test.go | 27 ++-- types/iface.go | 11 -- up.go | 31 +++-- up_test.go | 22 ++-- 32 files changed, 573 insertions(+), 435 deletions(-) create mode 100644 key/keys.go create mode 100644 task/common_test.go create mode 100644 task/export_test.go create mode 100644 task/factory.go rename {taskset => task}/factory_test.go (81%) delete mode 100644 taskset/factory.go delete mode 100644 time.go create mode 100644 timeout/time.go diff --git a/cage.go b/cage.go index 2e2c69a..d526a60 100644 --- a/cage.go +++ b/cage.go @@ -1,39 +1,14 @@ package cage import ( - "time" - - "github.com/loilo-inc/canarycage/taskset" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" ) type cage struct { - *types.Deps - Timeout timeout.Manager - TaskFactory taskset.Factory + di *di.D } -func NewCage(input *types.Deps) types.Cage { - if input.Time == nil { - input.Time = &timeImpl{} - } - taskRunningWait := (time.Duration)(input.Env.CanaryTaskRunningWait) * time.Second - taskHealthCheckWait := (time.Duration)(input.Env.CanaryTaskHealthCheckWait) * time.Second - taskStoppedWait := (time.Duration)(input.Env.CanaryTaskStoppedWait) * time.Second - serviceStableWait := (time.Duration)(input.Env.ServiceStableWait) * time.Second - targetHealthCheckWait := (time.Duration)(input.Env.TargetHealthCheckWait) * time.Second - return &cage{ - Deps: input, - Timeout: timeout.NewManager( - 15*time.Minute, - &timeout.Input{ - TaskRunningWait: taskRunningWait, - TaskHealthCheckWait: taskHealthCheckWait, - TaskStoppedWait: taskStoppedWait, - ServiceStableWait: serviceStableWait, - TargetHealthCheckWait: targetHealthCheckWait, - }), - TaskFactory: taskset.NewFactory(), - } +func NewCage(di *di.D) types.Cage { + return &cage{di} } diff --git a/cli/cage/commands/command.go b/cli/cage/commands/command.go index 8c2e721..dcb3167 100644 --- a/cli/cage/commands/command.go +++ b/cli/cage/commands/command.go @@ -3,15 +3,21 @@ package commands import ( "context" "io" + "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/prompt" "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/urfave/cli/v2" "golang.org/x/xerrors" ) @@ -40,12 +46,17 @@ func DefalutCageCliProvider(envars *env.Envars) (types.Cage, error) { if err != nil { return nil, xerrors.Errorf("failed to load aws config: %w", err) } - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecs.NewFromConfig(conf), - Ec2: ec2.NewFromConfig(conf), - Alb: elasticloadbalancingv2.NewFromConfig(conf), + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecs.NewFromConfig(conf)) + b.Set(key.Ec2Cli, ec2.NewFromConfig(conf)) + b.Set(key.AlbCli, elasticloadbalancingv2.NewFromConfig(conf)) + b.Set(key.SrvCli, servicediscovery.NewFromConfig(conf)) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 15*time.Minute)) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, &timeout.Time{}) }) + cagecli := cage.NewCage(d) return cagecli, nil } diff --git a/export_test.go b/export_test.go index 901fd27..fd50637 100644 --- a/export_test.go +++ b/export_test.go @@ -1,3 +1,9 @@ package cage +import "github.com/loilo-inc/logos/di" + type CageExport = cage + +func NewCageExport(di *di.D) *cage { + return &cage{di} +} diff --git a/go.mod b/go.mod index bd7dc09..3361920 100644 --- a/go.mod +++ b/go.mod @@ -36,6 +36,7 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/loilo-inc/logos v1.0.0 github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 760eb4b..4ed95c0 100644 --- a/go.sum +++ b/go.sum @@ -84,6 +84,8 @@ github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfn github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/loilo-inc/logos v1.0.0 h1:yB2ivcz+o3RdqOGlB8t3eEzNW4VhIhZH+jJfCO119IE= +github.com/loilo-inc/logos v1.0.0/go.mod h1:bFtwBXw8rrKp6pc92dytMJBLbfkZt2WTXJSaz27Vmew= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= @@ -108,6 +110,7 @@ github.com/smartystreets/gunit v1.0.0/go.mod h1:qwPWnhz6pn0NnRBP++URONOVyNkPyr4S github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLDRpvE+3b7gP/C2YyLFYxNmcLnPTMe0= diff --git a/key/keys.go b/key/keys.go new file mode 100644 index 0000000..023072b --- /dev/null +++ b/key/keys.go @@ -0,0 +1,14 @@ +package key + +type DepsKey string + +const ( + EcsCli DepsKey = "ecs" + Ec2Cli DepsKey = "ec2" + SrvCli DepsKey = "srv" + AlbCli DepsKey = "alb" + Env DepsKey = "env" + TimeoutManager DepsKey = "timeout-manager" + Time DepsKey = "time" + TaskFactory DepsKey = "task-factory" +) diff --git a/rollout.go b/rollout.go index 65d3de8..7e8ad1f 100644 --- a/rollout.go +++ b/rollout.go @@ -6,8 +6,12 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/taskset" + "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) @@ -16,21 +20,21 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R result := &types.RollOutResult{ ServiceIntact: true, } - if out, err := c.Ecs.DescribeServices(ctx, &ecs.DescribeServicesInput{ - Cluster: &c.Env.Cluster, - Services: []string{ - c.Env.Service, - }, + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + if out, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Cluster: &env.Cluster, + Services: []string{env.Service}, }); err != nil { return result, xerrors.Errorf("failed to describe current service due to: %w", err) } else if len(out.Services) == 0 { - return result, xerrors.Errorf("service '%s' doesn't exist. Run 'cage up' or create service before rolling out", c.Env.Service) + return result, xerrors.Errorf("service '%s' doesn't exist. Run 'cage up' or create service before rolling out", env.Service) } else { service := out.Services[0] if *service.Status != "ACTIVE" { - return result, xerrors.Errorf("😵 '%s' status is '%s'. Stop rolling out", c.Env.Service, *service.Status) + return result, xerrors.Errorf("😵 '%s' status is '%s'. Stop rolling out", env.Service, *service.Status) } - if service.LaunchType == ecstypes.LaunchTypeEc2 && c.Env.CanaryInstanceArn == "" { + if service.LaunchType == ecstypes.LaunchTypeEc2 && env.CanaryInstanceArn == "" { return result, xerrors.Errorf("🥺 --canaryInstanceArn is required when LaunchType = 'EC2'") } } @@ -64,36 +68,37 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R log.Infof("canary tasks have been executed successfully!") log.Infof( "updating the task definition of '%s' into '%s:%d'...", - c.Env.Service, *nextTaskDefinition.Family, nextTaskDefinition.Revision, + env.Service, *nextTaskDefinition.Family, nextTaskDefinition.Revision, ) updateInput := &ecs.UpdateServiceInput{ - Cluster: &c.Env.Cluster, - Service: &c.Env.Service, + Cluster: &env.Cluster, + Service: &env.Service, TaskDefinition: nextTaskDefinition.TaskDefinitionArn, } if input.UpdateService { - updateInput.LoadBalancers = c.Env.ServiceDefinitionInput.LoadBalancers - updateInput.NetworkConfiguration = c.Env.ServiceDefinitionInput.NetworkConfiguration - updateInput.ServiceConnectConfiguration = c.Env.ServiceDefinitionInput.ServiceConnectConfiguration - updateInput.ServiceRegistries = c.Env.ServiceDefinitionInput.ServiceRegistries - updateInput.PlatformVersion = c.Env.ServiceDefinitionInput.PlatformVersion - updateInput.VolumeConfigurations = c.Env.ServiceDefinitionInput.VolumeConfigurations + updateInput.LoadBalancers = env.ServiceDefinitionInput.LoadBalancers + updateInput.NetworkConfiguration = env.ServiceDefinitionInput.NetworkConfiguration + updateInput.ServiceConnectConfiguration = env.ServiceDefinitionInput.ServiceConnectConfiguration + updateInput.ServiceRegistries = env.ServiceDefinitionInput.ServiceRegistries + updateInput.PlatformVersion = env.ServiceDefinitionInput.PlatformVersion + updateInput.VolumeConfigurations = env.ServiceDefinitionInput.VolumeConfigurations } - if _, err := c.Ecs.UpdateService(ctx, updateInput); err != nil { + if _, err := ecsCli.UpdateService(ctx, updateInput); err != nil { return result, err } result.ServiceIntact = false - log.Infof("waiting for service '%s' to be stable...", c.Env.Service) - if err := ecs.NewServicesStableWaiter(c.Ecs).Wait(ctx, &ecs.DescribeServicesInput{ - Cluster: &c.Env.Cluster, - Services: []string{c.Env.Service}, - }, c.Timeout.ServiceStable()); err != nil { + log.Infof("waiting for service '%s' to be stable...", env.Service) + timeouManager := c.di.Get(key.TimeoutManager).(timeout.Manager) + if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ + Cluster: &env.Cluster, + Services: []string{env.Service}, + }, timeouManager.ServiceStable()); err != nil { return result, err } - log.Infof("🥴 service '%s' has become to be stable!", c.Env.Service) + log.Infof("🥴 service '%s' has become to be stable!", env.Service) log.Infof( "🐥 service '%s' successfully rolled out to '%s:%d'!", - c.Env.Service, *nextTaskDefinition.Family, nextTaskDefinition.Revision, + env.Service, *nextTaskDefinition.Family, nextTaskDefinition.Revision, ) return result, nil } @@ -107,15 +112,17 @@ func (c *cage) StartCanaryTasks( var platformVersion *string var loadBalancers []ecstypes.LoadBalancer var serviceRegistries []ecstypes.ServiceRegistry + env := c.di.Get(key.Env).(*env.Envars) if input.UpdateService { - networkConfiguration = c.Env.ServiceDefinitionInput.NetworkConfiguration - platformVersion = c.Env.ServiceDefinitionInput.PlatformVersion - loadBalancers = c.Env.ServiceDefinitionInput.LoadBalancers - serviceRegistries = c.Env.ServiceDefinitionInput.ServiceRegistries + networkConfiguration = env.ServiceDefinitionInput.NetworkConfiguration + platformVersion = env.ServiceDefinitionInput.PlatformVersion + loadBalancers = env.ServiceDefinitionInput.LoadBalancers + serviceRegistries = env.ServiceDefinitionInput.ServiceRegistries } else { - if o, err := c.Ecs.DescribeServices(ctx, &ecs.DescribeServicesInput{ - Cluster: &c.Env.Cluster, - Services: []string{c.Env.Service}, + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + if o, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Cluster: &env.Cluster, + Services: []string{env.Service}, }); err != nil { return nil, err } else { @@ -126,15 +133,14 @@ func (c *cage) StartCanaryTasks( serviceRegistries = service.ServiceRegistries } } + factory := c.di.Get(key.TaskFactory).(task.Factory) return taskset.NewSet( - c.TaskFactory, + factory, &taskset.Input{ Input: &task.Input{ - Deps: c.Deps, NetworkConfiguration: networkConfiguration, TaskDefinition: nextTaskDefinition, PlatformVersion: platformVersion, - Timeout: c.Timeout, }, LoadBalancers: loadBalancers, ServiceRegistries: serviceRegistries, diff --git a/rollout_test.go b/rollout_test.go index 61e9795..cfb1e73 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -4,6 +4,7 @@ import ( "context" "strings" "testing" + "time" "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/aws" @@ -13,9 +14,13 @@ import ( elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/golang/mock/gomock" cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" + "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -35,14 +40,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { if taskCnt := mctx.RunningTaskSize(); taskCnt != v { t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) } - - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Alb: albMock, - Ec2: ec2Mock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NoError(t, err) @@ -59,13 +65,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Alb: albMock, - Ec2: ec2Mock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NoError(t, err) @@ -98,13 +106,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, nil).Times(2), albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetHealth).AnyTimes(), ) - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Alb: albMock, - Ec2: ec2Mock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1*time.Minute)) + })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NoError(t, err) @@ -142,13 +152,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, nil), albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetHealth).AnyTimes(), ) - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Ec2: ec2Mock, - Alb: albMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NotNil(t, err) @@ -171,13 +183,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{newLb} envars.ServiceDefinitionInput.NetworkConfiguration = newNetwork envars.ServiceDefinitionInput.PlatformVersion = aws.String("LATEST") - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Alb: albMock, - Ec2: ec2Mock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() service, _ := mctx.GetEcsService(envars.Service) assert.Equal(t, "1.4.0", *service.PlatformVersion) @@ -194,13 +208,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars := test.DefaultEnvars() ctrl := gomock.NewController(t) mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Alb: albMock, - Ec2: ec2Mock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{ { @@ -219,12 +235,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ctrl := gomock.NewController(t) mocker, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") delete(mocker.Services, envars.Service) - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Ec2: ec2Mock, - Alb: albMock, - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.EqualError(t, err, "service 'service' doesn't exist. Run 'cage up' or create service before rolling out") @@ -235,13 +254,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { envars.CanaryTaskIdleDuration = 1 ctrl := gomock.NewController(t) _, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Alb: albMock, - Ec2: ec2Mock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() if res, err := cagecli.RollOut(ctx, &types.RollOutInput{}); err != nil { t.Fatalf(err.Error()) @@ -261,11 +282,13 @@ func TestCage_RollOut_FARGATE(t *testing.T) { }, }, nil, ) - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) _, err := cagecli.RollOut(context.Background(), &types.RollOutInput{}) assert.EqualError(t, err, "😵 'service' status is 'INACTIVE'. Stop rolling out") }) @@ -302,13 +325,15 @@ func TestCage_RollOut_FARGATE(t *testing.T) { ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeContainerInstances).AnyTimes() ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask).AnyTimes() ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.StopTask).AnyTimes() - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Ec2: ec2Mock, - Alb: albMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() res, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NotNil(t, res) @@ -348,13 +373,15 @@ func TestCage_RollOut_EC2(t *testing.T) { if taskCnt := mctx.RunningTaskSize(); taskCnt != v { t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) } - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Ec2: ec2Mock, - Alb: albMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) if err != nil { @@ -378,13 +405,15 @@ func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { if taskCnt := mctx.RunningTaskSize(); taskCnt != 1 { t.Fatalf("current tasks not setup: %d/%d", 1, taskCnt) } - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Ec2: ec2Mock, - Alb: albMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.ErrorContains(t, err, "canaryInstanceArn is required") @@ -409,13 +438,15 @@ func TestCage_RollOut_EC2_no_attribute(t *testing.T) { Attributes: []ecstypes.Attribute{}, }, nil).AnyTimes() ecsMock.EXPECT().PutAttributes(gomock.Any(), gomock.Any()).Return(&ecs.PutAttributesOutput{}, nil).AnyTimes() - cagecli := cage.NewCage(&types.Deps{ - Env: envars, - Ecs: ecsMock, - Ec2: ec2Mock, - Alb: albMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) + })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) assert.NoError(t, err) diff --git a/run.go b/run.go index 038a4d5..82d6b51 100644 --- a/run.go +++ b/run.go @@ -7,6 +7,10 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) @@ -21,19 +25,22 @@ func containerExistsInDefinition(td *ecs.RegisterTaskDefinitionInput, container } func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult, error) { - if !containerExistsInDefinition(c.Env.TaskDefinitionInput, input.Container) { + env := c.di.Get(key.Env).(*env.Envars) + if !containerExistsInDefinition(env.TaskDefinitionInput, input.Container) { return nil, xerrors.Errorf("🚫 '%s' not found in container definitions", *input.Container) } td, err := c.CreateNextTaskDefinition(ctx) if err != nil { return nil, err } - o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ - Cluster: &c.Env.Cluster, + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) + o, err := ecsCli.RunTask(ctx, &ecs.RunTaskInput{ + Cluster: &env.Cluster, TaskDefinition: td.TaskDefinitionArn, LaunchType: ecstypes.LaunchTypeFargate, - NetworkConfiguration: c.Env.ServiceDefinitionInput.NetworkConfiguration, - PlatformVersion: c.Env.ServiceDefinitionInput.PlatformVersion, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, + PlatformVersion: env.ServiceDefinitionInput.PlatformVersion, Overrides: input.Overrides, Group: aws.String("cage:run-task"), }) @@ -42,18 +49,18 @@ func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult } taskArn := o.Tasks[0].TaskArn log.Infof("waiting for task '%s' to start...", *taskArn) - if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, + if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, Tasks: []string{*taskArn}, - }, c.Timeout.TaskRunning()); err != nil { + }, timeoutManager.TaskRunning()); err != nil { return nil, xerrors.Errorf("task failed to start: %w", err) } log.Infof("task '%s' is running", *taskArn) log.Infof("waiting for task '%s' to stop...", *taskArn) - if result, err := ecs.NewTasksStoppedWaiter(c.Ecs).WaitForOutput(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, + if result, err := ecs.NewTasksStoppedWaiter(ecsCli).WaitForOutput(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, Tasks: []string{*taskArn}, - }, c.Timeout.TaskStopped()); err != nil { + }, timeoutManager.TaskStopped()); err != nil { return nil, xerrors.Errorf("task failed to stop: %w", err) } else { task := result.Tasks[0] diff --git a/run_test.go b/run_test.go index 780b0bc..2143ea5 100644 --- a/run_test.go +++ b/run_test.go @@ -10,9 +10,12 @@ import ( "github.com/golang/mock/gomock" cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -40,11 +43,12 @@ func TestCage_Run(t *testing.T) { return mocker.Ecs.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, @@ -70,11 +74,12 @@ func TestCage_Run(t *testing.T) { }, ), ) - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, @@ -92,11 +97,12 @@ func TestCage_Run(t *testing.T) { ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeTasks).Times(2), ) - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, @@ -120,11 +126,12 @@ func TestCage_Run(t *testing.T) { return mocker.Ecs.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, @@ -148,11 +155,12 @@ func TestCage_Run(t *testing.T) { return mocker.Ecs.DescribeTasks(ctx, input) }), ) - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, Overrides: overrides, @@ -164,11 +172,12 @@ func TestCage_Run(t *testing.T) { overrides := &ecstypes.TaskOverride{} ctx := context.Background() env, _, ecsMock := setupForBasic(t) - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - Time: test.NewFakeTime(), - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: aws.String("foo"), Overrides: overrides, diff --git a/task/alb_task.go b/task/alb_task.go index 0bdb731..8a9debf 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -9,6 +9,11 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "golang.org/x/xerrors" ) @@ -19,11 +24,13 @@ type albTask struct { target *CanaryTarget } -func NewAlbTask(input *Input, +func NewAlbTask( + di *di.D, + input *Input, lb *ecstypes.LoadBalancer, ) Task { return &albTask{ - common: &common{Input: input}, + common: &common{Input: input, di: di}, lb: lb, } } @@ -67,7 +74,8 @@ func (c *albTask) registerToTargetGroup(ctx context.Context) error { } else { c.target = target } - if _, err := c.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ + albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) + if _, err := albCli.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ TargetGroupArn: c.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ AvailabilityZone: &c.target.availabilityZone, @@ -83,10 +91,13 @@ func (c *albTask) registerToTargetGroup(ctx context.Context) error { func (c *albTask) waitUntilTargetHealthy( ctx context.Context, ) error { + albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) + timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) + timer := c.di.Get(key.Time).(types.Time) log.Infof("checking the health state of canary task...") var unusedCount = 0 var recentState *elbv2types.TargetHealthStateEnum - rest := c.Timeout.TargetHealthCheck() + rest := timeoutManager.TargetHealthCheck() waitPeriod := 15 * time.Second for rest > 0 && unusedCount < 5 { if rest < waitPeriod { @@ -95,8 +106,8 @@ func (c *albTask) waitUntilTargetHealthy( select { case <-ctx.Done(): return ctx.Err() - case <-c.Time.NewTimer(waitPeriod).C: - if o, err := c.Alb.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ + case <-timer.NewTimer(waitPeriod).C: + if o, err := albCli.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ TargetGroupArn: c.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ Id: &c.target.targetId, @@ -135,7 +146,8 @@ func (c *albTask) waitUntilTargetHealthy( func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, error) { deregistrationDelay := 300 * time.Second - if o, err := c.Alb.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ + albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) + if o, err := albCli.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ TargetGroupArn: c.lb.TargetGroupArn, }); err != nil { return deregistrationDelay, err @@ -158,13 +170,14 @@ func (c *albTask) deregisterTarget(ctx context.Context) { if c.target == nil { return } + albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) deregistrationDelay, err := c.targetDeregistrationDelay(ctx) if err != nil { log.Errorf("failed to get deregistration delay: %v", err) log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) } log.Infof("deregistering the canary task from target group '%s'...", c.target.targetId) - if _, err := c.Alb.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ + if _, err := albCli.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ TargetGroupArn: c.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ AvailabilityZone: &c.target.availabilityZone, @@ -177,7 +190,7 @@ func (c *albTask) deregisterTarget(ctx context.Context) { } else { log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety - if err := elbv2.NewTargetDeregisteredWaiter(c.Alb).Wait(ctx, &elbv2.DescribeTargetHealthInput{ + if err := elbv2.NewTargetDeregisteredWaiter(albCli).Wait(ctx, &elbv2.DescribeTargetHealthInput{ TargetGroupArn: c.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ AvailabilityZone: &c.target.availabilityZone, diff --git a/task/alb_task_test.go b/task/alb_task_test.go index cd65892..a9fc23a 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -5,10 +5,11 @@ import ( "testing" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/canarycage/timeout" - "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -19,18 +20,18 @@ func TestAlbTask(t *testing.T) { td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) - stask := task.NewAlbTask(&task.Input{ - Deps: &types.Deps{ - Env: env, - Ecs: mocker.Ecs, - Ec2: mocker.Ec2, - Alb: mocker.Alb, - Srv: mocker.Srv, - Time: test.NewFakeTime(), - }, + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, mocker.Ecs) + b.Set(key.Ec2Cli, mocker.Ec2) + b.Set(key.AlbCli, mocker.Alb) + b.Set(key.SrvCli, mocker.Srv) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + b.Set(key.Time, test.NewFakeTime()) + }) + stask := task.NewAlbTask(d, &task.Input{ TaskDefinition: td.TaskDefinition, NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, - Timeout: timeout.NewManager(1, &timeout.Input{}), }, &ecsSvc.Service.LoadBalancers[0]) mocker.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ TargetGroupArn: ecsSvc.Service.LoadBalancers[0].TargetGroupArn, diff --git a/task/common.go b/task/common.go index b28d78c..3f8cf53 100644 --- a/task/common.go +++ b/task/common.go @@ -10,8 +10,12 @@ import ( ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "golang.org/x/xerrors" ) @@ -23,28 +27,29 @@ type CanaryTarget struct { } type Input struct { - *types.Deps TaskDefinition *ecstypes.TaskDefinition NetworkConfiguration *ecstypes.NetworkConfiguration PlatformVersion *string - Timeout timeout.Manager } type common struct { *Input + di *di.D taskArn *string } func (c *common) Start(ctx context.Context) error { - group := fmt.Sprintf("cage:canary-task:%s", c.Env.Service) - if c.Env.CanaryInstanceArn != "" { + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + group := fmt.Sprintf("cage:canary-task:%s", env.Service) + if env.CanaryInstanceArn != "" { // ec2 - if o, err := c.Ecs.StartTask(ctx, &ecs.StartTaskInput{ - Cluster: &c.Env.Cluster, + if o, err := ecsCli.StartTask(ctx, &ecs.StartTaskInput{ + Cluster: &env.Cluster, Group: &group, NetworkConfiguration: c.NetworkConfiguration, TaskDefinition: c.TaskDefinition.TaskDefinitionArn, - ContainerInstances: []string{c.Env.CanaryInstanceArn}, + ContainerInstances: []string{env.CanaryInstanceArn}, }); err != nil { return err } else { @@ -52,8 +57,8 @@ func (c *common) Start(ctx context.Context) error { } } else { // fargate - if o, err := c.Ecs.RunTask(ctx, &ecs.RunTaskInput{ - Cluster: &c.Env.Cluster, + if o, err := ecsCli.RunTask(ctx, &ecs.RunTaskInput{ + Cluster: &env.Cluster, Group: &group, NetworkConfiguration: c.NetworkConfiguration, TaskDefinition: c.TaskDefinition.TaskDefinitionArn, @@ -73,11 +78,14 @@ func (c *common) TaskArn() *string { } func (c *common) waitForTask(ctx context.Context) error { + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) - if err := ecs.NewTasksRunningWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, + if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, - }, c.Timeout.TaskRunning()); err != nil { + }, timeoutManager.TaskRunning()); err != nil { return err } log.Infof("🐣 canary task '%s' is running!", *c.taskArn) @@ -91,13 +99,17 @@ func (c *common) waitForTask(ctx context.Context) error { func (c *common) waitContainerHealthCheck(ctx context.Context) error { log.Infof("😷 ensuring canary task container(s) to become healthy...") + env := c.di.Get(key.Env).(*env.Envars) + timer := c.di.Get(key.Time).(types.Time) + timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) containerHasHealthChecks := map[string]struct{}{} for _, definition := range c.TaskDefinition.ContainerDefinitions { if definition.HealthCheck != nil { containerHasHealthChecks[*definition.Name] = struct{}{} } } - rest := c.Timeout.TaskHealthCheck() + rest := timeoutManager.TaskHealthCheck() healthCheckPeriod := 15 * time.Second for rest > 0 { if rest < healthCheckPeriod { @@ -106,10 +118,10 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case <-c.Time.NewTimer(healthCheckPeriod).C: + case <-timer.NewTimer(healthCheckPeriod).C: log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) - if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, + if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, }); err != nil { return err @@ -142,8 +154,9 @@ func (c *common) describeTaskTarget( ctx context.Context, targetPort int32, ) (*CanaryTarget, error) { + env := c.di.Get(key.Env).(*env.Envars) target := CanaryTarget{targetPort: targetPort} - if c.Env.CanaryInstanceArn == "" { // Fargate + if env.CanaryInstanceArn == "" { // Fargate if err := c.getFargateTarget(ctx, &target); err != nil { return nil, err } @@ -159,8 +172,11 @@ func (c *common) describeTaskTarget( func (c *common) getFargateTarget(ctx context.Context, dest *CanaryTarget) error { var task ecstypes.Task - if o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) + if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, }); err != nil { return err @@ -180,7 +196,7 @@ func (c *common) getFargateTarget(ctx context.Context, dest *CanaryTarget) error if subnetId == nil || privateIp == nil { return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") } - if o, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + if o, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ SubnetIds: []string{*subnetId}, }); err != nil { return err @@ -194,23 +210,26 @@ func (c *common) getFargateTarget(ctx context.Context, dest *CanaryTarget) error func (c *common) getEc2Target(ctx context.Context, dest *CanaryTarget) error { var containerInstance ecstypes.ContainerInstance - if outputs, err := c.Ecs.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ - Cluster: &c.Env.Cluster, - ContainerInstances: []string{c.Env.CanaryInstanceArn}, + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) + if outputs, err := ecsCli.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ + Cluster: &env.Cluster, + ContainerInstances: []string{env.CanaryInstanceArn}, }); err != nil { return err } else { containerInstance = outputs.ContainerInstances[0] } var ec2Instance ec2types.Instance - if o, err := c.Ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + if o, err := ec2Cli.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ InstanceIds: []string{*containerInstance.Ec2InstanceId}, }); err != nil { return err } else { ec2Instance = o.Reservations[0].Instances[0] } - if sn, err := c.Ec2.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + if sn, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ SubnetIds: []string{*ec2Instance.SubnetId}, }); err != nil { return err @@ -226,17 +245,20 @@ func (c *common) stopTask(ctx context.Context) error { if c.taskArn == nil { return nil } + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) log.Infof("stopping the canary task '%s'...", *c.taskArn) - if _, err := c.Ecs.StopTask(ctx, &ecs.StopTaskInput{ - Cluster: &c.Env.Cluster, + if _, err := ecsCli.StopTask(ctx, &ecs.StopTaskInput{ + Cluster: &env.Cluster, Task: c.taskArn, }); err != nil { return xerrors.Errorf("failed to stop canary task: %w", err) } - if err := ecs.NewTasksStoppedWaiter(c.Ecs).Wait(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, + if err := ecs.NewTasksStoppedWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, - }, c.Timeout.TaskStopped()); err != nil { + }, timeoutManager.TaskStopped()); err != nil { return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) } log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) diff --git a/task/common_test.go b/task/common_test.go new file mode 100644 index 0000000..aa3d246 --- /dev/null +++ b/task/common_test.go @@ -0,0 +1,9 @@ +package task_test + +import ( + "testing" +) + +func TestCommon_Start(t *testing.T) { + +} diff --git a/task/export_test.go b/task/export_test.go new file mode 100644 index 0000000..ee02268 --- /dev/null +++ b/task/export_test.go @@ -0,0 +1,3 @@ +package task + +type CommonExport = common diff --git a/task/factory.go b/task/factory.go new file mode 100644 index 0000000..4f98267 --- /dev/null +++ b/task/factory.go @@ -0,0 +1,32 @@ +package task + +import ( + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/logos/di" +) + +type Factory interface { + NewAlbTask(input *Input, lb *ecstypes.LoadBalancer) Task + NewSrvTask(input *Input, srv *ecstypes.ServiceRegistry) Task + NewSimpleTask(input *Input) Task +} + +type factory struct { + di *di.D +} + +func NewFactory(di *di.D) Factory { + return &factory{di: di} +} + +func (f *factory) NewAlbTask(input *Input, lb *ecstypes.LoadBalancer) Task { + return NewAlbTask(f.di, input, lb) +} + +func (f *factory) NewSrvTask(input *Input, srv *ecstypes.ServiceRegistry) Task { + return NewSrvTask(f.di, input, srv) +} + +func (f *factory) NewSimpleTask(input *Input) Task { + return NewSimpleTask(f.di, input) +} diff --git a/taskset/factory_test.go b/task/factory_test.go similarity index 81% rename from taskset/factory_test.go rename to task/factory_test.go index 4521491..ad13d72 100644 --- a/taskset/factory_test.go +++ b/task/factory_test.go @@ -1,31 +1,32 @@ -package taskset_test +package task_test import ( "testing" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/loilo-inc/canarycage/task" - "github.com/loilo-inc/canarycage/taskset" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) func TestFactory(t *testing.T) { + d := &di.D{} t.Run("NewAlbTask", func(t *testing.T) { - f := taskset.NewFactory() + f := task.NewFactory(d) input := &task.Input{} lb := &ecstypes.LoadBalancer{} task := f.NewAlbTask(input, lb) assert.NotNil(t, task) }) t.Run("NewSrvTask", func(t *testing.T) { - f := taskset.NewFactory() + f := task.NewFactory(d) input := &task.Input{} srv := &ecstypes.ServiceRegistry{} task := f.NewSrvTask(input, srv) assert.NotNil(t, task) }) t.Run("NewSimpleTask", func(t *testing.T) { - f := taskset.NewFactory() + f := task.NewFactory(d) input := &task.Input{} task := f.NewSimpleTask(input) assert.NotNil(t, task) diff --git a/task/simple_task.go b/task/simple_task.go index f7ae23e..e40d2c4 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -6,6 +6,11 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "golang.org/x/xerrors" ) @@ -14,8 +19,8 @@ type simpleTask struct { *common } -func NewSimpleTask(input *Input) Task { - return &simpleTask{common: &common{Input: input}} +func NewSimpleTask(di *di.D, input *Input) Task { + return &simpleTask{common: &common{Input: input, di: di}} } func (c *simpleTask) Wait(ctx context.Context) error { @@ -30,8 +35,10 @@ func (c *simpleTask) Stop(ctx context.Context) error { } func (c *simpleTask) waitForIdleDuration(ctx context.Context) error { - log.Infof("wait %d seconds for canary task to be stable...", c.Env.CanaryTaskIdleDuration) - rest := time.Duration(c.Env.CanaryTaskIdleDuration) * time.Second + env := c.di.Get(key.Env).(*env.Envars) + timer := c.di.Get(key.Time).(types.Time) + log.Infof("wait %d seconds for canary task to be stable...", env.CanaryTaskIdleDuration) + rest := time.Duration(env.CanaryTaskIdleDuration) * time.Second waitPeriod := 15 * time.Second for rest > 0 { if rest < waitPeriod { @@ -40,13 +47,14 @@ func (c *simpleTask) waitForIdleDuration(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case <-c.Time.NewTimer(waitPeriod).C: + case <-timer.NewTimer(waitPeriod).C: rest -= waitPeriod } log.Infof("still waiting...; %d seconds left", rest) } - o, err := c.Ecs.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &c.Env.Cluster, + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, }) if err != nil { diff --git a/task/simple_task_test.go b/task/simple_task_test.go index 218cf57..27a0a59 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -4,10 +4,11 @@ import ( "context" "testing" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/canarycage/timeout" - "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -18,18 +19,18 @@ func TestSimpleTask(t *testing.T) { td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) - stask := task.NewSimpleTask(&task.Input{ - Deps: &types.Deps{ - Env: env, - Ecs: mocker.Ecs, - Ec2: mocker.Ec2, - Alb: mocker.Alb, - Srv: mocker.Srv, - Time: test.NewFakeTime(), - }, + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, mocker.Ecs) + b.Set(key.Ec2Cli, mocker.Ec2) + b.Set(key.AlbCli, mocker.Alb) + b.Set(key.SrvCli, mocker.Srv) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + b.Set(key.Time, test.NewFakeTime()) + }) + stask := task.NewSimpleTask(d, &task.Input{ TaskDefinition: td.TaskDefinition, NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, - Timeout: timeout.NewManager(1, &timeout.Input{}), }) err := stask.Start(ctx) assert.NoError(t, err) diff --git a/task/srv_task.go b/task/srv_task.go index 0ecb554..ad6db5c 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -8,6 +8,12 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/aws/aws-sdk-go-v2/service/servicediscovery" srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "golang.org/x/xerrors" ) @@ -21,9 +27,9 @@ type srvTask struct { ns *srvtypes.Namespace } -func NewSrvTask(input *Input, registry *ecstypes.ServiceRegistry) Task { +func NewSrvTask(di *di.D, input *Input, registry *ecstypes.ServiceRegistry) Task { return &srvTask{ - common: &common{Input: input}, + common: &common{Input: input, di: di}, registry: registry, } } @@ -63,16 +69,18 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { } c.target = target // get the service id from service registry arn srvId := ArnToId(*c.registry.RegistryArn) + srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) + env := c.di.Get(key.Env).(*env.Envars) var svc *srvtypes.Service var ns *srvtypes.Namespace - if o, err := c.Srv.GetService(ctx, &servicediscovery.GetServiceInput{ + if o, err := srvCli.GetService(ctx, &servicediscovery.GetServiceInput{ Id: &srvId, }); err != nil { return xerrors.Errorf("failed to get the service: %w", err) } else { svc = o.Service } - if o, err := c.Srv.GetNamespace(ctx, &servicediscovery.GetNamespaceInput{ + if o, err := srvCli.GetNamespace(ctx, &servicediscovery.GetNamespaceInput{ Id: svc.NamespaceId, }); err != nil { return xerrors.Errorf("failed to get the namespace: %w", err) @@ -83,14 +91,14 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { "AWS_INSTANCE_IPV4": target.targetIpv4, "AVAILABILITY_ZONE": target.availabilityZone, "AWS_INIT_HEALTH_STATUS": "UNHEALTHY", - "ECS_CLUSTER_NAME": c.Env.Cluster, - "ECS_SERVICE_NAME": c.Env.Service, + "ECS_CLUSTER_NAME": env.Cluster, + "ECS_SERVICE_NAME": env.Service, "ECS_TASK_DEFINITION_FAMILY": *c.TaskDefinition.Family, - "REGION": c.Env.Region, + "REGION": env.Region, "CAGE_TASK_ID": ArnToId(*c.taskArn), } taskId := ArnToId(*c.taskArn) - if _, err := c.Srv.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ + if _, err := srvCli.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ ServiceId: &srvId, InstanceId: &taskId, Attributes: attrs, @@ -106,7 +114,10 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { func (c *srvTask) waitUntilSrvInstHelthy( ctx context.Context, ) error { - var rest = c.Timeout.TargetHealthCheck() + timer := c.di.Get(key.Time).(types.Time) + srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) + timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) + var rest = timeoutManager.TargetHealthCheck() var waitPeriod = 15 * time.Second for rest > 0 { if rest < waitPeriod { @@ -115,8 +126,8 @@ func (c *srvTask) waitUntilSrvInstHelthy( select { case <-ctx.Done(): return ctx.Err() - case <-c.Time.NewTimer(time.Duration(waitPeriod) * time.Second).C: - if list, err := c.Srv.DiscoverInstances(ctx, &servicediscovery.DiscoverInstancesInput{ + case <-timer.NewTimer(time.Duration(waitPeriod) * time.Second).C: + if list, err := srvCli.DiscoverInstances(ctx, &servicediscovery.DiscoverInstancesInput{ NamespaceName: c.ns.Name, ServiceName: c.srv.Name, HealthStatus: srvtypes.HealthStatusFilterHealthy, @@ -145,7 +156,8 @@ func (c *srvTask) deregisterSrvInst( return } log.Info("deregistering the canary task from service discovery...") - if _, err := c.Srv.DeregisterInstance(ctx, &servicediscovery.DeregisterInstanceInput{ + srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) + if _, err := srvCli.DeregisterInstance(ctx, &servicediscovery.DeregisterInstanceInput{ ServiceId: c.srv.Id, InstanceId: c.instId, }); err != nil { diff --git a/task/srv_task_test.go b/task/srv_task_test.go index 374caad..9853c05 100644 --- a/task/srv_task_test.go +++ b/task/srv_task_test.go @@ -5,10 +5,11 @@ import ( "testing" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/canarycage/timeout" - "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -23,18 +24,18 @@ func TestSrvTask(t *testing.T) { td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) - stask := task.NewSrvTask(&task.Input{ - Deps: &types.Deps{ - Env: env, - Ecs: mocker.Ecs, - Ec2: mocker.Ec2, - Alb: mocker.Alb, - Srv: mocker.Srv, - Time: test.NewFakeTime(), - }, + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, mocker.Ecs) + b.Set(key.Ec2Cli, mocker.Ec2) + b.Set(key.AlbCli, mocker.Alb) + b.Set(key.SrvCli, mocker.Srv) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + b.Set(key.Time, test.NewFakeTime()) + }) + stask := task.NewSrvTask(d, &task.Input{ TaskDefinition: td.TaskDefinition, NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, - Timeout: timeout.NewManager(1, &timeout.Input{}), }, &ecstypes.ServiceRegistry{RegistryArn: ®istryArn}) _ = mocker.CreateSrvService(srvNsName, srvSvcName) err := stask.Start(ctx) diff --git a/task_definition.go b/task_definition.go index 25548f0..ee94e87 100644 --- a/task_definition.go +++ b/task_definition.go @@ -6,14 +6,19 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" "golang.org/x/xerrors" ) func (c *cage) CreateNextTaskDefinition(ctx context.Context) (*ecstypes.TaskDefinition, error) { - if c.Env.TaskDefinitionArn != "" { - log.Infof("--taskDefinitionArn was set to '%s'. skip registering new task definition.", c.Env.TaskDefinitionArn) - o, err := c.Ecs.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{ - TaskDefinition: &c.Env.TaskDefinitionArn, + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + if env.TaskDefinitionArn != "" { + log.Infof("--taskDefinitionArn was set to '%s'. skip registering new task definition.", env.TaskDefinitionArn) + o, err := ecsCli.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{ + TaskDefinition: &env.TaskDefinitionArn, }) if err != nil { return nil, xerrors.Errorf("failed to describe next task definition: %w", err) @@ -21,7 +26,7 @@ func (c *cage) CreateNextTaskDefinition(ctx context.Context) (*ecstypes.TaskDefi return o.TaskDefinition, nil } else { log.Infof("creating next task definition...") - if out, err := c.Ecs.RegisterTaskDefinition(ctx, c.Env.TaskDefinitionInput); err != nil { + if out, err := ecsCli.RegisterTaskDefinition(ctx, env.TaskDefinitionInput); err != nil { return nil, xerrors.Errorf("failed to register next task definition: %w", err) } else { log.Infof( diff --git a/task_definition_test.go b/task_definition_test.go index 3c51608..db5f8ac 100644 --- a/task_definition_test.go +++ b/task_definition_test.go @@ -9,9 +9,10 @@ import ( "github.com/golang/mock/gomock" cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" "golang.org/x/xerrors" ) @@ -23,12 +24,10 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { env := &env.Envars{ TaskDefinitionArn: "arn://aaa", } - c := &cage.CageExport{ - Deps: &types.Deps{ - Env: env, - Ecs: ecsMock, - }, - } + c := cage.NewCageExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })) ecsMock.EXPECT().DescribeTaskDefinition(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTaskDefinitionOutput{ TaskDefinition: &ecstypes.TaskDefinition{}, }, nil) @@ -42,12 +41,10 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { env := &env.Envars{ TaskDefinitionArn: "arn://aaa", } - c := &cage.CageExport{ - Deps: &types.Deps{ - Env: env, - Ecs: ecsMock, - }, - } + c := cage.NewCageExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })) ecsMock.EXPECT().DescribeTaskDefinition(gomock.Any(), gomock.Any()).Return(nil, xerrors.New("error")) td, err := c.CreateNextTaskDefinition(context.Background()) assert.Errorf(t, err, "failed to describe next task definition: error") @@ -57,12 +54,10 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() - c := &cage.CageExport{ - Deps: &types.Deps{ - Env: env, - Ecs: ecsMock, - }, - } + c := cage.NewCageExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })) ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).Return(&ecs.RegisterTaskDefinitionOutput{ TaskDefinition: &ecstypes.TaskDefinition{ Family: env.TaskDefinitionInput.Family, @@ -77,12 +72,10 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() - c := &cage.CageExport{ - Deps: &types.Deps{ - Env: env, - Ecs: ecsMock, - }, - } + c := cage.NewCageExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })) ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).Return(nil, xerrors.New("error")) td, err := c.CreateNextTaskDefinition(context.Background()) assert.Errorf(t, err, "failed to register next task definition: error") diff --git a/taskset/factory.go b/taskset/factory.go deleted file mode 100644 index 8391c8b..0000000 --- a/taskset/factory.go +++ /dev/null @@ -1,30 +0,0 @@ -package taskset - -import ( - ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/loilo-inc/canarycage/task" -) - -type Factory interface { - NewAlbTask(input *task.Input, lb *ecstypes.LoadBalancer) task.Task - NewSrvTask(input *task.Input, srv *ecstypes.ServiceRegistry) task.Task - NewSimpleTask(input *task.Input) task.Task -} - -type factory struct{} - -func NewFactory() Factory { - return &factory{} -} - -func (f *factory) NewAlbTask(input *task.Input, lb *ecstypes.LoadBalancer) task.Task { - return task.NewAlbTask(input, lb) -} - -func (f *factory) NewSrvTask(input *task.Input, srv *ecstypes.ServiceRegistry) task.Task { - return task.NewSrvTask(input, srv) -} - -func (f *factory) NewSimpleTask(input *task.Input) task.Task { - return task.NewSimpleTask(input) -} diff --git a/taskset/taskset.go b/taskset/taskset.go index 9d08eb8..d658187 100644 --- a/taskset/taskset.go +++ b/taskset/taskset.go @@ -23,14 +23,14 @@ type Input struct { ServiceRegistries []ecstypes.ServiceRegistry } -func NewSet(factory Factory, input *Input) Set { +func NewSet( + factory task.Factory, + input *Input) Set { var results []task.Task taskInput := &task.Input{ - Deps: input.Deps, NetworkConfiguration: input.NetworkConfiguration, TaskDefinition: input.TaskDefinition, PlatformVersion: input.PlatformVersion, - Timeout: input.Timeout, } for _, lb := range input.LoadBalancers { task := factory.NewAlbTask(taskInput, &lb) diff --git a/time.go b/time.go deleted file mode 100644 index d4f9e6a..0000000 --- a/time.go +++ /dev/null @@ -1,12 +0,0 @@ -package cage - -import "time" - -type timeImpl struct{} - -func (t *timeImpl) Now() time.Time { - return time.Now() -} -func (t *timeImpl) NewTimer(d time.Duration) *time.Timer { - return time.NewTimer(d) -} diff --git a/timeout/time.go b/timeout/time.go new file mode 100644 index 0000000..640def1 --- /dev/null +++ b/timeout/time.go @@ -0,0 +1,12 @@ +package timeout + +import "time" + +type Time struct{} + +func (t *Time) Now() time.Time { + return time.Now() +} +func (t *Time) NewTimer(d time.Duration) *time.Timer { + return time.NewTimer(d) +} diff --git a/timeout/timeout.go b/timeout/timeout.go index 7ad4ee3..0bcb9ed 100644 --- a/timeout/timeout.go +++ b/timeout/timeout.go @@ -1,14 +1,10 @@ package timeout -import "time" +import ( + "time" -type Input struct { - TaskStoppedWait time.Duration - TaskRunningWait time.Duration - TaskHealthCheckWait time.Duration - TargetHealthCheckWait time.Duration - ServiceStableWait time.Duration -} + "github.com/loilo-inc/canarycage/env" +) type Manager interface { TaskRunning() time.Duration @@ -19,51 +15,56 @@ type Manager interface { } type manager struct { - *Input + env *env.Envars DefaultTimeout time.Duration } func NewManager( + env *env.Envars, defaultTimeout time.Duration, - input *Input, ) Manager { return &manager{ - Input: input, + env: env, DefaultTimeout: defaultTimeout, } } func (t *manager) TaskRunning() time.Duration { - if t.TaskRunningWait > 0 { - return t.TaskRunningWait + wait := t.env.CanaryTaskRunningWait + if wait > 0 { + return time.Duration(wait) * time.Second } return t.DefaultTimeout } func (t *manager) TaskHealthCheck() time.Duration { - if t.TaskHealthCheckWait > 0 { - return t.TaskHealthCheckWait + wait := t.env.CanaryTaskHealthCheckWait + if wait > 0 { + return time.Duration(wait) * time.Second } return t.DefaultTimeout } func (t *manager) TaskStopped() time.Duration { - if t.TaskStoppedWait > 0 { - return t.TaskStoppedWait + wait := t.env.CanaryTaskStoppedWait + if wait > 0 { + return time.Duration(wait) * time.Second } return t.DefaultTimeout } func (t *manager) ServiceStable() time.Duration { - if t.ServiceStableWait > 0 { - return t.ServiceStableWait + wait := t.env.ServiceStableWait + if wait > 0 { + return time.Duration(wait) * time.Second } return t.DefaultTimeout } func (t *manager) TargetHealthCheck() time.Duration { - if t.TargetHealthCheckWait > 0 { - return t.TargetHealthCheckWait + wait := t.env.TargetHealthCheckWait + if wait > 0 { + return time.Duration(wait) * time.Second } return t.DefaultTimeout } diff --git a/timeout/timeout_test.go b/timeout/timeout_test.go index 9b4517d..f397f7e 100644 --- a/timeout/timeout_test.go +++ b/timeout/timeout_test.go @@ -4,13 +4,14 @@ import ( "testing" "time" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/timeout" "github.com/stretchr/testify/assert" ) func TestManager(t *testing.T) { t.Run("no config", func(t *testing.T) { - man := timeout.NewManager(10, &timeout.Input{}) + man := timeout.NewManager(&env.Envars{}, 10) assert.Equal(t, time.Duration(10), man.TaskRunning()) assert.Equal(t, time.Duration(10), man.TaskStopped()) assert.Equal(t, time.Duration(10), man.TaskHealthCheck()) @@ -18,17 +19,17 @@ func TestManager(t *testing.T) { assert.Equal(t, time.Duration(10), man.TargetHealthCheck()) }) t.Run("with config", func(t *testing.T) { - man := timeout.NewManager(10, &timeout.Input{ - TaskRunningWait: 1, - TaskStoppedWait: 2, - TaskHealthCheckWait: 3, - ServiceStableWait: 4, - TargetHealthCheckWait: 5, - }) - assert.Equal(t, time.Duration(1), man.TaskRunning()) - assert.Equal(t, time.Duration(2), man.TaskStopped()) - assert.Equal(t, time.Duration(3), man.TaskHealthCheck()) - assert.Equal(t, time.Duration(4), man.ServiceStable()) - assert.Equal(t, time.Duration(5), man.TargetHealthCheck()) + man := timeout.NewManager(&env.Envars{ + CanaryTaskRunningWait: 1, + CanaryTaskStoppedWait: 2, + CanaryTaskHealthCheckWait: 3, + ServiceStableWait: 4, + TargetHealthCheckWait: 5, + }, 10) + assert.Equal(t, 1*time.Second, man.TaskRunning()) + assert.Equal(t, 2*time.Second, man.TaskStopped()) + assert.Equal(t, 3*time.Second, man.TaskHealthCheck()) + assert.Equal(t, 4*time.Second, man.ServiceStable()) + assert.Equal(t, 5*time.Second, man.TargetHealthCheck()) }) } diff --git a/types/iface.go b/types/iface.go index 7e28a9d..38093b5 100644 --- a/types/iface.go +++ b/types/iface.go @@ -5,8 +5,6 @@ import ( "time" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/loilo-inc/canarycage/awsiface" - "github.com/loilo-inc/canarycage/env" ) type Cage interface { @@ -20,15 +18,6 @@ type Time interface { NewTimer(time.Duration) *time.Timer } -type Deps struct { - Env *env.Envars - Ecs awsiface.EcsClient - Alb awsiface.AlbClient - Ec2 awsiface.Ec2Client - Srv awsiface.SrvClient - Time Time -} - type RunInput struct { Container *string Overrides *ecstypes.TaskOverride diff --git a/up.go b/up.go index ae9da4d..434405e 100644 --- a/up.go +++ b/up.go @@ -6,6 +6,10 @@ import ( "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) @@ -15,20 +19,22 @@ func (c *cage) Up(ctx context.Context) (*types.UpResult, error) { if err != nil { return nil, err } - log.Infof("checking existence of service '%s'", c.Env.Service) - if o, err := c.Ecs.DescribeServices(ctx, &ecs.DescribeServicesInput{ - Cluster: &c.Env.Cluster, - Services: []string{c.Env.Service}, + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + log.Infof("checking existence of service '%s'", env.Service) + if o, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Cluster: &env.Cluster, + Services: []string{env.Service}, }); err != nil { return nil, xerrors.Errorf("couldn't describe service: %w", err) } else if len(o.Services) > 0 { svc := o.Services[0] if *svc.Status != "INACTIVE" { - return nil, xerrors.Errorf("service '%s' already exists. Use 'cage rollout' instead", c.Env.Service) + return nil, xerrors.Errorf("service '%s' already exists. Use 'cage rollout' instead", env.Service) } } - c.Env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinitionArn - if service, err := c.createService(ctx, c.Env.ServiceDefinitionInput); err != nil { + env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinitionArn + if service, err := c.createService(ctx, env.ServiceDefinitionInput); err != nil { return nil, err } else { return &types.UpResult{TaskDefinition: td, Service: service}, nil @@ -36,16 +42,19 @@ func (c *cage) Up(ctx context.Context) (*types.UpResult, error) { } func (c *cage) createService(ctx context.Context, serviceDefinitionInput *ecs.CreateServiceInput) (*ecstypes.Service, error) { + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) log.Infof("creating service '%s' with task-definition '%s'...", *serviceDefinitionInput.ServiceName, *serviceDefinitionInput.TaskDefinition) - o, err := c.Ecs.CreateService(ctx, serviceDefinitionInput) + o, err := ecsCli.CreateService(ctx, serviceDefinitionInput) if err != nil { return nil, xerrors.Errorf("failed to create service '%s': %w", *serviceDefinitionInput.ServiceName, err) } + timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) log.Infof("waiting for service '%s' to be STABLE", *serviceDefinitionInput.ServiceName) - if err := ecs.NewServicesStableWaiter(c.Ecs).Wait(ctx, &ecs.DescribeServicesInput{ - Cluster: &c.Env.Cluster, + if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ + Cluster: &env.Cluster, Services: []string{*serviceDefinitionInput.ServiceName}, - }, c.Timeout.ServiceStable()); err != nil { + }, timeoutManager.ServiceStable()); err != nil { return nil, xerrors.Errorf("failed to wait for service '%s' to be STABLE: %w", *serviceDefinitionInput.ServiceName, err) } return o.Service, nil diff --git a/up_test.go b/up_test.go index 1f3023e..352a315 100644 --- a/up_test.go +++ b/up_test.go @@ -6,8 +6,10 @@ import ( "github.com/golang/mock/gomock" cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -17,10 +19,11 @@ func TestCage_Up(t *testing.T) { ctrl := gomock.NewController(t) ctx, ecsMock, _, _ := test.Setup(ctrl, env, 1, "FARGATE") delete(ctx.Services, env.Service) - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Up(context.Background()) assert.Nil(t, err) assert.NotNil(t, result.Service) @@ -30,10 +33,11 @@ func TestCage_Up(t *testing.T) { env := test.DefaultEnvars() ctrl := gomock.NewController(t) _, ecsMock, _, _ := test.Setup(ctrl, env, 1, "FARGATE") - cagecli := cage.NewCage(&types.Deps{ - Env: env, - Ecs: ecsMock, - }) + cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) + })) result, err := cagecli.Up(context.Background()) assert.Nil(t, result) assert.EqualError(t, err, "service 'service' already exists. Use 'cage rollout' instead") From d4345473b691e620f9496117d39a918163a4e55b Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 1 Jul 2024 20:50:26 +0900 Subject: [PATCH 19/47] add tests --- task/common_test.go | 80 ++++++++++++++++++++++++++++++++++++++++++++- task/export_test.go | 6 ++++ 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/task/common_test.go b/task/common_test.go index aa3d246..cf68cda 100644 --- a/task/common_test.go +++ b/task/common_test.go @@ -1,9 +1,87 @@ package task_test import ( + "context" + "fmt" "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/golang/mock/gomock" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/mocks/mock_awsiface" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/logos/di" + "github.com/stretchr/testify/assert" ) func TestCommon_Start(t *testing.T) { - + t.Run("Fargate", func(t *testing.T) { + t.Run("basic", func(t *testing.T) { + ctrl := gomock.NewController(t) + td := &ecstypes.TaskDefinition{} + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(&ecs.RunTaskOutput{ + Tasks: []ecstypes.Task{{TaskArn: aws.String("task-arn")}}, + }, nil) + envars := test.DefaultEnvars() + cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), &task.Input{TaskDefinition: td}) + err := cm.Start(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, "task-arn", *cm.TaskArn()) + }) + t.Run("should error if task failed to start", func(t *testing.T) { + ctrl := gomock.NewController(t) + td := &ecstypes.TaskDefinition{} + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")) + envars := test.DefaultEnvars() + cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), &task.Input{TaskDefinition: td}) + err := cm.Start(context.TODO()) + assert.EqualError(t, err, "error") + assert.Nil(t, cm.TaskArn()) + }) + }) + t.Run("EC2", func(t *testing.T) { + t.Run("basic", func(t *testing.T) { + ctrl := gomock.NewController(t) + td := &ecstypes.TaskDefinition{} + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(&ecs.StartTaskOutput{ + Tasks: []ecstypes.Task{{TaskArn: aws.String("task-arn")}}, + }, nil) + envars := test.DefaultEnvars() + envars.CanaryInstanceArn = "instance-arn" + cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), &task.Input{TaskDefinition: td}) + err := cm.Start(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, "task-arn", *cm.TaskArn()) + }) + t.Run("should error if task failed to start", func(t *testing.T) { + ctrl := gomock.NewController(t) + td := &ecstypes.TaskDefinition{} + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")) + envars := test.DefaultEnvars() + envars.CanaryInstanceArn = "instance-arn" + cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), &task.Input{TaskDefinition: td}) + err := cm.Start(context.TODO()) + assert.EqualError(t, err, "error") + assert.Nil(t, cm.TaskArn()) + }) + }) } diff --git a/task/export_test.go b/task/export_test.go index ee02268..2b73a37 100644 --- a/task/export_test.go +++ b/task/export_test.go @@ -1,3 +1,9 @@ package task +import "github.com/loilo-inc/logos/di" + type CommonExport = common + +func NewCommonExport(di *di.D, input *Input) *common { + return &common{di: di, Input: input} +} From 262122ed541be45bc65bd2febec49af84c426a39 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 2 Jul 2024 18:40:26 +0900 Subject: [PATCH 20/47] drop timeout.Manager --- cli/cage/commands/command.go | 2 - env/timeout.go | 47 +++++++++++++++++++++ {timeout => env}/timeout_test.go | 19 ++++----- rollout.go | 4 +- rollout_test.go | 15 ------- run.go | 6 +-- run_test.go | 7 ---- task/alb_task.go | 6 +-- task/alb_task_test.go | 2 - task/common.go | 10 ++--- task/simple_task_test.go | 2 - task/srv_task.go | 5 +-- task/srv_task_test.go | 2 - timeout/timeout.go | 70 -------------------------------- up.go | 4 +- up_test.go | 3 -- 16 files changed, 68 insertions(+), 136 deletions(-) create mode 100644 env/timeout.go rename {timeout => env}/timeout_test.go (59%) delete mode 100644 timeout/timeout.go diff --git a/cli/cage/commands/command.go b/cli/cage/commands/command.go index dcb3167..d176641 100644 --- a/cli/cage/commands/command.go +++ b/cli/cage/commands/command.go @@ -3,7 +3,6 @@ package commands import ( "context" "io" - "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ec2" @@ -52,7 +51,6 @@ func DefalutCageCliProvider(envars *env.Envars) (types.Cage, error) { b.Set(key.Ec2Cli, ec2.NewFromConfig(conf)) b.Set(key.AlbCli, elasticloadbalancingv2.NewFromConfig(conf)) b.Set(key.SrvCli, servicediscovery.NewFromConfig(conf)) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 15*time.Minute)) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, &timeout.Time{}) }) diff --git a/env/timeout.go b/env/timeout.go new file mode 100644 index 0000000..c8d7baf --- /dev/null +++ b/env/timeout.go @@ -0,0 +1,47 @@ +package env + +import ( + "time" +) + +const defaultTimeout = 15 * time.Minute + +func (t *Envars) TaskRunning() time.Duration { + wait := t.CanaryTaskRunningWait + if wait > 0 { + return time.Duration(wait) * time.Second + } + return defaultTimeout +} + +func (t *Envars) TaskHealthCheck() time.Duration { + wait := t.CanaryTaskHealthCheckWait + if wait > 0 { + return time.Duration(wait) * time.Second + } + return defaultTimeout +} + +func (t *Envars) TaskStopped() time.Duration { + wait := t.CanaryTaskStoppedWait + if wait > 0 { + return time.Duration(wait) * time.Second + } + return defaultTimeout +} + +func (t *Envars) ServiceStable() time.Duration { + wait := t.ServiceStableWait + if wait > 0 { + return time.Duration(wait) * time.Second + } + return defaultTimeout +} + +func (t *Envars) TargetHealthCheck() time.Duration { + wait := t.TargetHealthCheckWait + if wait > 0 { + return time.Duration(wait) * time.Second + } + return defaultTimeout +} diff --git a/timeout/timeout_test.go b/env/timeout_test.go similarity index 59% rename from timeout/timeout_test.go rename to env/timeout_test.go index f397f7e..d6a434d 100644 --- a/timeout/timeout_test.go +++ b/env/timeout_test.go @@ -1,31 +1,30 @@ -package timeout_test +package env_test import ( "testing" "time" "github.com/loilo-inc/canarycage/env" - "github.com/loilo-inc/canarycage/timeout" "github.com/stretchr/testify/assert" ) func TestManager(t *testing.T) { t.Run("no config", func(t *testing.T) { - man := timeout.NewManager(&env.Envars{}, 10) - assert.Equal(t, time.Duration(10), man.TaskRunning()) - assert.Equal(t, time.Duration(10), man.TaskStopped()) - assert.Equal(t, time.Duration(10), man.TaskHealthCheck()) - assert.Equal(t, time.Duration(10), man.ServiceStable()) - assert.Equal(t, time.Duration(10), man.TargetHealthCheck()) + man := &env.Envars{} + assert.Equal(t, 15*time.Minute, man.TaskRunning()) + assert.Equal(t, 15*time.Minute, man.TaskStopped()) + assert.Equal(t, 15*time.Minute, man.TaskHealthCheck()) + assert.Equal(t, 15*time.Minute, man.ServiceStable()) + assert.Equal(t, 15*time.Minute, man.TargetHealthCheck()) }) t.Run("with config", func(t *testing.T) { - man := timeout.NewManager(&env.Envars{ + man := &env.Envars{ CanaryTaskRunningWait: 1, CanaryTaskStoppedWait: 2, CanaryTaskHealthCheckWait: 3, ServiceStableWait: 4, TargetHealthCheckWait: 5, - }, 10) + } assert.Equal(t, 1*time.Second, man.TaskRunning()) assert.Equal(t, 2*time.Second, man.TaskStopped()) assert.Equal(t, 3*time.Second, man.TaskHealthCheck()) diff --git a/rollout.go b/rollout.go index 7e8ad1f..cb09656 100644 --- a/rollout.go +++ b/rollout.go @@ -11,7 +11,6 @@ import ( "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/taskset" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) @@ -88,11 +87,10 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R } result.ServiceIntact = false log.Infof("waiting for service '%s' to be stable...", env.Service) - timeouManager := c.di.Get(key.TimeoutManager).(timeout.Manager) if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ Cluster: &env.Cluster, Services: []string{env.Service}, - }, timeouManager.ServiceStable()); err != nil { + }, env.ServiceStable()); err != nil { return result, err } log.Infof("🥴 service '%s' has become to be stable!", env.Service) diff --git a/rollout_test.go b/rollout_test.go index cfb1e73..9f409e2 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -4,7 +4,6 @@ import ( "context" "strings" "testing" - "time" "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/aws" @@ -18,7 +17,6 @@ import ( "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" @@ -47,7 +45,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -72,7 +69,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -113,7 +109,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1*time.Minute)) })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -159,7 +154,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -190,7 +184,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() service, _ := mctx.GetEcsService(envars.Service) @@ -215,7 +208,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{ @@ -242,7 +234,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -261,7 +252,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() if res, err := cagecli.RollOut(ctx, &types.RollOutInput{}); err != nil { @@ -287,7 +277,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.EcsCli, ecsMock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) _, err := cagecli.RollOut(context.Background(), &types.RollOutInput{}) assert.EqualError(t, err, "😵 'service' status is 'INACTIVE'. Stop rolling out") @@ -332,7 +321,6 @@ func TestCage_RollOut_FARGATE(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() res, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -380,7 +368,6 @@ func TestCage_RollOut_EC2(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -412,7 +399,6 @@ func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) @@ -445,7 +431,6 @@ func TestCage_RollOut_EC2_no_attribute(t *testing.T) { b.Set(key.Ec2Cli, ec2Mock) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(envars, 1)) })) ctx := context.Background() result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) diff --git a/run.go b/run.go index 82d6b51..dde8692 100644 --- a/run.go +++ b/run.go @@ -10,7 +10,6 @@ import ( "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) @@ -34,7 +33,6 @@ func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult return nil, err } ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) o, err := ecsCli.RunTask(ctx, &ecs.RunTaskInput{ Cluster: &env.Cluster, TaskDefinition: td.TaskDefinitionArn, @@ -52,7 +50,7 @@ func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*taskArn}, - }, timeoutManager.TaskRunning()); err != nil { + }, env.TaskRunning()); err != nil { return nil, xerrors.Errorf("task failed to start: %w", err) } log.Infof("task '%s' is running", *taskArn) @@ -60,7 +58,7 @@ func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult if result, err := ecs.NewTasksStoppedWaiter(ecsCli).WaitForOutput(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*taskArn}, - }, timeoutManager.TaskStopped()); err != nil { + }, env.TaskStopped()); err != nil { return nil, xerrors.Errorf("task failed to stop: %w", err) } else { task := result.Tasks[0] diff --git a/run_test.go b/run_test.go index 2143ea5..ac91ca3 100644 --- a/run_test.go +++ b/run_test.go @@ -13,7 +13,6 @@ import ( "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" @@ -47,7 +46,6 @@ func TestCage_Run(t *testing.T) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, @@ -78,7 +76,6 @@ func TestCage_Run(t *testing.T) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, @@ -101,7 +98,6 @@ func TestCage_Run(t *testing.T) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, @@ -130,7 +126,6 @@ func TestCage_Run(t *testing.T) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, @@ -159,7 +154,6 @@ func TestCage_Run(t *testing.T) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: &container, @@ -176,7 +170,6 @@ func TestCage_Run(t *testing.T) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) b.Set(key.Time, test.NewFakeTime()) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Run(ctx, &types.RunInput{ Container: aws.String("foo"), diff --git a/task/alb_task.go b/task/alb_task.go index 8a9debf..7a2f45e 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -10,8 +10,8 @@ import ( elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "github.com/loilo-inc/logos/di" "golang.org/x/xerrors" @@ -91,13 +91,13 @@ func (c *albTask) registerToTargetGroup(ctx context.Context) error { func (c *albTask) waitUntilTargetHealthy( ctx context.Context, ) error { + env := c.di.Get(key.Env).(*env.Envars) albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) - timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) timer := c.di.Get(key.Time).(types.Time) log.Infof("checking the health state of canary task...") var unusedCount = 0 var recentState *elbv2types.TargetHealthStateEnum - rest := timeoutManager.TargetHealthCheck() + rest := env.TargetHealthCheck() waitPeriod := 15 * time.Second for rest > 0 && unusedCount < 5 { if rest < waitPeriod { diff --git a/task/alb_task_test.go b/task/alb_task_test.go index a9fc23a..e102fc5 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -8,7 +8,6 @@ import ( "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -26,7 +25,6 @@ func TestAlbTask(t *testing.T) { b.Set(key.Ec2Cli, mocker.Ec2) b.Set(key.AlbCli, mocker.Alb) b.Set(key.SrvCli, mocker.Srv) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) b.Set(key.Time, test.NewFakeTime()) }) stask := task.NewAlbTask(d, &task.Input{ diff --git a/task/common.go b/task/common.go index 3f8cf53..ce18c0e 100644 --- a/task/common.go +++ b/task/common.go @@ -13,7 +13,6 @@ import ( "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "github.com/loilo-inc/logos/di" "golang.org/x/xerrors" @@ -80,12 +79,11 @@ func (c *common) TaskArn() *string { func (c *common) waitForTask(ctx context.Context) error { env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, - }, timeoutManager.TaskRunning()); err != nil { + }, env.TaskRunning()); err != nil { return err } log.Infof("🐣 canary task '%s' is running!", *c.taskArn) @@ -101,7 +99,6 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { log.Infof("😷 ensuring canary task container(s) to become healthy...") env := c.di.Get(key.Env).(*env.Envars) timer := c.di.Get(key.Time).(types.Time) - timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) containerHasHealthChecks := map[string]struct{}{} for _, definition := range c.TaskDefinition.ContainerDefinitions { @@ -109,7 +106,7 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { containerHasHealthChecks[*definition.Name] = struct{}{} } } - rest := timeoutManager.TaskHealthCheck() + rest := env.TaskHealthCheck() healthCheckPeriod := 15 * time.Second for rest > 0 { if rest < healthCheckPeriod { @@ -247,7 +244,6 @@ func (c *common) stopTask(ctx context.Context) error { } env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) log.Infof("stopping the canary task '%s'...", *c.taskArn) if _, err := ecsCli.StopTask(ctx, &ecs.StopTaskInput{ Cluster: &env.Cluster, @@ -258,7 +254,7 @@ func (c *common) stopTask(ctx context.Context) error { if err := ecs.NewTasksStoppedWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, - }, timeoutManager.TaskStopped()); err != nil { + }, env.TaskStopped()); err != nil { return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) } log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) diff --git a/task/simple_task_test.go b/task/simple_task_test.go index 27a0a59..23765e1 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -7,7 +7,6 @@ import ( "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -25,7 +24,6 @@ func TestSimpleTask(t *testing.T) { b.Set(key.Ec2Cli, mocker.Ec2) b.Set(key.AlbCli, mocker.Alb) b.Set(key.SrvCli, mocker.Srv) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) b.Set(key.Time, test.NewFakeTime()) }) stask := task.NewSimpleTask(d, &task.Input{ diff --git a/task/srv_task.go b/task/srv_task.go index ad6db5c..395b26b 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -11,7 +11,6 @@ import ( "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "github.com/loilo-inc/logos/di" "golang.org/x/xerrors" @@ -114,10 +113,10 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { func (c *srvTask) waitUntilSrvInstHelthy( ctx context.Context, ) error { + env := c.di.Get(key.Env).(*env.Envars) timer := c.di.Get(key.Time).(types.Time) srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) - timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) - var rest = timeoutManager.TargetHealthCheck() + var rest = env.TargetHealthCheck() var waitPeriod = 15 * time.Second for rest > 0 { if rest < waitPeriod { diff --git a/task/srv_task_test.go b/task/srv_task_test.go index 9853c05..5ddefc8 100644 --- a/task/srv_task_test.go +++ b/task/srv_task_test.go @@ -8,7 +8,6 @@ import ( "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -30,7 +29,6 @@ func TestSrvTask(t *testing.T) { b.Set(key.Ec2Cli, mocker.Ec2) b.Set(key.AlbCli, mocker.Alb) b.Set(key.SrvCli, mocker.Srv) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) b.Set(key.Time, test.NewFakeTime()) }) stask := task.NewSrvTask(d, &task.Input{ diff --git a/timeout/timeout.go b/timeout/timeout.go deleted file mode 100644 index 0bcb9ed..0000000 --- a/timeout/timeout.go +++ /dev/null @@ -1,70 +0,0 @@ -package timeout - -import ( - "time" - - "github.com/loilo-inc/canarycage/env" -) - -type Manager interface { - TaskRunning() time.Duration - TaskHealthCheck() time.Duration - TaskStopped() time.Duration - ServiceStable() time.Duration - TargetHealthCheck() time.Duration -} - -type manager struct { - env *env.Envars - DefaultTimeout time.Duration -} - -func NewManager( - env *env.Envars, - defaultTimeout time.Duration, -) Manager { - return &manager{ - env: env, - DefaultTimeout: defaultTimeout, - } -} - -func (t *manager) TaskRunning() time.Duration { - wait := t.env.CanaryTaskRunningWait - if wait > 0 { - return time.Duration(wait) * time.Second - } - return t.DefaultTimeout -} - -func (t *manager) TaskHealthCheck() time.Duration { - wait := t.env.CanaryTaskHealthCheckWait - if wait > 0 { - return time.Duration(wait) * time.Second - } - return t.DefaultTimeout -} - -func (t *manager) TaskStopped() time.Duration { - wait := t.env.CanaryTaskStoppedWait - if wait > 0 { - return time.Duration(wait) * time.Second - } - return t.DefaultTimeout -} - -func (t *manager) ServiceStable() time.Duration { - wait := t.env.ServiceStableWait - if wait > 0 { - return time.Duration(wait) * time.Second - } - return t.DefaultTimeout -} - -func (t *manager) TargetHealthCheck() time.Duration { - wait := t.env.TargetHealthCheckWait - if wait > 0 { - return time.Duration(wait) * time.Second - } - return t.DefaultTimeout -} diff --git a/up.go b/up.go index 434405e..a66d45c 100644 --- a/up.go +++ b/up.go @@ -9,7 +9,6 @@ import ( "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) @@ -49,12 +48,11 @@ func (c *cage) createService(ctx context.Context, serviceDefinitionInput *ecs.Cr if err != nil { return nil, xerrors.Errorf("failed to create service '%s': %w", *serviceDefinitionInput.ServiceName, err) } - timeoutManager := c.di.Get(key.TimeoutManager).(timeout.Manager) log.Infof("waiting for service '%s' to be STABLE", *serviceDefinitionInput.ServiceName) if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ Cluster: &env.Cluster, Services: []string{*serviceDefinitionInput.ServiceName}, - }, timeoutManager.ServiceStable()); err != nil { + }, env.ServiceStable()); err != nil { return nil, xerrors.Errorf("failed to wait for service '%s' to be STABLE: %w", *serviceDefinitionInput.ServiceName, err) } return o.Service, nil diff --git a/up_test.go b/up_test.go index 352a315..e929dfd 100644 --- a/up_test.go +++ b/up_test.go @@ -8,7 +8,6 @@ import ( cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -22,7 +21,6 @@ func TestCage_Up(t *testing.T) { cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Up(context.Background()) assert.Nil(t, err) @@ -36,7 +34,6 @@ func TestCage_Up(t *testing.T) { cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) - b.Set(key.TimeoutManager, timeout.NewManager(env, 1)) })) result, err := cagecli.Up(context.Background()) assert.Nil(t, result) From b9c2c929c87ffc4cffedaa7600efac11a67a18d8 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 2 Jul 2024 18:41:41 +0900 Subject: [PATCH 21/47] a --- Makefile | 6 ++-- cli/cage/commands/command.go | 30 ----------------- cli/cage/main.go | 35 +++++++++++++++++++- mocks/{mock_taskset => mock_task}/factory.go | 6 ++-- task/common.go | 3 ++ task/srv_task.go | 2 -- taskset/taskset_test.go | 9 +++-- timeout/time.go | 3 -- types/iface.go | 1 - 9 files changed, 47 insertions(+), 48 deletions(-) rename mocks/{mock_taskset => mock_task}/factory.go (95%) diff --git a/Makefile b/Makefile index 17a9652..a901369 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ mocks: mocks/mock_awsiface/iface.go \ mocks/mock_upgrade/upgrade.go \ mocks/mock_task/task.go \ mocks/mock_taskset/taskset.go \ - mocks/mock_taskset/factory.go + mocks/mock_task/factory.go mocks/mock_awsiface/iface.go: awsiface/iface.go $(MOCKGEN) -source=./awsiface/iface.go > mocks/mock_awsiface/iface.go mocks/mock_types/iface.go: cage.go @@ -25,6 +25,6 @@ mocks/mock_task/task.go: task/task.go $(MOCKGEN) -source=./task/task.go > mocks/mock_task/task.go mocks/mock_taskset/taskset.go: taskset/taskset.go $(MOCKGEN) -source=./taskset/taskset.go > mocks/mock_taskset/taskset.go -mocks/mock_taskset/factory.go: taskset/factory.go - $(MOCKGEN) -source=./taskset/factory.go > mocks/mock_taskset/factory.go +mocks/mock_task/factory.go: task/factory.go + $(MOCKGEN) -source=./task/factory.go > mocks/mock_task/factory.go .PHONY: mocks diff --git a/cli/cage/commands/command.go b/cli/cage/commands/command.go index d176641..62812e6 100644 --- a/cli/cage/commands/command.go +++ b/cli/cage/commands/command.go @@ -1,22 +1,12 @@ package commands import ( - "context" "io" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ecs" - "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - "github.com/aws/aws-sdk-go-v2/service/servicediscovery" - cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/prompt" "github.com/loilo-inc/canarycage/env" - "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/task" - "github.com/loilo-inc/canarycage/timeout" "github.com/loilo-inc/canarycage/types" - "github.com/loilo-inc/logos/di" "github.com/urfave/cli/v2" "golang.org/x/xerrors" ) @@ -38,26 +28,6 @@ func NewCageCommands( type cageCliProvier = func(envars *env.Envars) (types.Cage, error) -func DefalutCageCliProvider(envars *env.Envars) (types.Cage, error) { - conf, err := config.LoadDefaultConfig( - context.Background(), - config.WithRegion(envars.Region)) - if err != nil { - return nil, xerrors.Errorf("failed to load aws config: %w", err) - } - d := di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecs.NewFromConfig(conf)) - b.Set(key.Ec2Cli, ec2.NewFromConfig(conf)) - b.Set(key.AlbCli, elasticloadbalancingv2.NewFromConfig(conf)) - b.Set(key.SrvCli, servicediscovery.NewFromConfig(conf)) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, &timeout.Time{}) - }) - cagecli := cage.NewCage(d) - return cagecli, nil -} - func (c *CageCommands) requireArgs( ctx *cli.Context, minArgs int, diff --git a/cli/cage/main.go b/cli/cage/main.go index 40093bf..d939a39 100644 --- a/cli/cage/main.go +++ b/cli/cage/main.go @@ -1,14 +1,27 @@ package main import ( + "context" "fmt" "log" "os" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ecs" + "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/aws/aws-sdk-go-v2/service/servicediscovery" + cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/commands" "github.com/loilo-inc/canarycage/cli/cage/upgrade" "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/timeout" + "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" "github.com/urfave/cli/v2" + "golang.org/x/xerrors" ) // set by goreleaser @@ -26,7 +39,7 @@ func main() { app.Usage = "A deployment tool for AWS ECS" app.Description = "A deployment tool for AWS ECS" envars := env.Envars{} - cmds := commands.NewCageCommands(os.Stdin, commands.DefalutCageCliProvider) + cmds := commands.NewCageCommands(os.Stdin, provideCageCli) app.Commands = []*cli.Command{ cmds.Up(&envars), cmds.RollOut(&envars), @@ -45,3 +58,23 @@ func main() { log.Fatal(err) } } + +func provideCageCli(envars *env.Envars) (types.Cage, error) { + conf, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion(envars.Region)) + if err != nil { + return nil, xerrors.Errorf("failed to load aws config: %w", err) + } + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecs.NewFromConfig(conf)) + b.Set(key.Ec2Cli, ec2.NewFromConfig(conf)) + b.Set(key.AlbCli, elasticloadbalancingv2.NewFromConfig(conf)) + b.Set(key.SrvCli, servicediscovery.NewFromConfig(conf)) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + b.Set(key.Time, &timeout.Time{}) + }) + cagecli := cage.NewCage(d) + return cagecli, nil +} diff --git a/mocks/mock_taskset/factory.go b/mocks/mock_task/factory.go similarity index 95% rename from mocks/mock_taskset/factory.go rename to mocks/mock_task/factory.go index c646b24..6d915a6 100644 --- a/mocks/mock_taskset/factory.go +++ b/mocks/mock_task/factory.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: ./taskset/factory.go +// Source: ./task/factory.go -// Package mock_taskset is a generated GoMock package. -package mock_taskset +// Package mock_task is a generated GoMock package. +package mock_task import ( reflect "reflect" diff --git a/task/common.go b/task/common.go index ce18c0e..2a352ba 100644 --- a/task/common.go +++ b/task/common.go @@ -77,6 +77,9 @@ func (c *common) TaskArn() *string { } func (c *common) waitForTask(ctx context.Context) error { + if c.taskArn == nil { + return xerrors.New("task is not started") + } env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) diff --git a/task/srv_task.go b/task/srv_task.go index 395b26b..4df7cbe 100644 --- a/task/srv_task.go +++ b/task/srv_task.go @@ -20,7 +20,6 @@ import ( type srvTask struct { *common registry *ecstypes.ServiceRegistry - target *CanaryTarget srv *srvtypes.Service instId *string ns *srvtypes.Namespace @@ -66,7 +65,6 @@ func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { if err != nil { return err } - c.target = target // get the service id from service registry arn srvId := ArnToId(*c.registry.RegistryArn) srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) env := c.di.Get(key.Env).(*env.Envars) diff --git a/taskset/taskset_test.go b/taskset/taskset_test.go index 008bc83..7a2d653 100644 --- a/taskset/taskset_test.go +++ b/taskset/taskset_test.go @@ -8,7 +8,6 @@ import ( ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/golang/mock/gomock" "github.com/loilo-inc/canarycage/mocks/mock_task" - "github.com/loilo-inc/canarycage/mocks/mock_taskset" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/taskset" "github.com/stretchr/testify/assert" @@ -17,7 +16,7 @@ import ( func TestSet(t *testing.T) { t.Run("basic", func(t *testing.T) { ctrl := gomock.NewController(t) - factory := mock_taskset.NewMockFactory(ctrl) + factory := mock_task.NewMockFactory(ctrl) albTask := mock_task.NewMockTask(ctrl) srvTask := mock_task.NewMockTask(ctrl) lb := ecstypes.LoadBalancer{} @@ -46,7 +45,7 @@ func TestSet(t *testing.T) { }) t.Run("should add a simple task if no load balancer or service registry is given", func(t *testing.T) { ctrl := gomock.NewController(t) - factory := mock_taskset.NewMockFactory(ctrl) + factory := mock_task.NewMockFactory(ctrl) simpleTask := mock_task.NewMockTask(ctrl) input := &taskset.Input{ Input: &task.Input{}, @@ -62,7 +61,7 @@ func TestSet(t *testing.T) { }) t.Run("should aggregate errors from task.Wait", func(t *testing.T) { ctrl := gomock.NewController(t) - factory := mock_taskset.NewMockFactory(ctrl) + factory := mock_task.NewMockFactory(ctrl) albTask := mock_task.NewMockTask(ctrl) srvTask := mock_task.NewMockTask(ctrl) lb := ecstypes.LoadBalancer{} @@ -90,7 +89,7 @@ func TestSet(t *testing.T) { t.Run("should error immediately if task failed to start", func(t *testing.T) { ctrl := gomock.NewController(t) - factory := mock_taskset.NewMockFactory(ctrl) + factory := mock_task.NewMockFactory(ctrl) mtask := mock_task.NewMockTask(ctrl) lb := &ecstypes.LoadBalancer{} factory.EXPECT().NewAlbTask(gomock.Any(), lb).Return(mtask) diff --git a/timeout/time.go b/timeout/time.go index 640def1..857da44 100644 --- a/timeout/time.go +++ b/timeout/time.go @@ -4,9 +4,6 @@ import "time" type Time struct{} -func (t *Time) Now() time.Time { - return time.Now() -} func (t *Time) NewTimer(d time.Duration) *time.Timer { return time.NewTimer(d) } diff --git a/types/iface.go b/types/iface.go index 38093b5..d24bbcc 100644 --- a/types/iface.go +++ b/types/iface.go @@ -14,7 +14,6 @@ type Cage interface { } type Time interface { - Now() time.Time NewTimer(time.Duration) *time.Timer } From bf1f1b7fb808ac18426e081dcbd4debeff02058c Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 2 Jul 2024 20:05:43 +0900 Subject: [PATCH 22/47] drop srv features --- awsiface/iface.go | 9 -- cli/cage/commands/rollout.go | 1 - cli/cage/main.go | 2 - env/env.go | 1 - env/timeout.go | 8 -- env/timeout_test.go | 3 - key/keys.go | 14 ++- mocks/mock_awsiface/iface.go | 104 -------------------- mocks/mock_cage/task_factory.go | 78 --------------- mocks/mock_task/factory.go | 14 --- rollout.go | 6 +- task/alb_task.go | 17 +--- task/alb_task_test.go | 1 - task/arn.go | 8 -- task/arn_test.go | 13 --- task/factory.go | 5 - task/factory_test.go | 7 -- task/simple_task_test.go | 1 - task/srv_task.go | 165 -------------------------------- task/srv_task_test.go | 49 ---------- taskset/taskset.go | 7 +- taskset/taskset_test.go | 27 +----- test/context.go | 16 +--- test/srv.go | 139 --------------------------- 24 files changed, 20 insertions(+), 675 deletions(-) delete mode 100644 mocks/mock_cage/task_factory.go delete mode 100644 task/arn.go delete mode 100644 task/arn_test.go delete mode 100644 task/srv_task.go delete mode 100644 task/srv_task_test.go delete mode 100644 test/srv.go diff --git a/awsiface/iface.go b/awsiface/iface.go index 8135fd1..d377f3e 100644 --- a/awsiface/iface.go +++ b/awsiface/iface.go @@ -6,7 +6,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ecs" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - "github.com/aws/aws-sdk-go-v2/service/servicediscovery" ) type ( @@ -37,16 +36,8 @@ type ( DescribeSubnets(ctx context.Context, params *ec2.DescribeSubnetsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeSubnetsOutput, error) DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) } - SrvClient interface { - DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) - RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) - DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) - GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) - GetNamespace(ctx context.Context, params *servicediscovery.GetNamespaceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetNamespaceOutput, error) - } ) var _ EcsClient = (*ecs.Client)(nil) var _ AlbClient = (*elbv2.Client)(nil) var _ Ec2Client = (*ec2.Client)(nil) -var _ SrvClient = (*servicediscovery.Client)(nil) diff --git a/cli/cage/commands/rollout.go b/cli/cage/commands/rollout.go index 362b34c..87d6129 100644 --- a/cli/cage/commands/rollout.go +++ b/cli/cage/commands/rollout.go @@ -41,7 +41,6 @@ func (c *CageCommands) RollOut( TaskHealthCheckWaitFlag(&envars.CanaryTaskHealthCheckWait), TaskStoppedWaitFlag(&envars.CanaryTaskStoppedWait), ServiceStableWaitFlag(&envars.ServiceStableWait), - TargetHealthCheckWaitFlag(&envars.TargetHealthCheckWait), }, Action: func(ctx *cli.Context) error { dir, _, err := c.requireArgs(ctx, 1, 1) diff --git a/cli/cage/main.go b/cli/cage/main.go index d939a39..5dcb4e4 100644 --- a/cli/cage/main.go +++ b/cli/cage/main.go @@ -10,7 +10,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - "github.com/aws/aws-sdk-go-v2/service/servicediscovery" cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/cli/cage/commands" "github.com/loilo-inc/canarycage/cli/cage/upgrade" @@ -71,7 +70,6 @@ func provideCageCli(envars *env.Envars) (types.Cage, error) { b.Set(key.EcsCli, ecs.NewFromConfig(conf)) b.Set(key.Ec2Cli, ec2.NewFromConfig(conf)) b.Set(key.AlbCli, elasticloadbalancingv2.NewFromConfig(conf)) - b.Set(key.SrvCli, servicediscovery.NewFromConfig(conf)) b.Set(key.TaskFactory, task.NewFactory(b.Future())) b.Set(key.Time, &timeout.Time{}) }) diff --git a/env/env.go b/env/env.go index 2d7a5e5..2bc152f 100644 --- a/env/env.go +++ b/env/env.go @@ -27,7 +27,6 @@ type Envars struct { CanaryTaskHealthCheckWait int // sec CanaryTaskStoppedWait int // sec ServiceStableWait int // sec - TargetHealthCheckWait int // sec } // required diff --git a/env/timeout.go b/env/timeout.go index c8d7baf..b08175c 100644 --- a/env/timeout.go +++ b/env/timeout.go @@ -37,11 +37,3 @@ func (t *Envars) ServiceStable() time.Duration { } return defaultTimeout } - -func (t *Envars) TargetHealthCheck() time.Duration { - wait := t.TargetHealthCheckWait - if wait > 0 { - return time.Duration(wait) * time.Second - } - return defaultTimeout -} diff --git a/env/timeout_test.go b/env/timeout_test.go index d6a434d..d7f1467 100644 --- a/env/timeout_test.go +++ b/env/timeout_test.go @@ -15,7 +15,6 @@ func TestManager(t *testing.T) { assert.Equal(t, 15*time.Minute, man.TaskStopped()) assert.Equal(t, 15*time.Minute, man.TaskHealthCheck()) assert.Equal(t, 15*time.Minute, man.ServiceStable()) - assert.Equal(t, 15*time.Minute, man.TargetHealthCheck()) }) t.Run("with config", func(t *testing.T) { man := &env.Envars{ @@ -23,12 +22,10 @@ func TestManager(t *testing.T) { CanaryTaskStoppedWait: 2, CanaryTaskHealthCheckWait: 3, ServiceStableWait: 4, - TargetHealthCheckWait: 5, } assert.Equal(t, 1*time.Second, man.TaskRunning()) assert.Equal(t, 2*time.Second, man.TaskStopped()) assert.Equal(t, 3*time.Second, man.TaskHealthCheck()) assert.Equal(t, 4*time.Second, man.ServiceStable()) - assert.Equal(t, 5*time.Second, man.TargetHealthCheck()) }) } diff --git a/key/keys.go b/key/keys.go index 023072b..7b32dda 100644 --- a/key/keys.go +++ b/key/keys.go @@ -3,12 +3,10 @@ package key type DepsKey string const ( - EcsCli DepsKey = "ecs" - Ec2Cli DepsKey = "ec2" - SrvCli DepsKey = "srv" - AlbCli DepsKey = "alb" - Env DepsKey = "env" - TimeoutManager DepsKey = "timeout-manager" - Time DepsKey = "time" - TaskFactory DepsKey = "task-factory" + EcsCli DepsKey = "ecs" + Ec2Cli DepsKey = "ec2" + AlbCli DepsKey = "alb" + Env DepsKey = "env" + Time DepsKey = "time" + TaskFactory DepsKey = "task-factory" ) diff --git a/mocks/mock_awsiface/iface.go b/mocks/mock_awsiface/iface.go index 3e79b10..b582870 100644 --- a/mocks/mock_awsiface/iface.go +++ b/mocks/mock_awsiface/iface.go @@ -11,7 +11,6 @@ import ( ec2 "github.com/aws/aws-sdk-go-v2/service/ec2" ecs "github.com/aws/aws-sdk-go-v2/service/ecs" elasticloadbalancingv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - servicediscovery "github.com/aws/aws-sdk-go-v2/service/servicediscovery" gomock "github.com/golang/mock/gomock" ) @@ -503,106 +502,3 @@ func (mr *MockEc2ClientMockRecorder) DescribeSubnets(ctx, params interface{}, op varargs := append([]interface{}{ctx, params}, optFns...) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeSubnets", reflect.TypeOf((*MockEc2Client)(nil).DescribeSubnets), varargs...) } - -// MockSrvClient is a mock of SrvClient interface. -type MockSrvClient struct { - ctrl *gomock.Controller - recorder *MockSrvClientMockRecorder -} - -// MockSrvClientMockRecorder is the mock recorder for MockSrvClient. -type MockSrvClientMockRecorder struct { - mock *MockSrvClient -} - -// NewMockSrvClient creates a new mock instance. -func NewMockSrvClient(ctrl *gomock.Controller) *MockSrvClient { - mock := &MockSrvClient{ctrl: ctrl} - mock.recorder = &MockSrvClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSrvClient) EXPECT() *MockSrvClientMockRecorder { - return m.recorder -} - -// DeregisterInstance mocks base method. -func (m *MockSrvClient) DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, params} - for _, a := range optFns { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "DeregisterInstance", varargs...) - ret0, _ := ret[0].(*servicediscovery.DeregisterInstanceOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DeregisterInstance indicates an expected call of DeregisterInstance. -func (mr *MockSrvClientMockRecorder) DeregisterInstance(ctx, params interface{}, optFns ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, params}, optFns...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeregisterInstance", reflect.TypeOf((*MockSrvClient)(nil).DeregisterInstance), varargs...) -} - -// DiscoverInstances mocks base method. -func (m *MockSrvClient) DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, params} - for _, a := range optFns { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "DiscoverInstances", varargs...) - ret0, _ := ret[0].(*servicediscovery.DiscoverInstancesOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverInstances indicates an expected call of DiscoverInstances. -func (mr *MockSrvClientMockRecorder) DiscoverInstances(ctx, params interface{}, optFns ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, params}, optFns...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverInstances", reflect.TypeOf((*MockSrvClient)(nil).DiscoverInstances), varargs...) -} - -// GetService mocks base method. -func (m *MockSrvClient) GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, params} - for _, a := range optFns { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "GetService", varargs...) - ret0, _ := ret[0].(*servicediscovery.GetServiceOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetService indicates an expected call of GetService. -func (mr *MockSrvClientMockRecorder) GetService(ctx, params interface{}, optFns ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, params}, optFns...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockSrvClient)(nil).GetService), varargs...) -} - -// RegisterInstance mocks base method. -func (m *MockSrvClient) RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, params} - for _, a := range optFns { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "RegisterInstance", varargs...) - ret0, _ := ret[0].(*servicediscovery.RegisterInstanceOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// RegisterInstance indicates an expected call of RegisterInstance. -func (mr *MockSrvClientMockRecorder) RegisterInstance(ctx, params interface{}, optFns ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, params}, optFns...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterInstance", reflect.TypeOf((*MockSrvClient)(nil).RegisterInstance), varargs...) -} diff --git a/mocks/mock_cage/task_factory.go b/mocks/mock_cage/task_factory.go deleted file mode 100644 index a8f7cea..0000000 --- a/mocks/mock_cage/task_factory.go +++ /dev/null @@ -1,78 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: ./task_factory.go - -// Package mock_cage is a generated GoMock package. -package mock_cage - -import ( - reflect "reflect" - - types "github.com/aws/aws-sdk-go-v2/service/ecs/types" - gomock "github.com/golang/mock/gomock" - task "github.com/loilo-inc/canarycage/task" -) - -// MockTaskFactory is a mock of TaskFactory interface. -type MockTaskFactory struct { - ctrl *gomock.Controller - recorder *MockTaskFactoryMockRecorder -} - -// MockTaskFactoryMockRecorder is the mock recorder for MockTaskFactory. -type MockTaskFactoryMockRecorder struct { - mock *MockTaskFactory -} - -// NewMockTaskFactory creates a new mock instance. -func NewMockTaskFactory(ctrl *gomock.Controller) *MockTaskFactory { - mock := &MockTaskFactory{ctrl: ctrl} - mock.recorder = &MockTaskFactoryMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockTaskFactory) EXPECT() *MockTaskFactoryMockRecorder { - return m.recorder -} - -// NewAlbTask mocks base method. -func (m *MockTaskFactory) NewAlbTask(input *task.Input, lb *types.LoadBalancer) task.Task { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewAlbTask", input, lb) - ret0, _ := ret[0].(task.Task) - return ret0 -} - -// NewAlbTask indicates an expected call of NewAlbTask. -func (mr *MockTaskFactoryMockRecorder) NewAlbTask(input, lb interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAlbTask", reflect.TypeOf((*MockTaskFactory)(nil).NewAlbTask), input, lb) -} - -// NewSimpleTask mocks base method. -func (m *MockTaskFactory) NewSimpleTask(input *task.Input) task.Task { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewSimpleTask", input) - ret0, _ := ret[0].(task.Task) - return ret0 -} - -// NewSimpleTask indicates an expected call of NewSimpleTask. -func (mr *MockTaskFactoryMockRecorder) NewSimpleTask(input interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSimpleTask", reflect.TypeOf((*MockTaskFactory)(nil).NewSimpleTask), input) -} - -// NewSrvTask mocks base method. -func (m *MockTaskFactory) NewSrvTask(input *task.Input, srv *types.ServiceRegistry) task.Task { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewSrvTask", input, srv) - ret0, _ := ret[0].(task.Task) - return ret0 -} - -// NewSrvTask indicates an expected call of NewSrvTask. -func (mr *MockTaskFactoryMockRecorder) NewSrvTask(input, srv interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSrvTask", reflect.TypeOf((*MockTaskFactory)(nil).NewSrvTask), input, srv) -} diff --git a/mocks/mock_task/factory.go b/mocks/mock_task/factory.go index 6d915a6..fbcee8e 100644 --- a/mocks/mock_task/factory.go +++ b/mocks/mock_task/factory.go @@ -62,17 +62,3 @@ func (mr *MockFactoryMockRecorder) NewSimpleTask(input interface{}) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSimpleTask", reflect.TypeOf((*MockFactory)(nil).NewSimpleTask), input) } - -// NewSrvTask mocks base method. -func (m *MockFactory) NewSrvTask(input *task.Input, srv *types.ServiceRegistry) task.Task { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "NewSrvTask", input, srv) - ret0, _ := ret[0].(task.Task) - return ret0 -} - -// NewSrvTask indicates an expected call of NewSrvTask. -func (mr *MockFactoryMockRecorder) NewSrvTask(input, srv interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewSrvTask", reflect.TypeOf((*MockFactory)(nil).NewSrvTask), input, srv) -} diff --git a/rollout.go b/rollout.go index cb09656..fb4f041 100644 --- a/rollout.go +++ b/rollout.go @@ -109,13 +109,11 @@ func (c *cage) StartCanaryTasks( var networkConfiguration *ecstypes.NetworkConfiguration var platformVersion *string var loadBalancers []ecstypes.LoadBalancer - var serviceRegistries []ecstypes.ServiceRegistry env := c.di.Get(key.Env).(*env.Envars) if input.UpdateService { networkConfiguration = env.ServiceDefinitionInput.NetworkConfiguration platformVersion = env.ServiceDefinitionInput.PlatformVersion loadBalancers = env.ServiceDefinitionInput.LoadBalancers - serviceRegistries = env.ServiceDefinitionInput.ServiceRegistries } else { ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) if o, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ @@ -128,7 +126,6 @@ func (c *cage) StartCanaryTasks( networkConfiguration = service.NetworkConfiguration platformVersion = service.PlatformVersion loadBalancers = service.LoadBalancers - serviceRegistries = service.ServiceRegistries } } factory := c.di.Get(key.TaskFactory).(task.Factory) @@ -140,8 +137,7 @@ func (c *cage) StartCanaryTasks( TaskDefinition: nextTaskDefinition, PlatformVersion: platformVersion, }, - LoadBalancers: loadBalancers, - ServiceRegistries: serviceRegistries, + LoadBalancers: loadBalancers, }, ), nil } diff --git a/task/alb_task.go b/task/alb_task.go index 7a2f45e..2cfaabe 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -10,7 +10,6 @@ import ( elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/loilo-inc/canarycage/awsiface" - "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/types" "github.com/loilo-inc/logos/di" @@ -91,18 +90,13 @@ func (c *albTask) registerToTargetGroup(ctx context.Context) error { func (c *albTask) waitUntilTargetHealthy( ctx context.Context, ) error { - env := c.di.Get(key.Env).(*env.Envars) albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) timer := c.di.Get(key.Time).(types.Time) log.Infof("checking the health state of canary task...") - var unusedCount = 0 + var notHealthyCount = 0 var recentState *elbv2types.TargetHealthStateEnum - rest := env.TargetHealthCheck() waitPeriod := 15 * time.Second - for rest > 0 && unusedCount < 5 { - if rest < waitPeriod { - waitPeriod = rest - } + for notHealthyCount < 5 { select { case <-ctx.Done(): return ctx.Err() @@ -127,14 +121,13 @@ func (c *albTask) waitUntilTargetHealthy( } log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.target.targetId, c.target.targetPort, *recentState) switch *recentState { - case "healthy": + case elbv2types.TargetHealthStateEnumHealthy: return nil - case "unused": - unusedCount++ + default: + notHealthyCount++ } } } - rest -= waitPeriod } // unhealthy, draining, unused log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) diff --git a/task/alb_task_test.go b/task/alb_task_test.go index e102fc5..512f8ae 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -24,7 +24,6 @@ func TestAlbTask(t *testing.T) { b.Set(key.EcsCli, mocker.Ecs) b.Set(key.Ec2Cli, mocker.Ec2) b.Set(key.AlbCli, mocker.Alb) - b.Set(key.SrvCli, mocker.Srv) b.Set(key.Time, test.NewFakeTime()) }) stask := task.NewAlbTask(d, &task.Input{ diff --git a/task/arn.go b/task/arn.go deleted file mode 100644 index bb0dace..0000000 --- a/task/arn.go +++ /dev/null @@ -1,8 +0,0 @@ -package task - -import "strings" - -func ArnToId(arn string) string { - list := strings.Split(arn, "/") - return list[len(list)-1] -} diff --git a/task/arn_test.go b/task/arn_test.go deleted file mode 100644 index 28f3657..0000000 --- a/task/arn_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package task_test - -import ( - "testing" - - "github.com/loilo-inc/canarycage/task" - "github.com/stretchr/testify/assert" -) - -func TestArnToId(t *testing.T) { - arn := "arn://aaa/srv-1234" - assert.Equal(t, "srv-1234", task.ArnToId(arn)) -} diff --git a/task/factory.go b/task/factory.go index 4f98267..07bc6fc 100644 --- a/task/factory.go +++ b/task/factory.go @@ -7,7 +7,6 @@ import ( type Factory interface { NewAlbTask(input *Input, lb *ecstypes.LoadBalancer) Task - NewSrvTask(input *Input, srv *ecstypes.ServiceRegistry) Task NewSimpleTask(input *Input) Task } @@ -23,10 +22,6 @@ func (f *factory) NewAlbTask(input *Input, lb *ecstypes.LoadBalancer) Task { return NewAlbTask(f.di, input, lb) } -func (f *factory) NewSrvTask(input *Input, srv *ecstypes.ServiceRegistry) Task { - return NewSrvTask(f.di, input, srv) -} - func (f *factory) NewSimpleTask(input *Input) Task { return NewSimpleTask(f.di, input) } diff --git a/task/factory_test.go b/task/factory_test.go index ad13d72..642fd65 100644 --- a/task/factory_test.go +++ b/task/factory_test.go @@ -18,13 +18,6 @@ func TestFactory(t *testing.T) { task := f.NewAlbTask(input, lb) assert.NotNil(t, task) }) - t.Run("NewSrvTask", func(t *testing.T) { - f := task.NewFactory(d) - input := &task.Input{} - srv := &ecstypes.ServiceRegistry{} - task := f.NewSrvTask(input, srv) - assert.NotNil(t, task) - }) t.Run("NewSimpleTask", func(t *testing.T) { f := task.NewFactory(d) input := &task.Input{} diff --git a/task/simple_task_test.go b/task/simple_task_test.go index 23765e1..da41c55 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -23,7 +23,6 @@ func TestSimpleTask(t *testing.T) { b.Set(key.EcsCli, mocker.Ecs) b.Set(key.Ec2Cli, mocker.Ec2) b.Set(key.AlbCli, mocker.Alb) - b.Set(key.SrvCli, mocker.Srv) b.Set(key.Time, test.NewFakeTime()) }) stask := task.NewSimpleTask(d, &task.Input{ diff --git a/task/srv_task.go b/task/srv_task.go deleted file mode 100644 index 4df7cbe..0000000 --- a/task/srv_task.go +++ /dev/null @@ -1,165 +0,0 @@ -package task - -import ( - "context" - "time" - - "github.com/apex/log" - ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/aws/aws-sdk-go-v2/service/servicediscovery" - srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" - "github.com/loilo-inc/canarycage/awsiface" - "github.com/loilo-inc/canarycage/env" - "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/types" - "github.com/loilo-inc/logos/di" - "golang.org/x/xerrors" -) - -// srvTask is a task that is attached to an Service Discovery -type srvTask struct { - *common - registry *ecstypes.ServiceRegistry - srv *srvtypes.Service - instId *string - ns *srvtypes.Namespace -} - -func NewSrvTask(di *di.D, input *Input, registry *ecstypes.ServiceRegistry) Task { - return &srvTask{ - common: &common{Input: input, di: di}, - registry: registry, - } -} - -func (c *srvTask) Wait(ctx context.Context) error { - if err := c.waitForTask(ctx); err != nil { - return err - } - if err := c.registerToSrvDiscovery(ctx); err != nil { - return err - } - log.Infof("canary task '%s' is registered to service discovery instance '%s'", *c.taskArn, *c.instId) - log.Infof("😷 ensuring canary service instance to become healthy...") - if err := c.waitUntilSrvInstHelthy(ctx); err != nil { - return err - } - log.Info("🤩 canary service instance is healthy!") - return nil -} - -func (c *srvTask) Stop(ctx context.Context) error { - c.deregisterSrvInst(ctx) - return c.stopTask(ctx) -} - -func (c *srvTask) registerToSrvDiscovery(ctx context.Context) error { - log.Infof("registring canary task '%s' to service discovery...", *c.taskArn) - var targetPort int32 - if c.registry.Port != nil { - targetPort = *c.registry.Port - } else { - targetPort = 80 - } - target, err := c.describeTaskTarget(ctx, targetPort) - if err != nil { - return err - } - srvId := ArnToId(*c.registry.RegistryArn) - srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) - env := c.di.Get(key.Env).(*env.Envars) - var svc *srvtypes.Service - var ns *srvtypes.Namespace - if o, err := srvCli.GetService(ctx, &servicediscovery.GetServiceInput{ - Id: &srvId, - }); err != nil { - return xerrors.Errorf("failed to get the service: %w", err) - } else { - svc = o.Service - } - if o, err := srvCli.GetNamespace(ctx, &servicediscovery.GetNamespaceInput{ - Id: svc.NamespaceId, - }); err != nil { - return xerrors.Errorf("failed to get the namespace: %w", err) - } else { - ns = o.Namespace - } - attrs := map[string]string{ - "AWS_INSTANCE_IPV4": target.targetIpv4, - "AVAILABILITY_ZONE": target.availabilityZone, - "AWS_INIT_HEALTH_STATUS": "UNHEALTHY", - "ECS_CLUSTER_NAME": env.Cluster, - "ECS_SERVICE_NAME": env.Service, - "ECS_TASK_DEFINITION_FAMILY": *c.TaskDefinition.Family, - "REGION": env.Region, - "CAGE_TASK_ID": ArnToId(*c.taskArn), - } - taskId := ArnToId(*c.taskArn) - if _, err := srvCli.RegisterInstance(ctx, &servicediscovery.RegisterInstanceInput{ - ServiceId: &srvId, - InstanceId: &taskId, - Attributes: attrs, - }); err != nil { - return xerrors.Errorf("failed to register the canary task to service discovery: %w", err) - } - c.srv = svc - c.instId = &taskId - c.ns = ns - return nil -} - -func (c *srvTask) waitUntilSrvInstHelthy( - ctx context.Context, -) error { - env := c.di.Get(key.Env).(*env.Envars) - timer := c.di.Get(key.Time).(types.Time) - srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) - var rest = env.TargetHealthCheck() - var waitPeriod = 15 * time.Second - for rest > 0 { - if rest < waitPeriod { - waitPeriod = rest - } - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.NewTimer(time.Duration(waitPeriod) * time.Second).C: - if list, err := srvCli.DiscoverInstances(ctx, &servicediscovery.DiscoverInstancesInput{ - NamespaceName: c.ns.Name, - ServiceName: c.srv.Name, - HealthStatus: srvtypes.HealthStatusFilterHealthy, - QueryParameters: map[string]string{ - "CAGE_TASK_ID": ArnToId(*c.taskArn), - }, - }); err != nil { - return xerrors.Errorf("failed to discover instances: %w", err) - } else { - for _, inst := range list.Instances { - if *inst.InstanceId == *c.instId { - return nil - } - } - rest -= waitPeriod - } - } - } - return xerrors.Errorf("timed out waiting for healthy instances") -} - -func (c *srvTask) deregisterSrvInst( - ctx context.Context, -) { - if c.instId == nil { - return - } - log.Info("deregistering the canary task from service discovery...") - srvCli := c.di.Get(key.SrvCli).(awsiface.SrvClient) - if _, err := srvCli.DeregisterInstance(ctx, &servicediscovery.DeregisterInstanceInput{ - ServiceId: c.srv.Id, - InstanceId: c.instId, - }); err != nil { - log.Errorf("failed to deregister the canary task from service discovery: %v", err) - log.Errorf("continuing to stop the canary task...") - } - log.Infof("canary task '%s' is deregistered from service discovery", *c.taskArn) -} diff --git a/task/srv_task_test.go b/task/srv_task_test.go deleted file mode 100644 index 5ddefc8..0000000 --- a/task/srv_task_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package task_test - -import ( - "context" - "testing" - - ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/task" - "github.com/loilo-inc/canarycage/test" - "github.com/loilo-inc/logos/di" - "github.com/stretchr/testify/assert" -) - -func TestSrvTask(t *testing.T) { - srvSvcName := "internal" - srvNsName := "dev.local" - registryArn := "arn://aaa/srv-" + srvSvcName - mocker := test.NewMockContext() - env := test.DefaultEnvars() - - ctx := context.TODO() - td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) - env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn - ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) - d := di.NewDomain(func(b *di.B) { - b.Set(key.Env, env) - b.Set(key.EcsCli, mocker.Ecs) - b.Set(key.Ec2Cli, mocker.Ec2) - b.Set(key.AlbCli, mocker.Alb) - b.Set(key.SrvCli, mocker.Srv) - b.Set(key.Time, test.NewFakeTime()) - }) - stask := task.NewSrvTask(d, &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, - }, &ecstypes.ServiceRegistry{RegistryArn: ®istryArn}) - _ = mocker.CreateSrvService(srvNsName, srvSvcName) - err := stask.Start(ctx) - assert.NoError(t, err) - taskId := task.ArnToId(*stask.TaskArn()) - mocker.PutSrvInstHealth(taskId, "healthy") - err = stask.Wait(ctx) - assert.NoError(t, err) - err = stask.Stop(ctx) - assert.NoError(t, err) - assert.Equal(t, 1, mocker.RunningTaskSize()) - assert.Equal(t, 0, len(mocker.SrvInsts[srvSvcName])) -} diff --git a/taskset/taskset.go b/taskset/taskset.go index d658187..df1b0a5 100644 --- a/taskset/taskset.go +++ b/taskset/taskset.go @@ -19,8 +19,7 @@ type set struct { type Input struct { *task.Input - LoadBalancers []ecstypes.LoadBalancer - ServiceRegistries []ecstypes.ServiceRegistry + LoadBalancers []ecstypes.LoadBalancer } func NewSet( @@ -36,10 +35,6 @@ func NewSet( task := factory.NewAlbTask(taskInput, &lb) results = append(results, task) } - for _, srv := range input.ServiceRegistries { - task := factory.NewSrvTask(taskInput, &srv) - results = append(results, task) - } if len(results) == 0 { task := factory.NewSimpleTask(taskInput) results = append(results, task) diff --git a/taskset/taskset_test.go b/taskset/taskset_test.go index 7a2d653..f433384 100644 --- a/taskset/taskset_test.go +++ b/taskset/taskset_test.go @@ -18,25 +18,16 @@ func TestSet(t *testing.T) { ctrl := gomock.NewController(t) factory := mock_task.NewMockFactory(ctrl) albTask := mock_task.NewMockTask(ctrl) - srvTask := mock_task.NewMockTask(ctrl) lb := ecstypes.LoadBalancer{} - srg := ecstypes.ServiceRegistry{} gomock.InOrder( factory.EXPECT().NewAlbTask(gomock.Any(), &lb).Return(albTask), - factory.EXPECT().NewSrvTask(gomock.Any(), &srg).Return(srvTask), - ) - gomock.InOrder( albTask.EXPECT().Start(gomock.Any()).Return(nil), - srvTask.EXPECT().Start(gomock.Any()).Return(nil), ) albTask.EXPECT().Wait(gomock.Any()).Return(nil) - srvTask.EXPECT().Wait(gomock.Any()).Return(nil) albTask.EXPECT().Stop(gomock.Any()).Return(nil) - srvTask.EXPECT().Stop(gomock.Any()).Return(nil) input := &taskset.Input{ - Input: &task.Input{}, - LoadBalancers: []ecstypes.LoadBalancer{lb}, - ServiceRegistries: []ecstypes.ServiceRegistry{srg}, + Input: &task.Input{}, + LoadBalancers: []ecstypes.LoadBalancer{lb}, } set := taskset.NewSet(factory, input) ctx := context.TODO() @@ -63,23 +54,15 @@ func TestSet(t *testing.T) { ctrl := gomock.NewController(t) factory := mock_task.NewMockFactory(ctrl) albTask := mock_task.NewMockTask(ctrl) - srvTask := mock_task.NewMockTask(ctrl) lb := ecstypes.LoadBalancer{} - srg := ecstypes.ServiceRegistry{} gomock.InOrder( factory.EXPECT().NewAlbTask(gomock.Any(), &lb).Return(albTask), - factory.EXPECT().NewSrvTask(gomock.Any(), &srg).Return(srvTask), - ) - gomock.InOrder( albTask.EXPECT().Start(gomock.Any()).Return(nil), - srvTask.EXPECT().Start(gomock.Any()).Return(nil), + albTask.EXPECT().Wait(gomock.Any()).Return(fmt.Errorf("error")), ) - albTask.EXPECT().Wait(gomock.Any()).Return(fmt.Errorf("error")) - srvTask.EXPECT().Wait(gomock.Any()).Return(nil) input := &taskset.Input{ - Input: &task.Input{}, - LoadBalancers: []ecstypes.LoadBalancer{lb}, - ServiceRegistries: []ecstypes.ServiceRegistry{srg}, + Input: &task.Input{}, + LoadBalancers: []ecstypes.LoadBalancer{lb}, } set := taskset.NewSet(factory, input) ctx := context.TODO() diff --git a/test/context.go b/test/context.go index 3f2e359..f0c9c02 100644 --- a/test/context.go +++ b/test/context.go @@ -4,7 +4,6 @@ import ( "sync" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" "github.com/loilo-inc/canarycage/awsiface" ) @@ -13,13 +12,7 @@ type commons struct { Tasks map[string]*ecstypes.Task TaskDefinitions *TaskDefinitionRepository TargetGroups map[string]struct{} - SrvNamespaces []*srvtypes.Namespace - SrvServices []*srvtypes.Service - // service.Name -> []*instance - SrvInsts map[string][]*srvtypes.Instance - // instance.Id -> HealthStatus - SrvInstHelths map[string]srvtypes.HealthStatus - mux sync.Mutex + mux sync.Mutex } type MockContext struct { @@ -27,7 +20,6 @@ type MockContext struct { Ecs awsiface.EcsClient Alb awsiface.AlbClient Ec2 awsiface.Ec2Client - Srv awsiface.SrvClient } func NewMockContext() *MockContext { @@ -37,16 +29,12 @@ func NewMockContext() *MockContext { TaskDefinitions: &TaskDefinitionRepository{ families: make(map[string]*TaskDefinitionFamily), }, - TargetGroups: make(map[string]struct{}), - SrvServices: make([]*srvtypes.Service, 0), - SrvInsts: make(map[string][]*srvtypes.Instance), - SrvInstHelths: make(map[string]srvtypes.HealthStatus), + TargetGroups: make(map[string]struct{}), } return &MockContext{ commons: cm, Ecs: &EcsServer{commons: cm}, Ec2: &Ec2Server{commons: cm}, - Srv: &SrvServer{commons: cm}, Alb: &AlbServer{commons: cm}, } } diff --git a/test/srv.go b/test/srv.go deleted file mode 100644 index 08887e3..0000000 --- a/test/srv.go +++ /dev/null @@ -1,139 +0,0 @@ -package test - -import ( - "context" - "fmt" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/servicediscovery" - srvtypes "github.com/aws/aws-sdk-go-v2/service/servicediscovery/types" - "golang.org/x/xerrors" -) - -type SrvServer struct { - *commons -} - -func (s *SrvServer) getServiceById(id string) *srvtypes.Service { - for _, svc := range s.SrvServices { - if *svc.Id == id { - return svc - } - } - return nil -} - -func (s *commons) CreateSrvService( - namepsaceName string, - serviceName string) *srvtypes.Service { - nsId := fmt.Sprintf("ns-%s", namepsaceName) - ns := &srvtypes.Namespace{ - Id: &nsId, - Name: &namepsaceName, - Arn: aws.String(fmt.Sprintf("arn:aws:servicediscovery:ap-northeast-1:123456789012:namespace/%s", nsId)), - } - svId := fmt.Sprintf("srv-%s", serviceName) - svc := &srvtypes.Service{ - NamespaceId: ns.Id, - Id: &svId, - Name: &serviceName, - Arn: aws.String(fmt.Sprintf("arn:aws:servicediscovery:ap-northeast-1:123456789012:service/%s", svId)), - InstanceCount: aws.Int32(0), - } - s.SrvNamespaces = append(s.SrvNamespaces, ns) - s.SrvServices = append(s.SrvServices, svc) - return svc -} - -func (s *commons) PutSrvInstHealth(id string, health srvtypes.HealthStatus) { - s.SrvInstHelths[id] = health -} - -func (s *SrvServer) DiscoverInstances(ctx context.Context, params *servicediscovery.DiscoverInstancesInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DiscoverInstancesOutput, error) { - insts, ok := s.SrvInsts[*params.ServiceName] - if !ok { - return nil, xerrors.Errorf("service not found: %s", *params.ServiceName) - } - var summories []srvtypes.HttpInstanceSummary - for _, inst := range insts { - health := s.SrvInstHelths[*inst.Id] - if !matchInst(inst, params) { - continue - } - switch params.HealthStatus { - case srvtypes.HealthStatusFilterHealthy: - if health != srvtypes.HealthStatusHealthy { - continue - } - case srvtypes.HealthStatusFilterUnhealthy: - if health != srvtypes.HealthStatusUnhealthy { - continue - } - } - summories = append(summories, srvtypes.HttpInstanceSummary{ - Attributes: inst.Attributes, - InstanceId: inst.Id, - ServiceName: params.ServiceName, - NamespaceName: params.NamespaceName, - HealthStatus: health, - }) - } - return &servicediscovery.DiscoverInstancesOutput{Instances: summories}, nil -} - -func matchInst(inst *srvtypes.Instance, params *servicediscovery.DiscoverInstancesInput) bool { - for k, v := range params.QueryParameters { - if act, ok := inst.Attributes[k]; !ok || act != v { - return false - } - } - return true -} - -func (s *SrvServer) RegisterInstance(ctx context.Context, params *servicediscovery.RegisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.RegisterInstanceOutput, error) { - if srv := s.getServiceById(*params.ServiceId); srv == nil { - return nil, xerrors.Errorf("service not found: %s", *params.ServiceId) - } else { - inst := &srvtypes.Instance{ - Id: params.InstanceId, - Attributes: params.Attributes, - } - s.SrvInsts[*srv.Name] = append(s.SrvInsts[*params.ServiceId], inst) - s.SrvInstHelths[*params.InstanceId] = srvtypes.HealthStatusHealthy - return &servicediscovery.RegisterInstanceOutput{}, nil - } -} - -func (s *SrvServer) DeregisterInstance(ctx context.Context, params *servicediscovery.DeregisterInstanceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.DeregisterInstanceOutput, error) { - if srv := s.getServiceById(*params.ServiceId); srv == nil { - return nil, xerrors.Errorf("service not found: %s", *params.ServiceId) - } else { - insts := s.SrvInsts[*srv.Name] - for i, inst := range insts { - if *inst.Id == *params.InstanceId { - insts = append(insts[:i], insts[i+1:]...) - s.SrvInsts[*srv.Name] = insts - delete(s.SrvInstHelths, *params.InstanceId) - return &servicediscovery.DeregisterInstanceOutput{}, nil - } - } - return nil, xerrors.Errorf("instance not found: %s", *params.InstanceId) - } -} - -func (s *SrvServer) GetService(ctx context.Context, params *servicediscovery.GetServiceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetServiceOutput, error) { - svc := s.getServiceById(*params.Id) - if svc == nil { - return nil, xerrors.Errorf("service not found: %s", *params.Id) - } - return &servicediscovery.GetServiceOutput{Service: svc}, nil -} - -func (s *SrvServer) GetNamespace(ctx context.Context, params *servicediscovery.GetNamespaceInput, optFns ...func(*servicediscovery.Options)) (*servicediscovery.GetNamespaceOutput, error) { - for _, ns := range s.SrvNamespaces { - if *ns.Id == *params.Id { - return &servicediscovery.GetNamespaceOutput{Namespace: ns}, nil - } - } - return nil, xerrors.Errorf("namespace not found: %s", *params.Id) -} From bfd3b71eb69fd454abefdc959466daf50996ce7d Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 2 Jul 2024 20:09:15 +0900 Subject: [PATCH 23/47] a --- cli/cage/commands/flags.go | 11 ----------- codecov.yml | 1 + 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/cli/cage/commands/flags.go b/cli/cage/commands/flags.go index f60cf4a..0e228b5 100644 --- a/cli/cage/commands/flags.go +++ b/cli/cage/commands/flags.go @@ -92,14 +92,3 @@ func ServiceStableWaitFlag(dest *int) *cli.IntFlag { Value: 900, } } - -func TargetHealthCheckWaitFlag(dest *int) *cli.IntFlag { - return &cli.IntFlag{ - Name: "targetHealthCheckTimeout", - EnvVars: []string{env.TargetHealthCheckTimeout}, - Usage: "max duration seconds for waiting target health check", - Destination: dest, - Category: "ADVANCED", - Value: 900, - } -} diff --git a/codecov.yml b/codecov.yml index f932a60..34c0338 100644 --- a/codecov.yml +++ b/codecov.yml @@ -2,3 +2,4 @@ ignore: - "test" - "mocks" - "cli/cage/main.go" + - "timeout/time.go" From 810cffacc52cf2ea54193a6075796985e5174af4 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 2 Jul 2024 20:18:26 +0900 Subject: [PATCH 24/47] rename --- env/timeout.go | 8 ++++---- env/timeout_test.go | 16 ++++++++-------- rollout.go | 2 +- run.go | 4 ++-- task/common.go | 6 +++--- up.go | 2 +- 6 files changed, 19 insertions(+), 19 deletions(-) diff --git a/env/timeout.go b/env/timeout.go index b08175c..245bfa3 100644 --- a/env/timeout.go +++ b/env/timeout.go @@ -6,7 +6,7 @@ import ( const defaultTimeout = 15 * time.Minute -func (t *Envars) TaskRunning() time.Duration { +func (t *Envars) GetTaskRunningWait() time.Duration { wait := t.CanaryTaskRunningWait if wait > 0 { return time.Duration(wait) * time.Second @@ -14,7 +14,7 @@ func (t *Envars) TaskRunning() time.Duration { return defaultTimeout } -func (t *Envars) TaskHealthCheck() time.Duration { +func (t *Envars) GetTaskHealthCheckWait() time.Duration { wait := t.CanaryTaskHealthCheckWait if wait > 0 { return time.Duration(wait) * time.Second @@ -22,7 +22,7 @@ func (t *Envars) TaskHealthCheck() time.Duration { return defaultTimeout } -func (t *Envars) TaskStopped() time.Duration { +func (t *Envars) GetTaskStoppedWait() time.Duration { wait := t.CanaryTaskStoppedWait if wait > 0 { return time.Duration(wait) * time.Second @@ -30,7 +30,7 @@ func (t *Envars) TaskStopped() time.Duration { return defaultTimeout } -func (t *Envars) ServiceStable() time.Duration { +func (t *Envars) GetServiceStableWait() time.Duration { wait := t.ServiceStableWait if wait > 0 { return time.Duration(wait) * time.Second diff --git a/env/timeout_test.go b/env/timeout_test.go index d7f1467..101b1d2 100644 --- a/env/timeout_test.go +++ b/env/timeout_test.go @@ -11,10 +11,10 @@ import ( func TestManager(t *testing.T) { t.Run("no config", func(t *testing.T) { man := &env.Envars{} - assert.Equal(t, 15*time.Minute, man.TaskRunning()) - assert.Equal(t, 15*time.Minute, man.TaskStopped()) - assert.Equal(t, 15*time.Minute, man.TaskHealthCheck()) - assert.Equal(t, 15*time.Minute, man.ServiceStable()) + assert.Equal(t, 15*time.Minute, man.GetTaskRunningWait()) + assert.Equal(t, 15*time.Minute, man.GetTaskStoppedWait()) + assert.Equal(t, 15*time.Minute, man.GetTaskHealthCheckWait()) + assert.Equal(t, 15*time.Minute, man.GetServiceStableWait()) }) t.Run("with config", func(t *testing.T) { man := &env.Envars{ @@ -23,9 +23,9 @@ func TestManager(t *testing.T) { CanaryTaskHealthCheckWait: 3, ServiceStableWait: 4, } - assert.Equal(t, 1*time.Second, man.TaskRunning()) - assert.Equal(t, 2*time.Second, man.TaskStopped()) - assert.Equal(t, 3*time.Second, man.TaskHealthCheck()) - assert.Equal(t, 4*time.Second, man.ServiceStable()) + assert.Equal(t, 1*time.Second, man.GetTaskRunningWait()) + assert.Equal(t, 2*time.Second, man.GetTaskStoppedWait()) + assert.Equal(t, 3*time.Second, man.GetTaskHealthCheckWait()) + assert.Equal(t, 4*time.Second, man.GetServiceStableWait()) }) } diff --git a/rollout.go b/rollout.go index fb4f041..efc6822 100644 --- a/rollout.go +++ b/rollout.go @@ -90,7 +90,7 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ Cluster: &env.Cluster, Services: []string{env.Service}, - }, env.ServiceStable()); err != nil { + }, env.GetServiceStableWait()); err != nil { return result, err } log.Infof("🥴 service '%s' has become to be stable!", env.Service) diff --git a/run.go b/run.go index dde8692..21a1c8e 100644 --- a/run.go +++ b/run.go @@ -50,7 +50,7 @@ func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*taskArn}, - }, env.TaskRunning()); err != nil { + }, env.GetTaskRunningWait()); err != nil { return nil, xerrors.Errorf("task failed to start: %w", err) } log.Infof("task '%s' is running", *taskArn) @@ -58,7 +58,7 @@ func (c *cage) Run(ctx context.Context, input *types.RunInput) (*types.RunResult if result, err := ecs.NewTasksStoppedWaiter(ecsCli).WaitForOutput(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*taskArn}, - }, env.TaskStopped()); err != nil { + }, env.GetTaskStoppedWait()); err != nil { return nil, xerrors.Errorf("task failed to stop: %w", err) } else { task := result.Tasks[0] diff --git a/task/common.go b/task/common.go index 2a352ba..4d076ad 100644 --- a/task/common.go +++ b/task/common.go @@ -86,7 +86,7 @@ func (c *common) waitForTask(ctx context.Context) error { if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, - }, env.TaskRunning()); err != nil { + }, env.GetTaskRunningWait()); err != nil { return err } log.Infof("🐣 canary task '%s' is running!", *c.taskArn) @@ -109,7 +109,7 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { containerHasHealthChecks[*definition.Name] = struct{}{} } } - rest := env.TaskHealthCheck() + rest := env.GetTaskHealthCheckWait() healthCheckPeriod := 15 * time.Second for rest > 0 { if rest < healthCheckPeriod { @@ -257,7 +257,7 @@ func (c *common) stopTask(ctx context.Context) error { if err := ecs.NewTasksStoppedWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*c.taskArn}, - }, env.TaskStopped()); err != nil { + }, env.GetTaskStoppedWait()); err != nil { return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) } log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) diff --git a/up.go b/up.go index a66d45c..f7a460c 100644 --- a/up.go +++ b/up.go @@ -52,7 +52,7 @@ func (c *cage) createService(ctx context.Context, serviceDefinitionInput *ecs.Cr if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ Cluster: &env.Cluster, Services: []string{*serviceDefinitionInput.ServiceName}, - }, env.ServiceStable()); err != nil { + }, env.GetServiceStableWait()); err != nil { return nil, xerrors.Errorf("failed to wait for service '%s' to be STABLE: %w", *serviceDefinitionInput.ServiceName, err) } return o.Service, nil From 3172b6e0b51916cba86fc82beede822087c1cd1f Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 2 Jul 2024 20:26:39 +0900 Subject: [PATCH 25/47] add test --- env/timeout.go | 8 ++++++++ env/timeout_test.go | 25 ++++++++++++++----------- task/simple_task.go | 2 +- task/simple_task_test.go | 1 + 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/env/timeout.go b/env/timeout.go index 245bfa3..0d32fa0 100644 --- a/env/timeout.go +++ b/env/timeout.go @@ -6,6 +6,14 @@ import ( const defaultTimeout = 15 * time.Minute +func (t *Envars) GetCanaryTaskIdleWait() time.Duration { + wait := t.CanaryTaskIdleDuration + if wait > 0 { + return time.Duration(wait) * time.Second + } + return 0 +} + func (t *Envars) GetTaskRunningWait() time.Duration { wait := t.CanaryTaskRunningWait if wait > 0 { diff --git a/env/timeout_test.go b/env/timeout_test.go index 101b1d2..b8990a3 100644 --- a/env/timeout_test.go +++ b/env/timeout_test.go @@ -8,24 +8,27 @@ import ( "github.com/stretchr/testify/assert" ) -func TestManager(t *testing.T) { +func TestEnv_Timeout(t *testing.T) { t.Run("no config", func(t *testing.T) { - man := &env.Envars{} - assert.Equal(t, 15*time.Minute, man.GetTaskRunningWait()) - assert.Equal(t, 15*time.Minute, man.GetTaskStoppedWait()) - assert.Equal(t, 15*time.Minute, man.GetTaskHealthCheckWait()) - assert.Equal(t, 15*time.Minute, man.GetServiceStableWait()) + e := &env.Envars{} + assert.Equal(t, 15*time.Minute, e.GetTaskRunningWait()) + assert.Equal(t, 15*time.Minute, e.GetTaskStoppedWait()) + assert.Equal(t, 15*time.Minute, e.GetTaskHealthCheckWait()) + assert.Equal(t, 15*time.Minute, e.GetServiceStableWait()) + assert.Equal(t, time.Duration(0), e.GetCanaryTaskIdleWait()) }) t.Run("with config", func(t *testing.T) { - man := &env.Envars{ + e := &env.Envars{ CanaryTaskRunningWait: 1, CanaryTaskStoppedWait: 2, CanaryTaskHealthCheckWait: 3, ServiceStableWait: 4, + CanaryTaskIdleDuration: 5, } - assert.Equal(t, 1*time.Second, man.GetTaskRunningWait()) - assert.Equal(t, 2*time.Second, man.GetTaskStoppedWait()) - assert.Equal(t, 3*time.Second, man.GetTaskHealthCheckWait()) - assert.Equal(t, 4*time.Second, man.GetServiceStableWait()) + assert.Equal(t, 1*time.Second, e.GetTaskRunningWait()) + assert.Equal(t, 2*time.Second, e.GetTaskStoppedWait()) + assert.Equal(t, 3*time.Second, e.GetTaskHealthCheckWait()) + assert.Equal(t, 4*time.Second, e.GetServiceStableWait()) + assert.Equal(t, 5*time.Second, e.GetCanaryTaskIdleWait()) }) } diff --git a/task/simple_task.go b/task/simple_task.go index e40d2c4..12f6916 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -38,7 +38,7 @@ func (c *simpleTask) waitForIdleDuration(ctx context.Context) error { env := c.di.Get(key.Env).(*env.Envars) timer := c.di.Get(key.Time).(types.Time) log.Infof("wait %d seconds for canary task to be stable...", env.CanaryTaskIdleDuration) - rest := time.Duration(env.CanaryTaskIdleDuration) * time.Second + rest := env.GetCanaryTaskIdleWait() waitPeriod := 15 * time.Second for rest > 0 { if rest < waitPeriod { diff --git a/task/simple_task_test.go b/task/simple_task_test.go index da41c55..fb477fb 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -17,6 +17,7 @@ func TestSimpleTask(t *testing.T) { env := test.DefaultEnvars() td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn + env.CanaryTaskIdleDuration = 10 ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) d := di.NewDomain(func(b *di.B) { b.Set(key.Env, env) From 730399f10be4426f8b5228129cfc5672df61dff5 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 17:12:26 +0900 Subject: [PATCH 26/47] Update alb_task_test.go --- task/alb_task_test.go | 72 ++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 512f8ae..4dbf3e2 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -5,6 +5,7 @@ import ( "testing" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" @@ -13,31 +14,52 @@ import ( ) func TestAlbTask(t *testing.T) { - mocker := test.NewMockContext() - env := test.DefaultEnvars() - ctx := context.TODO() - td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) - env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn - ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) - d := di.NewDomain(func(b *di.B) { - b.Set(key.Env, env) - b.Set(key.EcsCli, mocker.Ecs) - b.Set(key.Ec2Cli, mocker.Ec2) - b.Set(key.AlbCli, mocker.Alb) - b.Set(key.Time, test.NewFakeTime()) + setup := func(env *env.Envars) (task.Task, *test.MockContext) { + mocker := test.NewMockContext() + ctx := context.TODO() + td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) + env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn + ecsSvc, _ := mocker.Ecs.CreateService(ctx, env.ServiceDefinitionInput) + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, mocker.Ecs) + b.Set(key.Ec2Cli, mocker.Ec2) + b.Set(key.AlbCli, mocker.Alb) + b.Set(key.Time, test.NewFakeTime()) + }) + stask := task.NewAlbTask(d, &task.Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, + }, &ecsSvc.Service.LoadBalancers[0]) + mocker.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ + TargetGroupArn: ecsSvc.Service.LoadBalancers[0].TargetGroupArn, + }) + return stask, mocker + } + t.Run("fargate", func(t *testing.T) { + env := test.DefaultEnvars() + stask, mocker := setup(env) + ctx := context.TODO() + err := stask.Start(ctx) + assert.NoError(t, err) + err = stask.Wait(ctx) + assert.NoError(t, err) + err = stask.Stop(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, mocker.RunningTaskSize()) }) - stask := task.NewAlbTask(d, &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, - }, &ecsSvc.Service.LoadBalancers[0]) - mocker.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ - TargetGroupArn: ecsSvc.Service.LoadBalancers[0].TargetGroupArn, + t.Run("ec2", func(t *testing.T) { + env := test.DefaultEnvars() + env.CanaryInstanceArn = "arn://ec2" + stask, mocker := setup(env) + ctx := context.TODO() + err := stask.Start(ctx) + assert.NoError(t, err) + err = stask.Wait(ctx) + assert.NoError(t, err) + err = stask.Stop(ctx) + assert.NoError(t, err) + assert.Equal(t, 1, mocker.RunningTaskSize()) }) - err := stask.Start(ctx) - assert.NoError(t, err) - err = stask.Wait(ctx) - assert.NoError(t, err) - err = stask.Stop(ctx) - assert.NoError(t, err) - assert.Equal(t, 1, mocker.RunningTaskSize()) + } From ef4755e491e62833e239c7d283c160a0ce3c533b Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 17:52:14 +0900 Subject: [PATCH 27/47] add tests --- Makefile | 2 +- env/timeout.go | 2 +- mocks/mock_task/task.go | 14 ---- mocks/mock_types/iface.go | 14 ---- task/alb_task.go | 157 ++++++++++++++++++++++++++++++-------- task/alb_task_test.go | 1 - task/common.go | 97 ----------------------- task/common_test.go | 4 - task/export_test.go | 15 +++- task/simple_task.go | 4 +- task/simple_task_test.go | 91 ++++++++++++++++++++++ task/task.go | 1 - 12 files changed, 235 insertions(+), 167 deletions(-) diff --git a/Makefile b/Makefile index a901369..e96d26f 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ mocks: mocks/mock_awsiface/iface.go \ mocks/mock_task/factory.go mocks/mock_awsiface/iface.go: awsiface/iface.go $(MOCKGEN) -source=./awsiface/iface.go > mocks/mock_awsiface/iface.go -mocks/mock_types/iface.go: cage.go +mocks/mock_types/iface.go: types/iface.go $(MOCKGEN) -source=./types/iface.go > mocks/mock_types/iface.go mocks/mock_upgrade/upgrade.go: cli/cage/upgrade/upgrade.go $(MOCKGEN) -source=./cli/cage/upgrade/upgrade.go > mocks/mock_upgrade/upgrade.go diff --git a/env/timeout.go b/env/timeout.go index 0d32fa0..53d0956 100644 --- a/env/timeout.go +++ b/env/timeout.go @@ -11,7 +11,7 @@ func (t *Envars) GetCanaryTaskIdleWait() time.Duration { if wait > 0 { return time.Duration(wait) * time.Second } - return 0 + return 15 * time.Second } func (t *Envars) GetTaskRunningWait() time.Duration { diff --git a/mocks/mock_task/task.go b/mocks/mock_task/task.go index efc19e6..fa6ea25 100644 --- a/mocks/mock_task/task.go +++ b/mocks/mock_task/task.go @@ -62,20 +62,6 @@ func (mr *MockTaskMockRecorder) Stop(ctx interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Stop", reflect.TypeOf((*MockTask)(nil).Stop), ctx) } -// TaskArn mocks base method. -func (m *MockTask) TaskArn() *string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "TaskArn") - ret0, _ := ret[0].(*string) - return ret0 -} - -// TaskArn indicates an expected call of TaskArn. -func (mr *MockTaskMockRecorder) TaskArn() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TaskArn", reflect.TypeOf((*MockTask)(nil).TaskArn)) -} - // Wait mocks base method. func (m *MockTask) Wait(ctx context.Context) error { m.ctrl.T.Helper() diff --git a/mocks/mock_types/iface.go b/mocks/mock_types/iface.go index e473b4d..29f83b1 100644 --- a/mocks/mock_types/iface.go +++ b/mocks/mock_types/iface.go @@ -117,17 +117,3 @@ func (mr *MockTimeMockRecorder) NewTimer(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewTimer", reflect.TypeOf((*MockTime)(nil).NewTimer), arg0) } - -// Now mocks base method. -func (m *MockTime) Now() time.Time { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Now") - ret0, _ := ret[0].(time.Time) - return ret0 -} - -// Now indicates an expected call of Now. -func (mr *MockTimeMockRecorder) Now() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Now", reflect.TypeOf((*MockTime)(nil).Now)) -} diff --git a/task/alb_task.go b/task/alb_task.go index 2cfaabe..c0ff35a 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -6,10 +6,14 @@ import ( "time" "github.com/apex/log" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/types" "github.com/loilo-inc/logos/di" @@ -19,8 +23,8 @@ import ( // albTask is a task that is attached to an Application Load Balancer type albTask struct { *common - lb *ecstypes.LoadBalancer - target *CanaryTarget + Lb *ecstypes.LoadBalancer + Target *CanaryTarget } func NewAlbTask( @@ -30,7 +34,7 @@ func NewAlbTask( ) Task { return &albTask{ common: &common{Input: input, di: di}, - lb: lb, + Lb: lb, } } @@ -41,7 +45,7 @@ func (c *albTask) Wait(ctx context.Context) error { if err := c.registerToTargetGroup(ctx); err != nil { return err } - log.Infof("canary task '%s' is registered to target group '%s'", c.target.targetId, *c.lb.TargetGroupArn) + log.Infof("canary task '%s' is registered to target group '%s'", c.Target.targetId, *c.Lb.TargetGroupArn) log.Infof("😷 waiting canary target to be healthy...") if err := c.waitUntilTargetHealthy(ctx); err != nil { return err @@ -55,9 +59,100 @@ func (c *albTask) Stop(ctx context.Context) error { return c.stopTask(ctx) } +func (c *albTask) describeTaskTarget( + ctx context.Context, + targetPort int32, +) (*CanaryTarget, error) { + env := c.di.Get(key.Env).(*env.Envars) + target := CanaryTarget{targetPort: targetPort} + if env.CanaryInstanceArn == "" { // Fargate + if err := c.getFargateTarget(ctx, &target); err != nil { + return nil, err + } + log.Infof("canary task was placed: privateIp = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + } else { + if err := c.getEc2Target(ctx, &target); err != nil { + return nil, err + } + log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + } + return &target, nil +} + +func (c *albTask) getFargateTarget(ctx context.Context, dest *CanaryTarget) error { + var task ecstypes.Task + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) + if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ + Cluster: &env.Cluster, + Tasks: []string{*c.taskArn}, + }); err != nil { + return err + } else { + task = o.Tasks[0] + } + details := task.Attachments[0].Details + var subnetId *string + var privateIp *string + for _, v := range details { + if *v.Name == "subnetId" { + subnetId = v.Value + } else if *v.Name == "privateIPv4Address" { + privateIp = v.Value + } + } + if subnetId == nil || privateIp == nil { + return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") + } + if o, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: []string{*subnetId}, + }); err != nil { + return err + } else { + dest.targetId = *privateIp + dest.targetIpv4 = *privateIp + dest.availabilityZone = *o.Subnets[0].AvailabilityZone + } + return nil +} + +func (c *albTask) getEc2Target(ctx context.Context, dest *CanaryTarget) error { + var containerInstance ecstypes.ContainerInstance + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) + if outputs, err := ecsCli.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ + Cluster: &env.Cluster, + ContainerInstances: []string{env.CanaryInstanceArn}, + }); err != nil { + return err + } else { + containerInstance = outputs.ContainerInstances[0] + } + var ec2Instance ec2types.Instance + if o, err := ec2Cli.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + InstanceIds: []string{*containerInstance.Ec2InstanceId}, + }); err != nil { + return err + } else { + ec2Instance = o.Reservations[0].Instances[0] + } + if sn, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: []string{*ec2Instance.SubnetId}, + }); err != nil { + return err + } else { + dest.targetId = *containerInstance.Ec2InstanceId + dest.targetIpv4 = *ec2Instance.PrivateIpAddress + dest.availabilityZone = *sn.Subnets[0].AvailabilityZone + } + return nil +} + func (c *albTask) getTargetPort() (int32, error) { for _, container := range c.TaskDefinition.ContainerDefinitions { - if *container.Name == *c.lb.ContainerName { + if *container.Name == *c.Lb.ContainerName { return *container.PortMappings[0].HostPort, nil } } @@ -65,21 +160,21 @@ func (c *albTask) getTargetPort() (int32, error) { } func (c *albTask) registerToTargetGroup(ctx context.Context) error { - log.Infof("registering the canary task to target group '%s'...", *c.lb.TargetGroupArn) + log.Infof("registering the canary task to target group '%s'...", *c.Lb.TargetGroupArn) if targetPort, err := c.getTargetPort(); err != nil { return err } else if target, err := c.describeTaskTarget(ctx, targetPort); err != nil { return err } else { - c.target = target + c.Target = target } albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) if _, err := albCli.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ - TargetGroupArn: c.lb.TargetGroupArn, + TargetGroupArn: c.Lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.target.availabilityZone, - Id: &c.target.targetId, - Port: &c.target.targetPort, + AvailabilityZone: &c.Target.availabilityZone, + Id: &c.Target.targetId, + Port: &c.Target.targetPort, }}, }); err != nil { return err @@ -102,24 +197,24 @@ func (c *albTask) waitUntilTargetHealthy( return ctx.Err() case <-timer.NewTimer(waitPeriod).C: if o, err := albCli.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.lb.TargetGroupArn, + TargetGroupArn: c.Lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ - Id: &c.target.targetId, - Port: &c.target.targetPort, - AvailabilityZone: &c.target.availabilityZone, + Id: &c.Target.targetId, + Port: &c.Target.targetPort, + AvailabilityZone: &c.Target.availabilityZone, }}, }); err != nil { return err } else { for _, desc := range o.TargetHealthDescriptions { - if *desc.Target.Id == c.target.targetId && *desc.Target.Port == c.target.targetPort { + if *desc.Target.Id == c.Target.targetId && *desc.Target.Port == c.Target.targetPort { recentState = &desc.TargetHealth.State } } if recentState == nil { - return xerrors.Errorf("'%s' is not registered to the target group '%s'", c.target.targetId, *c.lb.TargetGroupArn) + return xerrors.Errorf("'%s' is not registered to the target group '%s'", c.Target.targetId, *c.Lb.TargetGroupArn) } - log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.target.targetId, c.target.targetPort, *recentState) + log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.Target.targetId, c.Target.targetPort, *recentState) switch *recentState { case elbv2types.TargetHealthStateEnumHealthy: return nil @@ -133,7 +228,7 @@ func (c *albTask) waitUntilTargetHealthy( log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) return xerrors.Errorf( "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", - *c.taskArn, c.target.targetId, c.target.targetPort, *recentState, + *c.taskArn, c.Target.targetId, c.Target.targetPort, *recentState, ) } @@ -141,7 +236,7 @@ func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, deregistrationDelay := 300 * time.Second albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) if o, err := albCli.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ - TargetGroupArn: c.lb.TargetGroupArn, + TargetGroupArn: c.Lb.TargetGroupArn, }); err != nil { return deregistrationDelay, err } else { @@ -160,7 +255,7 @@ func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, } func (c *albTask) deregisterTarget(ctx context.Context) { - if c.target == nil { + if c.Target == nil { return } albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) @@ -169,13 +264,13 @@ func (c *albTask) deregisterTarget(ctx context.Context) { log.Errorf("failed to get deregistration delay: %v", err) log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) } - log.Infof("deregistering the canary task from target group '%s'...", c.target.targetId) + log.Infof("deregistering the canary task from target group '%s'...", c.Target.targetId) if _, err := albCli.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ - TargetGroupArn: c.lb.TargetGroupArn, + TargetGroupArn: c.Lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.target.availabilityZone, - Id: &c.target.targetId, - Port: &c.target.targetPort, + AvailabilityZone: &c.Target.availabilityZone, + Id: &c.Target.targetId, + Port: &c.Target.targetPort, }}, }); err != nil { log.Errorf("failed to deregister the canary task from target group: %v", err) @@ -184,11 +279,11 @@ func (c *albTask) deregisterTarget(ctx context.Context) { log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety if err := elbv2.NewTargetDeregisteredWaiter(albCli).Wait(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.lb.TargetGroupArn, + TargetGroupArn: c.Lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.target.availabilityZone, - Id: &c.target.targetId, - Port: &c.target.targetPort, + AvailabilityZone: &c.Target.availabilityZone, + Id: &c.Target.targetId, + Port: &c.Target.targetPort, }}, }, deregisterWait); err != nil { log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) @@ -196,7 +291,7 @@ func (c *albTask) deregisterTarget(ctx context.Context) { } else { log.Infof( "canary task '%s' has successfully been deregistered from target group '%s'", - *c.taskArn, c.target.targetId, + *c.taskArn, c.Target.targetId, ) } } diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 4dbf3e2..61e1c9f 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -61,5 +61,4 @@ func TestAlbTask(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, mocker.RunningTaskSize()) }) - } diff --git a/task/common.go b/task/common.go index 4d076ad..7af52aa 100644 --- a/task/common.go +++ b/task/common.go @@ -6,8 +6,6 @@ import ( "time" "github.com/apex/log" - "github.com/aws/aws-sdk-go-v2/service/ec2" - ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/loilo-inc/canarycage/awsiface" @@ -72,10 +70,6 @@ func (c *common) Start(ctx context.Context) error { return nil } -func (c *common) TaskArn() *string { - return c.taskArn -} - func (c *common) waitForTask(ctx context.Context) error { if c.taskArn == nil { return xerrors.New("task is not started") @@ -150,97 +144,6 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { return xerrors.Errorf("😨 canary task hasn't become to be healthy") } -func (c *common) describeTaskTarget( - ctx context.Context, - targetPort int32, -) (*CanaryTarget, error) { - env := c.di.Get(key.Env).(*env.Envars) - target := CanaryTarget{targetPort: targetPort} - if env.CanaryInstanceArn == "" { // Fargate - if err := c.getFargateTarget(ctx, &target); err != nil { - return nil, err - } - log.Infof("canary task was placed: privateIp = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) - } else { - if err := c.getEc2Target(ctx, &target); err != nil { - return nil, err - } - log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) - } - return &target, nil -} - -func (c *common) getFargateTarget(ctx context.Context, dest *CanaryTarget) error { - var task ecstypes.Task - env := c.di.Get(key.Env).(*env.Envars) - ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) - if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ - Cluster: &env.Cluster, - Tasks: []string{*c.taskArn}, - }); err != nil { - return err - } else { - task = o.Tasks[0] - } - details := task.Attachments[0].Details - var subnetId *string - var privateIp *string - for _, v := range details { - if *v.Name == "subnetId" { - subnetId = v.Value - } else if *v.Name == "privateIPv4Address" { - privateIp = v.Value - } - } - if subnetId == nil || privateIp == nil { - return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") - } - if o, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*subnetId}, - }); err != nil { - return err - } else { - dest.targetId = *privateIp - dest.targetIpv4 = *privateIp - dest.availabilityZone = *o.Subnets[0].AvailabilityZone - } - return nil -} - -func (c *common) getEc2Target(ctx context.Context, dest *CanaryTarget) error { - var containerInstance ecstypes.ContainerInstance - env := c.di.Get(key.Env).(*env.Envars) - ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) - if outputs, err := ecsCli.DescribeContainerInstances(ctx, &ecs.DescribeContainerInstancesInput{ - Cluster: &env.Cluster, - ContainerInstances: []string{env.CanaryInstanceArn}, - }); err != nil { - return err - } else { - containerInstance = outputs.ContainerInstances[0] - } - var ec2Instance ec2types.Instance - if o, err := ec2Cli.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ - InstanceIds: []string{*containerInstance.Ec2InstanceId}, - }); err != nil { - return err - } else { - ec2Instance = o.Reservations[0].Instances[0] - } - if sn, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*ec2Instance.SubnetId}, - }); err != nil { - return err - } else { - dest.targetId = *containerInstance.Ec2InstanceId - dest.targetIpv4 = *ec2Instance.PrivateIpAddress - dest.availabilityZone = *sn.Subnets[0].AvailabilityZone - } - return nil -} - func (c *common) stopTask(ctx context.Context) error { if c.taskArn == nil { return nil diff --git a/task/common_test.go b/task/common_test.go index cf68cda..3f643bf 100644 --- a/task/common_test.go +++ b/task/common_test.go @@ -33,7 +33,6 @@ func TestCommon_Start(t *testing.T) { }), &task.Input{TaskDefinition: td}) err := cm.Start(context.TODO()) assert.NoError(t, err) - assert.Equal(t, "task-arn", *cm.TaskArn()) }) t.Run("should error if task failed to start", func(t *testing.T) { ctrl := gomock.NewController(t) @@ -47,7 +46,6 @@ func TestCommon_Start(t *testing.T) { }), &task.Input{TaskDefinition: td}) err := cm.Start(context.TODO()) assert.EqualError(t, err, "error") - assert.Nil(t, cm.TaskArn()) }) }) t.Run("EC2", func(t *testing.T) { @@ -66,7 +64,6 @@ func TestCommon_Start(t *testing.T) { }), &task.Input{TaskDefinition: td}) err := cm.Start(context.TODO()) assert.NoError(t, err) - assert.Equal(t, "task-arn", *cm.TaskArn()) }) t.Run("should error if task failed to start", func(t *testing.T) { ctrl := gomock.NewController(t) @@ -81,7 +78,6 @@ func TestCommon_Start(t *testing.T) { }), &task.Input{TaskDefinition: td}) err := cm.Start(context.TODO()) assert.EqualError(t, err, "error") - assert.Nil(t, cm.TaskArn()) }) }) } diff --git a/task/export_test.go b/task/export_test.go index 2b73a37..b193a42 100644 --- a/task/export_test.go +++ b/task/export_test.go @@ -1,9 +1,22 @@ package task -import "github.com/loilo-inc/logos/di" +import ( + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/logos/di" +) type CommonExport = common +type AlbTaskExport = albTask +type SimpleTaskExport = simpleTask func NewCommonExport(di *di.D, input *Input) *common { return &common{di: di, Input: input} } + +func NewAlbTaskExport(di *di.D, input *Input, lb *ecstypes.LoadBalancer) *albTask { + return &albTask{common: &common{di: di, Input: input}, Lb: lb} +} + +func NewSimpleTaskExport(di *di.D, input *Input) *simpleTask { + return &simpleTask{common: &common{di: di, Input: input}} +} diff --git a/task/simple_task.go b/task/simple_task.go index 12f6916..0ac8c31 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -27,14 +27,14 @@ func (c *simpleTask) Wait(ctx context.Context) error { if err := c.waitForTask(ctx); err != nil { return err } - return c.waitForIdleDuration(ctx) + return c.WaitForIdleDuration(ctx) } func (c *simpleTask) Stop(ctx context.Context) error { return c.stopTask(ctx) } -func (c *simpleTask) waitForIdleDuration(ctx context.Context) error { +func (c *simpleTask) WaitForIdleDuration(ctx context.Context) error { env := c.di.Get(key.Env).(*env.Envars) timer := c.di.Get(key.Time).(types.Time) log.Infof("wait %d seconds for canary task to be stable...", env.CanaryTaskIdleDuration) diff --git a/task/simple_task_test.go b/task/simple_task_test.go index fb477fb..493642a 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -3,8 +3,15 @@ package task_test import ( "context" "testing" + "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/golang/mock/gomock" "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/mocks/mock_awsiface" + "github.com/loilo-inc/canarycage/mocks/mock_types" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/logos/di" @@ -38,3 +45,87 @@ func TestSimpleTask(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, mocker.RunningTaskSize()) } + +func TestSimpleTask_WaitForIdleDuration(t *testing.T) { + t.Run("should call DescribeTasks periodically", func(t *testing.T) { + ctrl := gomock.NewController(t) + mocker := test.NewMockContext() + envars := test.DefaultEnvars() + envars.CanaryTaskIdleDuration = 35 // sec + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + fakeTimer := test.NewFakeTime() + timerMock := mock_types.NewMockTime(ctrl) + gomock.InOrder( + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()). + DoAndReturn(mocker.Ecs.RunTask). + Times(1), + timerMock.EXPECT().NewTimer(15*time.Second). + DoAndReturn(fakeTimer.NewTimer). + Times(2), + timerMock.EXPECT().NewTimer(5*time.Second). + DoAndReturn(fakeTimer.NewTimer). + Times(1), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()). + DoAndReturn(mocker.Ecs.DescribeTasks). + Times(1), + ) + cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, timerMock) + }), &task.Input{TaskDefinition: td.TaskDefinition}) + err := cm.Start(context.TODO()) + assert.NoError(t, err) + err = cm.WaitForIdleDuration(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should error if DescribeTasks failed", func(t *testing.T) { + ctrl := gomock.NewController(t) + mocker := test.NewMockContext() + envars := test.DefaultEnvars() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + gomock.InOrder( + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()). + DoAndReturn(mocker.Ecs.RunTask). + Times(1), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()). + Return(nil, assert.AnError). + Times(1), + ) + cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + }), &task.Input{TaskDefinition: td.TaskDefinition}) + cm.Start(context.TODO()) + err := cm.WaitForIdleDuration(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + t.Run("should error if task is not started", func(t *testing.T) { + ctrl := gomock.NewController(t) + mocker := test.NewMockContext() + envars := test.DefaultEnvars() + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()). + DoAndReturn(mocker.Ecs.RunTask). + Times(1) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()). + Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{ + LastStatus: aws.String("STOPPED"), + StoppedReason: aws.String("reason"), + }}, + }, nil) + cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + }), &task.Input{TaskDefinition: td.TaskDefinition}) + cm.Start(context.TODO()) + err := cm.WaitForIdleDuration(context.TODO()) + assert.EqualError(t, err, "😫 canary task has stopped: reason") + }) +} diff --git a/task/task.go b/task/task.go index dd21a00..b5fed47 100644 --- a/task/task.go +++ b/task/task.go @@ -6,5 +6,4 @@ type Task interface { Start(ctx context.Context) error Wait(ctx context.Context) error Stop(ctx context.Context) error - TaskArn() *string } From 5b7295c67bd453e04d514c7a3d60a414f7bb02b4 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 18:42:43 +0900 Subject: [PATCH 28/47] add tests --- task/alb_task.go | 129 +++++++++++++++++------------------------- task/alb_task_test.go | 89 +++++++++++++++++++++++++++++ task/common.go | 37 +++++------- task/simple_task.go | 2 +- 4 files changed, 156 insertions(+), 101 deletions(-) diff --git a/task/alb_task.go b/task/alb_task.go index c0ff35a..b7d4e67 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -24,7 +24,7 @@ import ( type albTask struct { *common Lb *ecstypes.LoadBalancer - Target *CanaryTarget + Target *elbv2types.TargetDescription } func NewAlbTask( @@ -45,9 +45,9 @@ func (c *albTask) Wait(ctx context.Context) error { if err := c.registerToTargetGroup(ctx); err != nil { return err } - log.Infof("canary task '%s' is registered to target group '%s'", c.Target.targetId, *c.Lb.TargetGroupArn) + log.Infof("canary task '%s' is registered to target group '%s'", *c.Target.Id, *c.Lb.TargetGroupArn) log.Infof("😷 waiting canary target to be healthy...") - if err := c.waitUntilTargetHealthy(ctx); err != nil { + if err := c.WaitUntilTargetHealthy(ctx); err != nil { return err } log.Info("🤩 canary target is healthy!") @@ -61,34 +61,43 @@ func (c *albTask) Stop(ctx context.Context) error { func (c *albTask) describeTaskTarget( ctx context.Context, - targetPort int32, -) (*CanaryTarget, error) { +) (*elbv2types.TargetDescription, error) { env := c.di.Get(key.Env).(*env.Envars) - target := CanaryTarget{targetPort: targetPort} + targetPort, err := c.getTargetPort() + if err != nil { + return nil, err + } + target := &elbv2types.TargetDescription{Port: targetPort} + var subnetId *string if env.CanaryInstanceArn == "" { // Fargate - if err := c.getFargateTarget(ctx, &target); err != nil { - return nil, err - } - log.Infof("canary task was placed: privateIp = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + target.Id, subnetId, err = c.getFargateTargetNetwork(ctx) } else { - if err := c.getEc2Target(ctx, &target); err != nil { - return nil, err - } - log.Infof("canary task was placed: instanceId = '%s', hostPort = '%d', az = '%s'", target.targetId, target.targetPort, target.availabilityZone) + target.Id, subnetId, err = c.getEc2TargetNetwork(ctx) + } + if err != nil { + return nil, err + } + ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) + if o, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ + SubnetIds: []string{*subnetId}, + }); err != nil { + return nil, err + } else { + target.AvailabilityZone = o.Subnets[0].AvailabilityZone } - return &target, nil + log.Infof("canary task was placed: id = '%s', hostPort = '%d', az = '%s'", *target.Id, *target.Port, *target.AvailabilityZone) + return target, nil } -func (c *albTask) getFargateTarget(ctx context.Context, dest *CanaryTarget) error { +func (c *albTask) getFargateTargetNetwork(ctx context.Context) (*string, *string, error) { var task ecstypes.Task env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - ec2Cli := c.di.Get(key.Ec2Cli).(awsiface.Ec2Client) if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.taskArn}, + Tasks: []string{*c.TaskArn}, }); err != nil { - return err + return nil, nil, err } else { task = o.Tasks[0] } @@ -103,21 +112,12 @@ func (c *albTask) getFargateTarget(ctx context.Context, dest *CanaryTarget) erro } } if subnetId == nil || privateIp == nil { - return xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") + return nil, nil, xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") } - if o, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*subnetId}, - }); err != nil { - return err - } else { - dest.targetId = *privateIp - dest.targetIpv4 = *privateIp - dest.availabilityZone = *o.Subnets[0].AvailabilityZone - } - return nil + return privateIp, subnetId, nil } -func (c *albTask) getEc2Target(ctx context.Context, dest *CanaryTarget) error { +func (c *albTask) getEc2TargetNetwork(ctx context.Context) (*string, *string, error) { var containerInstance ecstypes.ContainerInstance env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) @@ -126,7 +126,7 @@ func (c *albTask) getEc2Target(ctx context.Context, dest *CanaryTarget) error { Cluster: &env.Cluster, ContainerInstances: []string{env.CanaryInstanceArn}, }); err != nil { - return err + return nil, nil, err } else { containerInstance = outputs.ContainerInstances[0] } @@ -134,36 +134,25 @@ func (c *albTask) getEc2Target(ctx context.Context, dest *CanaryTarget) error { if o, err := ec2Cli.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ InstanceIds: []string{*containerInstance.Ec2InstanceId}, }); err != nil { - return err + return nil, nil, err } else { ec2Instance = o.Reservations[0].Instances[0] } - if sn, err := ec2Cli.DescribeSubnets(ctx, &ec2.DescribeSubnetsInput{ - SubnetIds: []string{*ec2Instance.SubnetId}, - }); err != nil { - return err - } else { - dest.targetId = *containerInstance.Ec2InstanceId - dest.targetIpv4 = *ec2Instance.PrivateIpAddress - dest.availabilityZone = *sn.Subnets[0].AvailabilityZone - } - return nil + return ec2Instance.PrivateIpAddress, ec2Instance.SubnetId, nil } -func (c *albTask) getTargetPort() (int32, error) { +func (c *albTask) getTargetPort() (*int32, error) { for _, container := range c.TaskDefinition.ContainerDefinitions { if *container.Name == *c.Lb.ContainerName { - return *container.PortMappings[0].HostPort, nil + return container.PortMappings[0].HostPort, nil } } - return 0, xerrors.Errorf("couldn't find host port in container definition") + return nil, xerrors.Errorf("couldn't find host port in container definition") } func (c *albTask) registerToTargetGroup(ctx context.Context) error { log.Infof("registering the canary task to target group '%s'...", *c.Lb.TargetGroupArn) - if targetPort, err := c.getTargetPort(); err != nil { - return err - } else if target, err := c.describeTaskTarget(ctx, targetPort); err != nil { + if target, err := c.describeTaskTarget(ctx); err != nil { return err } else { c.Target = target @@ -171,18 +160,14 @@ func (c *albTask) registerToTargetGroup(ctx context.Context) error { albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) if _, err := albCli.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.Target.availabilityZone, - Id: &c.Target.targetId, - Port: &c.Target.targetPort, - }}, - }); err != nil { + Targets: []elbv2types.TargetDescription{*c.Target}}, + ); err != nil { return err } return nil } -func (c *albTask) waitUntilTargetHealthy( +func (c *albTask) WaitUntilTargetHealthy( ctx context.Context, ) error { albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) @@ -198,23 +183,19 @@ func (c *albTask) waitUntilTargetHealthy( case <-timer.NewTimer(waitPeriod).C: if o, err := albCli.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - Id: &c.Target.targetId, - Port: &c.Target.targetPort, - AvailabilityZone: &c.Target.availabilityZone, - }}, + Targets: []elbv2types.TargetDescription{*c.Target}, }); err != nil { return err } else { for _, desc := range o.TargetHealthDescriptions { - if *desc.Target.Id == c.Target.targetId && *desc.Target.Port == c.Target.targetPort { + if *desc.Target.Id == *c.Target.Id && *desc.Target.Port == *c.Target.Port { recentState = &desc.TargetHealth.State } } if recentState == nil { - return xerrors.Errorf("'%s' is not registered to the target group '%s'", c.Target.targetId, *c.Lb.TargetGroupArn) + return xerrors.Errorf("'%s' is not registered to the target group '%s'", *c.Target.Id, *c.Lb.TargetGroupArn) } - log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, c.Target.targetId, c.Target.targetPort, *recentState) + log.Infof("canary task '%s' (%s:%d) state is: %s", *c.TaskArn, *c.Target.Id, *c.Target.Port, *recentState) switch *recentState { case elbv2types.TargetHealthStateEnumHealthy: return nil @@ -225,10 +206,10 @@ func (c *albTask) waitUntilTargetHealthy( } } // unhealthy, draining, unused - log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) + log.Errorf("😨 canary task '%s' is unhealthy", *c.TaskArn) return xerrors.Errorf( "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", - *c.taskArn, c.Target.targetId, c.Target.targetPort, *recentState, + *c.TaskArn, *c.Target.Id, *c.Target.Port, *recentState, ) } @@ -264,14 +245,10 @@ func (c *albTask) deregisterTarget(ctx context.Context) { log.Errorf("failed to get deregistration delay: %v", err) log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) } - log.Infof("deregistering the canary task from target group '%s'...", c.Target.targetId) + log.Infof("deregistering the canary task from target group '%s'...", *c.Target.Id) if _, err := albCli.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.Target.availabilityZone, - Id: &c.Target.targetId, - Port: &c.Target.targetPort, - }}, + Targets: []elbv2types.TargetDescription{*c.Target}, }); err != nil { log.Errorf("failed to deregister the canary task from target group: %v", err) log.Errorf("continuing to stop the canary task...") @@ -280,18 +257,14 @@ func (c *albTask) deregisterTarget(ctx context.Context) { deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety if err := elbv2.NewTargetDeregisteredWaiter(albCli).Wait(ctx, &elbv2.DescribeTargetHealthInput{ TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{{ - AvailabilityZone: &c.Target.availabilityZone, - Id: &c.Target.targetId, - Port: &c.Target.targetPort, - }}, + Targets: []elbv2types.TargetDescription{*c.Target}, }, deregisterWait); err != nil { log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) log.Errorf("continuing to stop the canary task...") } else { log.Infof( "canary task '%s' has successfully been deregistered from target group '%s'", - *c.taskArn, c.Target.targetId, + *c.TaskArn, *c.Target.Id, ) } } diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 61e1c9f..05ca25a 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -2,11 +2,16 @@ package task_test import ( "context" + "fmt" "testing" + "github.com/aws/aws-sdk-go-v2/aws" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + "github.com/golang/mock/gomock" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/logos/di" @@ -62,3 +67,87 @@ func TestAlbTask(t *testing.T) { assert.Equal(t, 1, mocker.RunningTaskSize()) }) } + +func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { + target := &elbv2types.TargetDescription{ + Id: aws.String("127.0.0.1"), + Port: aws.Int32(80), + AvailabilityZone: aws.String("ap-northeast-1a"), + } + setup := func(t *testing.T) (*mock_awsiface.MockAlbClient, *task.AlbTaskExport) { + ctrl := gomock.NewController(t) + env := test.DefaultEnvars() + mocker := test.NewMockContext() + albMock := mock_awsiface.NewMockAlbClient(ctrl) + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), env.TaskDefinitionInput) + atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.AlbCli, albMock) + b.Set(key.Time, test.NewFakeTime()) + }), &task.Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, + }, &env.ServiceDefinitionInput.LoadBalancers[0]) + atask.TaskArn = aws.String("arn://task") + atask.Target = target + return albMock, atask + } + t.Run("should call DescribeTargetHealth periodically", func(t *testing.T) { + albMock, atask := setup(t) + gomock.InOrder( + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ + {TargetHealth: &elbv2types.TargetHealth{State: elbv2types.TargetHealthStateEnumUnused}, + Target: target, + }, + }, + }, nil).Times(1), + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ + {TargetHealth: &elbv2types.TargetHealth{State: elbv2types.TargetHealthStateEnumHealthy}, + Target: target, + }, + }, + }, nil).Times(1), + ) + err := atask.WaitUntilTargetHealthy(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should error if DescribeTargetHealth failed", func(t *testing.T) { + albMock, atask := setup(t) + gomock.InOrder( + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(nil, assert.AnError).Times(1), + ) + err := atask.WaitUntilTargetHealthy(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + t.Run("should error if target is not registered", func(t *testing.T) { + albMock, atask := setup(t) + gomock.InOrder( + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: []elbv2types.TargetHealthDescription{}, + }, nil).Times(1), + ) + err := atask.WaitUntilTargetHealthy(context.TODO()) + assert.EqualError(t, err, fmt.Sprintf( + "'%s' is not registered to the target group '%s'", *target.Id, *atask.Lb.TargetGroupArn), + ) + }) + t.Run("should error if target unhelthy counts exceed the limit", func(t *testing.T) { + albMock, atask := setup(t) + gomock.InOrder( + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ + {TargetHealth: &elbv2types.TargetHealth{State: elbv2types.TargetHealthStateEnumUnhealthy}, + Target: target, + }, + }, + }, nil).Times(5), + ) + err := atask.WaitUntilTargetHealthy(context.TODO()) + assert.EqualError(t, err, fmt.Sprintf( + "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", + *atask.TaskArn, *target.Id, *target.Port, elbv2types.TargetHealthStateEnumUnhealthy, + ), + ) + }) +} diff --git a/task/common.go b/task/common.go index 7af52aa..90d274b 100644 --- a/task/common.go +++ b/task/common.go @@ -16,13 +16,6 @@ import ( "golang.org/x/xerrors" ) -type CanaryTarget struct { - targetId string - targetIpv4 string - targetPort int32 - availabilityZone string -} - type Input struct { TaskDefinition *ecstypes.TaskDefinition NetworkConfiguration *ecstypes.NetworkConfiguration @@ -32,7 +25,7 @@ type Input struct { type common struct { *Input di *di.D - taskArn *string + TaskArn *string } func (c *common) Start(ctx context.Context) error { @@ -50,7 +43,7 @@ func (c *common) Start(ctx context.Context) error { }); err != nil { return err } else { - c.taskArn = o.Tasks[0].TaskArn + c.TaskArn = o.Tasks[0].TaskArn } } else { // fargate @@ -64,31 +57,31 @@ func (c *common) Start(ctx context.Context) error { }); err != nil { return err } else { - c.taskArn = o.Tasks[0].TaskArn + c.TaskArn = o.Tasks[0].TaskArn } } return nil } func (c *common) waitForTask(ctx context.Context) error { - if c.taskArn == nil { + if c.TaskArn == nil { return xerrors.New("task is not started") } env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) + log.Infof("🥚 waiting for canary task '%s' is running...", *c.TaskArn) if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.taskArn}, + Tasks: []string{*c.TaskArn}, }, env.GetTaskRunningWait()); err != nil { return err } - log.Infof("🐣 canary task '%s' is running!", *c.taskArn) + log.Infof("🐣 canary task '%s' is running!", *c.TaskArn) if err := c.waitContainerHealthCheck(ctx); err != nil { return err } log.Info("🤩 canary task container(s) is healthy!") - log.Infof("canary task '%s' ensured.", *c.taskArn) + log.Infof("canary task '%s' ensured.", *c.TaskArn) return nil } @@ -113,10 +106,10 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case <-timer.NewTimer(healthCheckPeriod).C: - log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) + log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.TaskArn, len(containerHasHealthChecks)) if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.taskArn}, + Tasks: []string{*c.TaskArn}, }); err != nil { return err } else { @@ -145,24 +138,24 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { } func (c *common) stopTask(ctx context.Context) error { - if c.taskArn == nil { + if c.TaskArn == nil { return nil } env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - log.Infof("stopping the canary task '%s'...", *c.taskArn) + log.Infof("stopping the canary task '%s'...", *c.TaskArn) if _, err := ecsCli.StopTask(ctx, &ecs.StopTaskInput{ Cluster: &env.Cluster, - Task: c.taskArn, + Task: c.TaskArn, }); err != nil { return xerrors.Errorf("failed to stop canary task: %w", err) } if err := ecs.NewTasksStoppedWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.taskArn}, + Tasks: []string{*c.TaskArn}, }, env.GetTaskStoppedWait()); err != nil { return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) } - log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) + log.Infof("canary task '%s' has successfully been stopped", *c.TaskArn) return nil } diff --git a/task/simple_task.go b/task/simple_task.go index 0ac8c31..ec3644a 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -55,7 +55,7 @@ func (c *simpleTask) WaitForIdleDuration(ctx context.Context) error { ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.taskArn}, + Tasks: []string{*c.TaskArn}, }) if err != nil { return err From 8140df145993643a62380581c1b633a148dd2f80 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 18:51:42 +0900 Subject: [PATCH 29/47] Update timeout_test.go --- env/timeout_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/env/timeout_test.go b/env/timeout_test.go index b8990a3..16a6ea7 100644 --- a/env/timeout_test.go +++ b/env/timeout_test.go @@ -15,7 +15,7 @@ func TestEnv_Timeout(t *testing.T) { assert.Equal(t, 15*time.Minute, e.GetTaskStoppedWait()) assert.Equal(t, 15*time.Minute, e.GetTaskHealthCheckWait()) assert.Equal(t, 15*time.Minute, e.GetServiceStableWait()) - assert.Equal(t, time.Duration(0), e.GetCanaryTaskIdleWait()) + assert.Equal(t, 15*time.Second, e.GetCanaryTaskIdleWait()) }) t.Run("with config", func(t *testing.T) { e := &env.Envars{ From 7a73ae0c142956a1f09af638b84b85975fe86086 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 19:03:03 +0900 Subject: [PATCH 30/47] Update simple_task_test.go --- task/alb_task_test.go | 7 +++ task/simple_task_test.go | 96 ++++++++++++++++++++-------------------- 2 files changed, 54 insertions(+), 49 deletions(-) diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 05ca25a..3d55952 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -120,6 +120,13 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { err := atask.WaitUntilTargetHealthy(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) + t.Run("should error if context is canceled", func(t *testing.T) { + _, atask := setup(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := atask.WaitUntilTargetHealthy(ctx) + assert.EqualError(t, err, "context canceled") + }) t.Run("should error if target is not registered", func(t *testing.T) { albMock, atask := setup(t) gomock.InOrder( diff --git a/task/simple_task_test.go b/task/simple_task_test.go index 493642a..220b7ba 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -47,19 +47,26 @@ func TestSimpleTask(t *testing.T) { } func TestSimpleTask_WaitForIdleDuration(t *testing.T) { - t.Run("should call DescribeTasks periodically", func(t *testing.T) { + setup := func(t *testing.T, idle int) (*mock_awsiface.MockEcsClient, *mock_types.MockTime, *task.SimpleTaskExport) { ctrl := gomock.NewController(t) mocker := test.NewMockContext() envars := test.DefaultEnvars() - envars.CanaryTaskIdleDuration = 35 // sec + envars.CanaryTaskIdleDuration = idle td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - fakeTimer := test.NewFakeTime() timerMock := mock_types.NewMockTime(ctrl) + cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, timerMock) + }), &task.Input{TaskDefinition: td.TaskDefinition}) + cm.TaskArn = aws.String("task-arn") + return ecsMock, timerMock, cm + } + t.Run("should call DescribeTasks periodically", func(t *testing.T) { + ecsMock, timerMock, cm := setup(t, 35) + fakeTimer := test.NewFakeTime() gomock.InOrder( - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()). - DoAndReturn(mocker.Ecs.RunTask). - Times(1), timerMock.EXPECT().NewTimer(15*time.Second). DoAndReturn(fakeTimer.NewTimer). Times(2), @@ -67,64 +74,55 @@ func TestSimpleTask_WaitForIdleDuration(t *testing.T) { DoAndReturn(fakeTimer.NewTimer). Times(1), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()). - DoAndReturn(mocker.Ecs.DescribeTasks). + Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING")}}, + }, nil). Times(1), ) - cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.Time, timerMock) - }), &task.Input{TaskDefinition: td.TaskDefinition}) - err := cm.Start(context.TODO()) - assert.NoError(t, err) - err = cm.WaitForIdleDuration(context.TODO()) + err := cm.WaitForIdleDuration(context.TODO()) assert.NoError(t, err) }) t.Run("should error if DescribeTasks failed", func(t *testing.T) { - ctrl := gomock.NewController(t) - mocker := test.NewMockContext() - envars := test.DefaultEnvars() - td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) - ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + ecsMock, timerMock, cm := setup(t, 15) + fakeTimer := test.NewFakeTime() gomock.InOrder( - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()). - DoAndReturn(mocker.Ecs.RunTask). + timerMock.EXPECT().NewTimer(15*time.Second). + DoAndReturn(fakeTimer.NewTimer). Times(1), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()). Return(nil, assert.AnError). Times(1), ) - cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.Time, test.NewFakeTime()) - }), &task.Input{TaskDefinition: td.TaskDefinition}) - cm.Start(context.TODO()) err := cm.WaitForIdleDuration(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) + t.Run("sholud error if ctx is canceled", func(t *testing.T) { + _, timerMock, cm := setup(t, 15) + gomock.InOrder( + timerMock.EXPECT().NewTimer(15 * time.Second). + DoAndReturn(time.NewTimer). + Times(1), + ) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := cm.WaitForIdleDuration(ctx) + assert.EqualError(t, err, "context canceled") + }) t.Run("should error if task is not started", func(t *testing.T) { - ctrl := gomock.NewController(t) - mocker := test.NewMockContext() - envars := test.DefaultEnvars() - ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any()). - DoAndReturn(mocker.Ecs.RunTask). - Times(1) - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()). - Return(&ecs.DescribeTasksOutput{ - Tasks: []ecstypes.Task{{ - LastStatus: aws.String("STOPPED"), - StoppedReason: aws.String("reason"), - }}, - }, nil) - cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.Time, test.NewFakeTime()) - }), &task.Input{TaskDefinition: td.TaskDefinition}) - cm.Start(context.TODO()) + ecsMock, timerMock, cm := setup(t, 15) + fakeTimer := test.NewFakeTime() + gomock.InOrder( + timerMock.EXPECT().NewTimer(15*time.Second). + DoAndReturn(fakeTimer.NewTimer). + Times(1), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()). + Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{ + LastStatus: aws.String("STOPPED"), + StoppedReason: aws.String("reason"), + }}, + }, nil), + ) err := cm.WaitForIdleDuration(context.TODO()) assert.EqualError(t, err, "😫 canary task has stopped: reason") }) From f51a5a369dd27c6e2506b40b1a691397a2078ac4 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 19:59:58 +0900 Subject: [PATCH 31/47] test test test --- task/alb_task.go | 45 +++++----- task/alb_task_test.go | 74 +++++++++++++++ task/common.go | 58 ++++++------ task/common_test.go | 203 ++++++++++++++++++++++++++++++++++++++++++ task/simple_task.go | 7 +- 5 files changed, 336 insertions(+), 51 deletions(-) diff --git a/task/alb_task.go b/task/alb_task.go index b7d4e67..8ac283b 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -39,10 +39,13 @@ func NewAlbTask( } func (c *albTask) Wait(ctx context.Context) error { - if err := c.waitForTask(ctx); err != nil { + if err := c.WaitForTaskRunning(ctx); err != nil { return err } - if err := c.registerToTargetGroup(ctx); err != nil { + if err := c.WaitContainerHealthCheck(ctx); err != nil { + return err + } + if err := c.RegisterToTargetGroup(ctx); err != nil { return err } log.Infof("canary task '%s' is registered to target group '%s'", *c.Target.Id, *c.Lb.TargetGroupArn) @@ -55,8 +58,8 @@ func (c *albTask) Wait(ctx context.Context) error { } func (c *albTask) Stop(ctx context.Context) error { - c.deregisterTarget(ctx) - return c.stopTask(ctx) + c.DeregisterTarget(ctx) + return c.StopTask(ctx) } func (c *albTask) describeTaskTarget( @@ -150,7 +153,7 @@ func (c *albTask) getTargetPort() (*int32, error) { return nil, xerrors.Errorf("couldn't find host port in container definition") } -func (c *albTask) registerToTargetGroup(ctx context.Context) error { +func (c *albTask) RegisterToTargetGroup(ctx context.Context) error { log.Infof("registering the canary task to target group '%s'...", *c.Lb.TargetGroupArn) if target, err := c.describeTaskTarget(ctx); err != nil { return err @@ -235,7 +238,7 @@ func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, return deregistrationDelay, nil } -func (c *albTask) deregisterTarget(ctx context.Context) { +func (c *albTask) DeregisterTarget(ctx context.Context) { if c.Target == nil { return } @@ -252,20 +255,20 @@ func (c *albTask) deregisterTarget(ctx context.Context) { }); err != nil { log.Errorf("failed to deregister the canary task from target group: %v", err) log.Errorf("continuing to stop the canary task...") - } else { - log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") - deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety - if err := elbv2.NewTargetDeregisteredWaiter(albCli).Wait(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{*c.Target}, - }, deregisterWait); err != nil { - log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) - log.Errorf("continuing to stop the canary task...") - } else { - log.Infof( - "canary task '%s' has successfully been deregistered from target group '%s'", - *c.TaskArn, *c.Target.Id, - ) - } + return } + log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") + deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety + if err := elbv2.NewTargetDeregisteredWaiter(albCli).Wait(ctx, &elbv2.DescribeTargetHealthInput{ + TargetGroupArn: c.Lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{*c.Target}, + }, deregisterWait); err != nil { + log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) + log.Errorf("continuing to stop the canary task...") + return + } + log.Infof( + "canary task '%s' has successfully been deregistered from target group '%s'", + *c.TaskArn, *c.Target.Id, + ) } diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 3d55952..2d92397 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -158,3 +158,77 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { ) }) } + +func TestAlbTask_DeregisterTarget(t *testing.T) { + target := &elbv2types.TargetDescription{ + Id: aws.String("127.0.0.1"), + Port: aws.Int32(80), + AvailabilityZone: aws.String("ap-northeast-1a"), + } + setup := func(t *testing.T, env *env.Envars) (*mock_awsiface.MockAlbClient, *task.AlbTaskExport) { + ctrl := gomock.NewController(t) + mocker := test.NewMockContext() + albMock := mock_awsiface.NewMockAlbClient(ctrl) + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), env.TaskDefinitionInput) + atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.AlbCli, albMock) + }), &task.Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, + }, &env.ServiceDefinitionInput.LoadBalancers[0]) + atask.TaskArn = aws.String("arn://task") + atask.Target = target + return albMock, atask + } + t.Run("should do nothing if target is nil", func(t *testing.T) { + atask := task.NewAlbTaskExport(di.EmptyDomain(), &task.Input{}, nil) + atask.DeregisterTarget(context.TODO()) + }) + t.Run("should call DeregisterTargets and wait", func(t *testing.T) { + env := test.DefaultEnvars() + albMock, atask := setup(t, env) + gomock.InOrder( + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetGroupAttributesOutput{ + Attributes: []elbv2types.TargetGroupAttribute{ + {Key: aws.String("deregistration_delay.timeout_seconds"), Value: aws.String("300")}, + }, + }, nil).Times(1), + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1), + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ + {TargetHealth: &elbv2types.TargetHealth{State: elbv2types.TargetHealthStateEnumUnused}, + Target: target, + }, + }, + }, nil).Times(1), + ) + atask.DeregisterTarget(context.TODO()) + }) + t.Run("should return even if DeregisterTargets failed", func(t *testing.T) { + env := test.DefaultEnvars() + albMock, atask := setup(t, env) + gomock.InOrder( + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetGroupAttributesOutput{ + Attributes: []elbv2types.TargetGroupAttribute{ + {Key: aws.String("deregistration_delay.timeout_seconds"), Value: aws.String("300")}, + }, + }, nil).Times(1), + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError).Times(1), + ) + atask.DeregisterTarget(context.TODO()) + }) + t.Run("should return even if deregistration wait counts exceed the limit", func(t *testing.T) { + env := test.DefaultEnvars() + albMock, atask := setup(t, env) + gomock.InOrder( + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetGroupAttributesOutput{ + Attributes: []elbv2types.TargetGroupAttribute{ + {Key: aws.String("deregistration_delay.timeout_seconds"), Value: aws.String("1")}, + }, + }, nil).Times(1), + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1), + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, assert.AnError).Times(1), + ) + atask.DeregisterTarget(context.TODO()) + }) +} diff --git a/task/common.go b/task/common.go index 90d274b..7dca95e 100644 --- a/task/common.go +++ b/task/common.go @@ -63,7 +63,7 @@ func (c *common) Start(ctx context.Context) error { return nil } -func (c *common) waitForTask(ctx context.Context) error { +func (c *common) WaitForTaskRunning(ctx context.Context) error { if c.TaskArn == nil { return xerrors.New("task is not started") } @@ -74,31 +74,30 @@ func (c *common) waitForTask(ctx context.Context) error { Cluster: &env.Cluster, Tasks: []string{*c.TaskArn}, }, env.GetTaskRunningWait()); err != nil { - return err + return xerrors.Errorf("failed to wait for canary task to be running: %w", err) } log.Infof("🐣 canary task '%s' is running!", *c.TaskArn) - if err := c.waitContainerHealthCheck(ctx); err != nil { - return err - } - log.Info("🤩 canary task container(s) is healthy!") - log.Infof("canary task '%s' ensured.", *c.TaskArn) return nil } -func (c *common) waitContainerHealthCheck(ctx context.Context) error { +func (c *common) WaitContainerHealthCheck(ctx context.Context) error { log.Infof("😷 ensuring canary task container(s) to become healthy...") - env := c.di.Get(key.Env).(*env.Envars) - timer := c.di.Get(key.Time).(types.Time) - ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) containerHasHealthChecks := map[string]struct{}{} for _, definition := range c.TaskDefinition.ContainerDefinitions { if definition.HealthCheck != nil { containerHasHealthChecks[*definition.Name] = struct{}{} } } + if len(containerHasHealthChecks) == 0 { + log.Info("no container has health check, skipped.") + return nil + } + env := c.di.Get(key.Env).(*env.Envars) + timer := c.di.Get(key.Time).(types.Time) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) rest := env.GetTaskHealthCheckWait() healthCheckPeriod := 15 * time.Second - for rest > 0 { + for rest > 0 && len(containerHasHealthChecks) > 0 { if rest < healthCheckPeriod { healthCheckPeriod = rest } @@ -107,37 +106,40 @@ func (c *common) waitContainerHealthCheck(ctx context.Context) error { return ctx.Err() case <-timer.NewTimer(healthCheckPeriod).C: log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.TaskArn, len(containerHasHealthChecks)) + var task ecstypes.Task if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, Tasks: []string{*c.TaskArn}, }); err != nil { return err } else { - task := o.Tasks[0] - if *task.LastStatus != "RUNNING" { - return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) - } - for _, container := range task.Containers { - if _, ok := containerHasHealthChecks[*container.Name]; !ok { - continue - } - if container.HealthStatus != ecstypes.HealthStatusHealthy { - log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) - continue - } - delete(containerHasHealthChecks, *container.Name) + task = o.Tasks[0] + } + if *task.LastStatus != "RUNNING" { + return xerrors.Errorf("😫 canary task has stopped: %s", *task.StoppedReason) + } + for _, container := range task.Containers { + if _, ok := containerHasHealthChecks[*container.Name]; !ok { + continue } - if len(containerHasHealthChecks) == 0 { - return nil + if container.HealthStatus != ecstypes.HealthStatusHealthy { + log.Infof("container '%s' is not healthy: %s", *container.Name, container.HealthStatus) + continue } + delete(containerHasHealthChecks, *container.Name) } } rest -= healthCheckPeriod } + if len(containerHasHealthChecks) == 0 { + log.Info("🤩 canary task container(s) is healthy!") + log.Infof("canary task '%s' ensured.", *c.TaskArn) + return nil + } return xerrors.Errorf("😨 canary task hasn't become to be healthy") } -func (c *common) stopTask(ctx context.Context) error { +func (c *common) StopTask(ctx context.Context) error { if c.TaskArn == nil { return nil } diff --git a/task/common_test.go b/task/common_test.go index 3f643bf..ec0f8f0 100644 --- a/task/common_test.go +++ b/task/common_test.go @@ -4,13 +4,16 @@ import ( "context" "fmt" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/golang/mock/gomock" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" + "github.com/loilo-inc/canarycage/mocks/mock_types" "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/logos/di" @@ -81,3 +84,203 @@ func TestCommon_Start(t *testing.T) { }) }) } + +func TestCommon_WaitForTaskRunning(t *testing.T) { + setup := func(t *testing.T, envars *env.Envars) (*mock_awsiface.MockEcsClient, *task.CommonExport) { + ctrl := gomock.NewController(t) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + td := &ecstypes.TaskDefinition{} + cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), &task.Input{TaskDefinition: td}) + cm.TaskArn = aws.String("task-arn") + return ecsMock, cm + } + t.Run("should call ecs.NewTasksRunningWaiter", func(t *testing.T) { + ecsMock, cm := setup(t, test.DefaultEnvars()) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING")}}, + }, nil) + err := cm.WaitForTaskRunning(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should error if task is not started", func(t *testing.T) { + cm := task.NewCommonExport(di.EmptyDomain(), nil) + err := cm.WaitForTaskRunning(context.TODO()) + assert.EqualError(t, err, "task is not started") + }) + t.Run("should error if ecs.NewTasksRunningWaiter failed", func(t *testing.T) { + envars := test.DefaultEnvars() + envars.CanaryTaskRunningWait = 15 + ecsMock, cm := setup(t, envars) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).Return( + &ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("STOPPED"), StoppedReason: aws.String("reason")}}, + }, nil) + err := cm.WaitForTaskRunning(context.TODO()) + assert.ErrorContains(t, err, "failed to wait for canary task to be running:") + }) +} + +func TestCommon_WaitContainerHealthCheck(t *testing.T) { + setup := func(t *testing.T, envars *env.Envars) (*mock_awsiface.MockEcsClient, *mock_types.MockTime, + *ecstypes.TaskDefinition, + *task.CommonExport) { + ctrl := gomock.NewController(t) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + mocker := test.NewMockContext() + timerMock := mock_types.NewMockTime(ctrl) + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, timerMock) + }), &task.Input{TaskDefinition: td.TaskDefinition}) + cm.TaskArn = aws.String("task-arn") + return ecsMock, timerMock, td.TaskDefinition, cm + } + t.Run("should call DescribeTasks periodically", func(t *testing.T) { + env := test.DefaultEnvars() + ecsMock, timerMock, td, cm := setup(t, env) + faketime := test.NewFakeTime() + gomock.InOrder( + timerMock.EXPECT().NewTimer(15*time.Second).DoAndReturn(faketime.NewTimer), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING"), + Containers: []ecstypes.Container{ + {Name: td.ContainerDefinitions[0].Name, + HealthStatus: ecstypes.HealthStatusUnknown}, + {Name: td.ContainerDefinitions[1].Name, + HealthStatus: ecstypes.HealthStatusUnknown}, + }, + }}, + }, nil), + timerMock.EXPECT().NewTimer(15*time.Second).DoAndReturn(faketime.NewTimer), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING"), + Containers: []ecstypes.Container{ + {Name: td.ContainerDefinitions[0].Name, + HealthStatus: ecstypes.HealthStatusHealthy}, + {Name: td.ContainerDefinitions[1].Name, + HealthStatus: ecstypes.HealthStatusUnknown}, + }, + }}, + }, nil), + ) + err := cm.WaitContainerHealthCheck(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should do nothing if no container has health check", func(t *testing.T) { + env := test.DefaultEnvars() + env.TaskDefinitionInput.ContainerDefinitions[0].HealthCheck = nil + _, _, _, cm := setup(t, env) + err := cm.WaitContainerHealthCheck(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should error if DescribeTasks failed", func(t *testing.T) { + env := test.DefaultEnvars() + ecsMock, timerMock, _, cm := setup(t, env) + faketime := test.NewFakeTime() + gomock.InOrder( + timerMock.EXPECT().NewTimer(15*time.Second).DoAndReturn(faketime.NewTimer), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")), + ) + err := cm.WaitContainerHealthCheck(context.TODO()) + assert.EqualError(t, err, "error") + }) + t.Run("should error if context is canceled", func(t *testing.T) { + env := test.DefaultEnvars() + _, timerMock, _, cm := setup(t, env) + timerMock.EXPECT().NewTimer(15 * time.Second).DoAndReturn(time.NewTimer) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := cm.WaitContainerHealthCheck(ctx) + assert.EqualError(t, err, "context canceled") + }) + t.Run("should error if task is not running", func(t *testing.T) { + env := test.DefaultEnvars() + ecsMock, timerMock, _, cm := setup(t, env) + faketime := test.NewFakeTime() + gomock.InOrder( + timerMock.EXPECT().NewTimer(15*time.Second).DoAndReturn(faketime.NewTimer), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("STOPPED"), + StoppedReason: aws.String("reason")}, + }, + }, + nil), + ) + err := cm.WaitContainerHealthCheck(context.TODO()) + assert.EqualError(t, err, "😫 canary task has stopped: reason") + }) + t.Run("shold error if unhealth counts exceed the limit", func(t *testing.T) { + env := test.DefaultEnvars() + env.CanaryTaskHealthCheckWait = 15 + ecsMock, timerMock, td, cm := setup(t, env) + faketime := test.NewFakeTime() + gomock.InOrder( + timerMock.EXPECT().NewTimer(15*time.Second).DoAndReturn(faketime.NewTimer), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING"), + Containers: []ecstypes.Container{ + {Name: td.ContainerDefinitions[0].Name, + HealthStatus: ecstypes.HealthStatusUnhealthy}, + {Name: td.ContainerDefinitions[1].Name, + HealthStatus: ecstypes.HealthStatusUnknown}, + }, + }}, + }, nil), + ) + err := cm.WaitContainerHealthCheck(context.TODO()) + assert.EqualError(t, err, "😨 canary task hasn't become to be healthy") + }) +} + +func TestCommon_StopTask(t *testing.T) { + setup := func(t *testing.T, env *env.Envars) (*mock_awsiface.MockEcsClient, *task.CommonExport) { + ctrl := gomock.NewController(t) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { + b.Set(key.EcsCli, ecsMock) + b.Set(key.Env, env) + }), nil) + cm.TaskArn = aws.String("task-arn") + return ecsMock, cm + } + t.Run("should call ecscCli.StopTask and wait", func(t *testing.T) { + ecsMock, cm := setup(t, test.DefaultEnvars()) + gomock.InOrder( + ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any()).Return(&ecs.StopTaskOutput{}, nil), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("STOPPED")}}, + }, nil), + ) + err := cm.StopTask(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should do nothing if task is not started", func(t *testing.T) { + cm := task.NewCommonExport(di.EmptyDomain(), nil) + err := cm.StopTask(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should error if StopTask failed", func(t *testing.T) { + ecsMock, cm := setup(t, test.DefaultEnvars()) + ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")) + err := cm.StopTask(context.TODO()) + assert.EqualError(t, err, "failed to stop canary task: error") + }) + t.Run("should error wait time exceeds the limit", func(t *testing.T) { + env := test.DefaultEnvars() + env.CanaryTaskStoppedWait = 1 + ecsMock, cm := setup(t, env) + gomock.InOrder( + ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any()).Return(&ecs.StopTaskOutput{}, nil), + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING")}}, + }, nil), + ) + err := cm.StopTask(context.TODO()) + assert.ErrorContains(t, err, "failed to wait for canary task to be stopped") + }) +} diff --git a/task/simple_task.go b/task/simple_task.go index ec3644a..0911886 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -24,14 +24,17 @@ func NewSimpleTask(di *di.D, input *Input) Task { } func (c *simpleTask) Wait(ctx context.Context) error { - if err := c.waitForTask(ctx); err != nil { + if err := c.WaitForTaskRunning(ctx); err != nil { + return err + } + if err := c.WaitContainerHealthCheck(ctx); err != nil { return err } return c.WaitForIdleDuration(ctx) } func (c *simpleTask) Stop(ctx context.Context) error { - return c.stopTask(ctx) + return c.StopTask(ctx) } func (c *simpleTask) WaitForIdleDuration(ctx context.Context) error { From 4c3761e64f61b5f7d4fe4e91bc5bdf17f9464e09 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 20:32:08 +0900 Subject: [PATCH 32/47] test --- env/env.go | 4 +- task/alb_task_test.go | 191 ++++++++++++++++++++++++++++++++++++++++++ task/common_test.go | 3 +- 3 files changed, 194 insertions(+), 4 deletions(-) diff --git a/env/env.go b/env/env.go index 2bc152f..cb0c05e 100644 --- a/env/env.go +++ b/env/env.go @@ -7,7 +7,6 @@ import ( "regexp" "strings" - "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/service/ecs" "golang.org/x/xerrors" ) @@ -45,7 +44,6 @@ const TaskRunningTimeout = "CAGE_TASK_RUNNING_TIMEOUT" const TaskHealthCheckTimeout = "CAGE_TASK_HEALTH_CHECK_TIMEOUT" const TaskStoppedTimeout = "CAGE_TASK_STOPPED_TIMEOUT" const ServiceStableTimeout = "CAGE_SERVICE_STABLE_TIMEOUT" -const TargetHealthCheckTimeout = "CAGE_TARGET_HEALTH_CHECK_TIMEOUT" func EnsureEnvars( dest *Envars, @@ -138,7 +136,7 @@ func ReadFileAndApplyEnvars(path string) ([]byte, error) { if envar, ok := os.LookupEnv(m[1]); ok { str = strings.Replace(str, m[0], envar, -1) } else { - log.Fatalf("envar literal '%s' found in %s but was not defined", m[0], path) + return nil, xerrors.Errorf("envar literal '%s' found in %s but was not defined", m[0], path) } } return []byte(str), nil diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 2d92397..42d7a14 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -6,6 +6,10 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/golang/mock/gomock" @@ -159,6 +163,193 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { }) } +func TestAlbTask_RegisterToTargetGroup(t *testing.T) { + t.Run("should error if port mapping is not found", func(t *testing.T) { + env := test.DefaultEnvars() + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), env.TaskDefinitionInput) + atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + }), &task.Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, + }, &ecstypes.LoadBalancer{ + TargetGroupArn: aws.String("arn://target-group"), + ContainerName: aws.String("unknown")}) + atask.TaskArn = aws.String("arn://task") + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, "couldn't find host port in container definition") + }) + t.Run("Fargate", func(t *testing.T) { + attachments := []ecstypes.Attachment{{ + Details: []ecstypes.KeyValuePair{{ + Name: aws.String("networkInterfaceId"), + Value: aws.String("eni-123456"), + }, { + Name: aws.String("subnetId"), + Value: aws.String("subnet-123456"), + }, { + Name: aws.String("privateIPv4Address"), + Value: aws.String("127.0.0.1"), + }, + }}} + subnets := []ec2types.Subnet{{ + AvailabilityZone: aws.String("ap-northeast-1a"), + }} + setup := func(t *testing.T) (*mock_awsiface.MockEc2Client, *mock_awsiface.MockAlbClient, *mock_awsiface.MockEcsClient, *task.AlbTaskExport) { + ctrl := gomock.NewController(t) + envars := test.DefaultEnvars() + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + ec2Mock := mock_awsiface.NewMockEc2Client(ctrl) + albMock := mock_awsiface.NewMockAlbClient(ctrl) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.AlbCli, albMock) + b.Set(key.EcsCli, ecsMock) + }), &task.Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: envars.ServiceDefinitionInput.NetworkConfiguration, + }, &envars.ServiceDefinitionInput.LoadBalancers[0]) + atask.TaskArn = aws.String("arn://task") + return ec2Mock, albMock, ecsMock, atask + } + t.Run("should call RegisterTargets", func(t *testing.T) { + ec2Mock, albMock, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{ + LastStatus: aws.String("RUNNING"), + Attachments: attachments}, + }}, nil) + ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSubnetsOutput{ + Subnets: subnets, + }, nil) + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil) + atask.TaskArn = aws.String("arn://task") + err := atask.RegisterToTargetGroup(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should error if DescribeTasks failed", func(t *testing.T) { + _, _, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + t.Run("should error if DescribeSubnets failed", func(t *testing.T) { + ec2Mock, _, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{ + LastStatus: aws.String("RUNNING"), + Attachments: attachments}, + }}, nil) + ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + t.Run("should error if task is not attached to the network interface", func(t *testing.T) { + _, _, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{ + LastStatus: aws.String("RUNNING"), + }}, + }, nil) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, "couldn't find subnetId or privateIPv4Address in task details") + }) + t.Run("should error if RegisterTargets failed", func(t *testing.T) { + ec2Mock, albMock, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{ + LastStatus: aws.String("RUNNING"), + Attachments: attachments}, + }}, nil) + ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSubnetsOutput{ + Subnets: subnets, + }, nil) + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + }) + t.Run("EC2", func(t *testing.T) { + containerInstances := []ecstypes.ContainerInstance{{ + ContainerInstanceArn: aws.String("arn://container"), + }} + reservations := []ec2types.Reservation{{ + Instances: []ec2types.Instance{{ + InstanceId: aws.String("i-123456"), + }}, + }} + subnets := []ec2types.Subnet{{ + AvailabilityZone: aws.String("ap-northeast-1a"), + }} + setup := func(t *testing.T) (*mock_awsiface.MockEc2Client, *mock_awsiface.MockAlbClient, *mock_awsiface.MockEcsClient, *task.AlbTaskExport) { + ctrl := gomock.NewController(t) + envars := test.DefaultEnvars() + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + ec2Mock := mock_awsiface.NewMockEc2Client(ctrl) + albMock := mock_awsiface.NewMockAlbClient(ctrl) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.AlbCli, albMock) + b.Set(key.EcsCli, ecsMock) + }), &task.Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: envars.ServiceDefinitionInput.NetworkConfiguration, + }, &envars.ServiceDefinitionInput.LoadBalancers[0]) + atask.TaskArn = aws.String("arn://task") + return ec2Mock, albMock, ecsMock, atask + } + t.Run("should call RegisterTargets", func(t *testing.T) { + ec2Mock, albMock, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any()).Return(&ecs.DescribeContainerInstancesOutput{ + ContainerInstances: containerInstances, + }, nil) + ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any()).Return(&ec2.DescribeInstancesOutput{ + Reservations: reservations, + }, nil) + ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSubnetsOutput{ + Subnets: subnets, + }, nil) + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.NoError(t, err) + }) + t.Run("should error if DescribeContainerInstances failed", func(t *testing.T) { + _, _, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + t.Run("should error if DescribeInstances failed", func(t *testing.T) { + ec2Mock, _, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any()).Return(&ecs.DescribeContainerInstancesOutput{ + ContainerInstances: containerInstances, + }, nil) + ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + t.Run("should error if DescribeSubnets failed", func(t *testing.T) { + ec2Mock, _, ecsMock, atask := setup(t) + ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any()).Return(&ecs.DescribeContainerInstancesOutput{ + ContainerInstances: containerInstances, + }, nil) + ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any()).Return(&ec2.DescribeInstancesOutput{ + Reservations: reservations, + }, nil) + ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := atask.RegisterToTargetGroup(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + }) +} + func TestAlbTask_DeregisterTarget(t *testing.T) { target := &elbv2types.TargetDescription{ Id: aws.String("127.0.0.1"), diff --git a/task/common_test.go b/task/common_test.go index ec0f8f0..4759e6a 100644 --- a/task/common_test.go +++ b/task/common_test.go @@ -142,6 +142,7 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { } t.Run("should call DescribeTasks periodically", func(t *testing.T) { env := test.DefaultEnvars() + env.CanaryTaskHealthCheckWait = 20 ecsMock, timerMock, td, cm := setup(t, env) faketime := test.NewFakeTime() gomock.InOrder( @@ -156,7 +157,7 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { }, }}, }, nil), - timerMock.EXPECT().NewTimer(15*time.Second).DoAndReturn(faketime.NewTimer), + timerMock.EXPECT().NewTimer(5*time.Second).DoAndReturn(faketime.NewTimer), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING"), Containers: []ecstypes.Container{ From 305c2f214e9b25f73b91ad1bd7794ad768b9c8b1 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 20:41:47 +0900 Subject: [PATCH 33/47] a --- task/alb_task.go | 17 ++++++++++------- task/alb_task_test.go | 4 +++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/task/alb_task.go b/task/alb_task.go index 8ac283b..f290268 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -104,18 +104,21 @@ func (c *albTask) getFargateTargetNetwork(ctx context.Context) (*string, *string } else { task = o.Tasks[0] } - details := task.Attachments[0].Details var subnetId *string var privateIp *string - for _, v := range details { - if *v.Name == "subnetId" { - subnetId = v.Value - } else if *v.Name == "privateIPv4Address" { - privateIp = v.Value + for _, attachment := range task.Attachments { + if *attachment.Status == "ATTACHED" && *attachment.Type == "ElasticNetworkInterface" { + for _, v := range attachment.Details { + if *v.Name == "subnetId" { + subnetId = v.Value + } else if *v.Name == "privateIPv4Address" { + privateIp = v.Value + } + } } } if subnetId == nil || privateIp == nil { - return nil, nil, xerrors.Errorf("couldn't find subnetId or privateIPv4Address in task details") + return nil, nil, xerrors.Errorf("couldn't find ElasticNetworkInterface attachment in task") } return privateIp, subnetId, nil } diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 42d7a14..daeba6c 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -182,6 +182,8 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { }) t.Run("Fargate", func(t *testing.T) { attachments := []ecstypes.Attachment{{ + Status: aws.String("ATTACHED"), + Type: aws.String("ElasticNetworkInterface"), Details: []ecstypes.KeyValuePair{{ Name: aws.String("networkInterfaceId"), Value: aws.String("eni-123456"), @@ -256,7 +258,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { }}, }, nil) err := atask.RegisterToTargetGroup(context.TODO()) - assert.EqualError(t, err, "couldn't find subnetId or privateIPv4Address in task details") + assert.EqualError(t, err, "couldn't find ElasticNetworkInterface attachment in task") }) t.Run("should error if RegisterTargets failed", func(t *testing.T) { ec2Mock, albMock, ecsMock, atask := setup(t) From 14d22d788d78ad6b5a1bf1971a4e10309262c6bd Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 3 Jul 2024 20:49:07 +0900 Subject: [PATCH 34/47] test --- task/alb_task.go | 2 +- task/alb_task_test.go | 22 +++++++++++++++++++--- test/ecs.go | 2 ++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/task/alb_task.go b/task/alb_task.go index f290268..c6acb85 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -144,7 +144,7 @@ func (c *albTask) getEc2TargetNetwork(ctx context.Context) (*string, *string, er } else { ec2Instance = o.Reservations[0].Instances[0] } - return ec2Instance.PrivateIpAddress, ec2Instance.SubnetId, nil + return ec2Instance.InstanceId, ec2Instance.SubnetId, nil } func (c *albTask) getTargetPort() (*int32, error) { diff --git a/task/alb_task_test.go b/task/alb_task_test.go index daeba6c..e0ed052 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -228,7 +228,13 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSubnetsOutput{ Subnets: subnets, }, nil) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil) + albMock.EXPECT().RegisterTargets(gomock.Any(), &elbv2.RegisterTargetsInput{ + TargetGroupArn: atask.Lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + Id: aws.String("127.0.0.1"), + Port: aws.Int32(80), + AvailabilityZone: subnets[0].AvailabilityZone}, + }}).Return(nil, nil) atask.TaskArn = aws.String("arn://task") err := atask.RegisterToTargetGroup(context.TODO()) assert.NoError(t, err) @@ -278,10 +284,13 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { t.Run("EC2", func(t *testing.T) { containerInstances := []ecstypes.ContainerInstance{{ ContainerInstanceArn: aws.String("arn://container"), + Ec2InstanceId: aws.String("i-123456"), }} reservations := []ec2types.Reservation{{ Instances: []ec2types.Instance{{ - InstanceId: aws.String("i-123456"), + InstanceId: aws.String("i-123456"), + SubnetId: aws.String("subnet-123456"), + PrivateIpAddress: aws.String("127.0.0.1"), }}, }} subnets := []ec2types.Subnet{{ @@ -290,6 +299,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { setup := func(t *testing.T) (*mock_awsiface.MockEc2Client, *mock_awsiface.MockAlbClient, *mock_awsiface.MockEcsClient, *task.AlbTaskExport) { ctrl := gomock.NewController(t) envars := test.DefaultEnvars() + envars.CanaryInstanceArn = "arn://container" mocker := test.NewMockContext() td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) ec2Mock := mock_awsiface.NewMockEc2Client(ctrl) @@ -318,7 +328,13 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(&ec2.DescribeSubnetsOutput{ Subnets: subnets, }, nil) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil) + albMock.EXPECT().RegisterTargets(gomock.Any(), &elbv2.RegisterTargetsInput{ + TargetGroupArn: atask.Lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{{ + Id: containerInstances[0].Ec2InstanceId, + Port: aws.Int32(80), + AvailabilityZone: subnets[0].AvailabilityZone}, + }}).Return(nil, nil) err := atask.RegisterToTargetGroup(context.TODO()) assert.NoError(t, err) }) diff --git a/test/ecs.go b/test/ecs.go index 727bd1a..4c337fe 100644 --- a/test/ecs.go +++ b/test/ecs.go @@ -168,6 +168,8 @@ func (ctx *EcsServer) StartTask(_ context.Context, input *ecs.StartTaskInput, _ } taskArn := fmt.Sprintf("arn:aws:ecs:us-west-2:012345678910:task/%s", uuid.New().String()) attachment := types.Attachment{ + Status: aws.String("ATTACHED"), + Type: aws.String("ElasticNetworkInterface"), Details: []types.KeyValuePair{ { Name: aws.String("privateIPv4Address"), From be89a35041e7103e189d05cb4952c6c72d3c92a3 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 16:32:16 +0900 Subject: [PATCH 35/47] test --- env/env.go | 16 ++-- env/env_test.go | 38 +++++++++- env/testdata/invalid/service.json | 1 + env/testdata/invalid/task-definition.json | 1 + {fixtures => env/testdata}/service.json | 0 .../testdata}/task-definition.json | 74 ------------------- env/{fixtures => testdata}/template.txt | 0 7 files changed, 46 insertions(+), 84 deletions(-) create mode 100644 env/testdata/invalid/service.json create mode 100644 env/testdata/invalid/task-definition.json rename {fixtures => env/testdata}/service.json (100%) rename {fixtures => env/testdata}/task-definition.json (63%) rename env/{fixtures => testdata}/template.txt (100%) diff --git a/env/env.go b/env/env.go index cb0c05e..39330d7 100644 --- a/env/env.go +++ b/env/env.go @@ -70,7 +70,7 @@ func LoadServiceDefiniton(dir string) (*ecs.CreateServiceInput, error) { if noSvc != nil { return nil, xerrors.Errorf("roll out context specified at '%s' but no 'service.json' or 'task-definition.json'", dir) } - if _, err := ReadAndUnmarshalJson(svcPath, &service); err != nil { + if err := ReadAndUnmarshalJson(svcPath, &service); err != nil { return nil, xerrors.Errorf("failed to read and unmarshal service.json: %s", err) } return &service, nil @@ -83,7 +83,7 @@ func LoadTaskDefiniton(dir string) (*ecs.RegisterTaskDefinitionInput, error) { if noTd != nil { return nil, xerrors.Errorf("roll out context specified at '%s' but no 'service.json' or 'task-definition.json'", dir) } - if _, err := ReadAndUnmarshalJson(tdPath, &td); err != nil { + if err := ReadAndUnmarshalJson(tdPath, &td); err != nil { return nil, xerrors.Errorf("failed to read and unmarshal task-definition.json: %s", err) } return &td, nil @@ -113,15 +113,13 @@ func MergeEnvars(dest *Envars, src *Envars) { } } -func ReadAndUnmarshalJson(path string, dest interface{}) ([]byte, error) { +func ReadAndUnmarshalJson(path string, dest interface{}) error { if d, err := ReadFileAndApplyEnvars(path); err != nil { - return d, err - } else { - if err := json.Unmarshal(d, dest); err != nil { - return d, err - } - return d, nil + return err + } else if err := json.Unmarshal(d, dest); err != nil { + return err } + return nil } func ReadFileAndApplyEnvars(path string) ([]byte, error) { diff --git a/env/env_test.go b/env/env_test.go index 1645f8c..000330c 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -86,10 +86,46 @@ func TestMergeEnvars(t *testing.T) { assert.Equal(t, e1.Service, "fuga") } +func TestLoadServiceDefinition(t *testing.T) { + t.Run("basic", func(t *testing.T) { + d, err := env.LoadServiceDefiniton("./testdata") + if err != nil { + t.Fatalf(err.Error()) + } + assert.Equal(t, *d.ServiceName, "service") + }) + t.Run("should error if service.json is not found", func(t *testing.T) { + _, err := env.LoadServiceDefiniton("../") + assert.EqualError(t, err, "roll out context specified at '../' but no 'service.json' or 'task-definition.json'") + }) + t.Run("should error if service.json is invalid", func(t *testing.T) { + _, err := env.LoadServiceDefiniton("./testdata/invalid") + assert.ErrorContains(t, err, "failed to read and unmarshal service.json:") + }) +} + +func TestLoadTaskDefinition(t *testing.T) { + t.Run("basic", func(t *testing.T) { + d, err := env.LoadTaskDefiniton("./testdata") + if err != nil { + t.Fatalf(err.Error()) + } + assert.Equal(t, *d.Family, "test-task") + }) + t.Run("should error if task-definition.json is not found", func(t *testing.T) { + _, err := env.LoadTaskDefiniton("../") + assert.EqualError(t, err, "roll out context specified at '../' but no 'service.json' or 'task-definition.json'") + }) + t.Run("should error if task-definition.json is invalid", func(t *testing.T) { + _, err := env.LoadTaskDefiniton("./testdata/invalid") + assert.ErrorContains(t, err, "failed to read and unmarshal task-definition.json:") + }) +} + func TestReadFileAndApplyEnvars(t *testing.T) { os.Setenv("HOGE", "hogehoge") os.Setenv("FUGA", "fugafuga") - d, err := env.ReadFileAndApplyEnvars("./fixtures/template.txt") + d, err := env.ReadFileAndApplyEnvars("./testdata/template.txt") if err != nil { t.Fatalf(err.Error()) } diff --git a/env/testdata/invalid/service.json b/env/testdata/invalid/service.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/env/testdata/invalid/service.json @@ -0,0 +1 @@ +[] diff --git a/env/testdata/invalid/task-definition.json b/env/testdata/invalid/task-definition.json new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/env/testdata/invalid/task-definition.json @@ -0,0 +1 @@ +[] diff --git a/fixtures/service.json b/env/testdata/service.json similarity index 100% rename from fixtures/service.json rename to env/testdata/service.json diff --git a/fixtures/task-definition.json b/env/testdata/task-definition.json similarity index 63% rename from fixtures/task-definition.json rename to env/testdata/task-definition.json index 23ecc4e..756806f 100644 --- a/fixtures/task-definition.json +++ b/env/testdata/task-definition.json @@ -21,80 +21,6 @@ } ], "essential": true, - "entryPoint": [""], - "command": [""], - "environment": [ - { - "name": "", - "value": "" - } - ], - "mountPoints": [ - { - "sourceVolume": "", - "containerPath": "", - "readOnly": true - } - ], - "volumesFrom": [ - { - "sourceContainer": "", - "readOnly": true - } - ], - "linuxParameters": { - "capabilities": { - "add": [""], - "drop": [""] - }, - "devices": [ - { - "hostPath": "", - "containerPath": "", - "permissions": ["mknod"] - } - ], - "initProcessEnabled": true, - "sharedMemorySize": 0, - "tmpfs": [ - { - "containerPath": "", - "size": 0, - "mountOptions": [""] - } - ] - }, - "hostname": "", - "user": "", - "workingDirectory": "", - "disableNetworking": true, - "privileged": true, - "readonlyRootFilesystem": true, - "dnsServers": [""], - "dnsSearchDomains": [""], - "extraHosts": [ - { - "hostname": "", - "ipAddress": "" - } - ], - "dockerSecurityOptions": [""], - "dockerLabels": { - "KeyName": "" - }, - "ulimits": [ - { - "name": "core", - "softLimit": 0, - "hardLimit": 0 - } - ], - "logConfiguration": { - "logDriver": "gelf", - "options": { - "KeyName": "" - } - }, "healthCheck": { "command": [""], "interval": 0, diff --git a/env/fixtures/template.txt b/env/testdata/template.txt similarity index 100% rename from env/fixtures/template.txt rename to env/testdata/template.txt From 5a13995bdf5d550b8df603b3dc253035d9a1296f Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 16:36:09 +0900 Subject: [PATCH 36/47] a --- env/env_test.go | 12 ++++++------ {env/testdata => fixtures}/service.json | 0 {env/testdata => fixtures}/task-definition.json | 0 3 files changed, 6 insertions(+), 6 deletions(-) rename {env/testdata => fixtures}/service.json (100%) rename {env/testdata => fixtures}/task-definition.json (100%) diff --git a/env/env_test.go b/env/env_test.go index 000330c..2100420 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -88,15 +88,15 @@ func TestMergeEnvars(t *testing.T) { func TestLoadServiceDefinition(t *testing.T) { t.Run("basic", func(t *testing.T) { - d, err := env.LoadServiceDefiniton("./testdata") + d, err := env.LoadServiceDefiniton("../fixtures") if err != nil { t.Fatalf(err.Error()) } assert.Equal(t, *d.ServiceName, "service") }) t.Run("should error if service.json is not found", func(t *testing.T) { - _, err := env.LoadServiceDefiniton("../") - assert.EqualError(t, err, "roll out context specified at '../' but no 'service.json' or 'task-definition.json'") + _, err := env.LoadServiceDefiniton("./testdata") + assert.EqualError(t, err, "roll out context specified at './testdata' but no 'service.json' or 'task-definition.json'") }) t.Run("should error if service.json is invalid", func(t *testing.T) { _, err := env.LoadServiceDefiniton("./testdata/invalid") @@ -106,15 +106,15 @@ func TestLoadServiceDefinition(t *testing.T) { func TestLoadTaskDefinition(t *testing.T) { t.Run("basic", func(t *testing.T) { - d, err := env.LoadTaskDefiniton("./testdata") + d, err := env.LoadTaskDefiniton("../fixtures") if err != nil { t.Fatalf(err.Error()) } assert.Equal(t, *d.Family, "test-task") }) t.Run("should error if task-definition.json is not found", func(t *testing.T) { - _, err := env.LoadTaskDefiniton("../") - assert.EqualError(t, err, "roll out context specified at '../' but no 'service.json' or 'task-definition.json'") + _, err := env.LoadTaskDefiniton("./testdata") + assert.EqualError(t, err, "roll out context specified at './testdata' but no 'service.json' or 'task-definition.json'") }) t.Run("should error if task-definition.json is invalid", func(t *testing.T) { _, err := env.LoadTaskDefiniton("./testdata/invalid") diff --git a/env/testdata/service.json b/fixtures/service.json similarity index 100% rename from env/testdata/service.json rename to fixtures/service.json diff --git a/env/testdata/task-definition.json b/fixtures/task-definition.json similarity index 100% rename from env/testdata/task-definition.json rename to fixtures/task-definition.json From aa7ad3701ef3026d275b5bbb554a69f51b4293e7 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 16:48:57 +0900 Subject: [PATCH 37/47] test --- rollout_test.go | 4 ++-- task/alb_task.go | 4 ++-- task/alb_task_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/rollout_test.go b/rollout_test.go index 9f409e2..bf763c2 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -422,8 +422,8 @@ func TestCage_RollOut_EC2_no_attribute(t *testing.T) { } ecsMock.EXPECT().ListAttributes(gomock.Any(), gomock.Any()).Return(&ecs.ListAttributesOutput{ Attributes: []ecstypes.Attribute{}, - }, nil).AnyTimes() - ecsMock.EXPECT().PutAttributes(gomock.Any(), gomock.Any()).Return(&ecs.PutAttributesOutput{}, nil).AnyTimes() + }, nil) + ecsMock.EXPECT().PutAttributes(gomock.Any(), gomock.Any()).Return(&ecs.PutAttributesOutput{}, nil) cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { b.Set(key.Env, envars) b.Set(key.EcsCli, ecsMock) diff --git a/task/alb_task.go b/task/alb_task.go index c6acb85..5e9cd00 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -219,7 +219,7 @@ func (c *albTask) WaitUntilTargetHealthy( ) } -func (c *albTask) targetDeregistrationDelay(ctx context.Context) (time.Duration, error) { +func (c *albTask) GetTargetDeregistrationDelay(ctx context.Context) (time.Duration, error) { deregistrationDelay := 300 * time.Second albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) if o, err := albCli.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ @@ -246,7 +246,7 @@ func (c *albTask) DeregisterTarget(ctx context.Context) { return } albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) - deregistrationDelay, err := c.targetDeregistrationDelay(ctx) + deregistrationDelay, err := c.GetTargetDeregistrationDelay(ctx) if err != nil { log.Errorf("failed to get deregistration delay: %v", err) log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) diff --git a/task/alb_task_test.go b/task/alb_task_test.go index e0ed052..e7072de 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "testing" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" @@ -368,6 +369,56 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { }) } +func TestAlbTask_GetTargetDeregistrationDelay(t *testing.T) { + setup := func(t *testing.T) (*mock_awsiface.MockAlbClient, *task.AlbTaskExport) { + ctrl := gomock.NewController(t) + env := test.DefaultEnvars() + albMock := mock_awsiface.NewMockAlbClient(ctrl) + atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { + b.Set(key.AlbCli, albMock) + }), &task.Input{}, &env.ServiceDefinitionInput.LoadBalancers[0]) + return albMock, atask + } + t.Run("should return deregistration delay", func(t *testing.T) { + albMock, atask := setup(t) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetGroupAttributesOutput{ + Attributes: []elbv2types.TargetGroupAttribute{ + {Key: aws.String("deregistration_delay.timeout_seconds"), Value: aws.String("100")}, + }, + }, nil) + delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, 100*time.Second, delay) + }) + t.Run("should return default delay if deregistration_delay is not found", func(t *testing.T) { + albMock, atask := setup(t) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetGroupAttributesOutput{ + Attributes: []elbv2types.TargetGroupAttribute{}, + }, nil) + delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + assert.NoError(t, err) + assert.Equal(t, 300*time.Second, delay) + }) + t.Run("should return default delay if deregistration_delay is not a number", func(t *testing.T) { + albMock, atask := setup(t) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetGroupAttributesOutput{ + Attributes: []elbv2types.TargetGroupAttribute{ + {Key: aws.String("deregistration_delay.timeout_seconds"), Value: aws.String("invalid")}, + }, + }, nil) + delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + assert.Error(t, err) + assert.Equal(t, 300*time.Second, delay) + }) + t.Run("should error if DescribeTargetGroupAttributes failed", func(t *testing.T) { + albMock, atask := setup(t) + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + assert.Equal(t, 300*time.Second, delay) + }) +} + func TestAlbTask_DeregisterTarget(t *testing.T) { target := &elbv2types.TargetDescription{ Id: aws.String("127.0.0.1"), From 9f18e7b33af0dbd9070e8838774612512a728e1e Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 18:31:14 +0900 Subject: [PATCH 38/47] test --- go.mod | 1 - go.sum | 8 -- rollout.go | 107 ++--------------------- rollout/rollout.go | 139 +++++++++++++++++++++++++++++ rollout/rollout_test.go | 189 ++++++++++++++++++++++++++++++++++++++++ rollout_test.go | 3 + test/context.go | 3 + 7 files changed, 340 insertions(+), 110 deletions(-) create mode 100644 rollout/rollout.go create mode 100644 rollout/rollout_test.go diff --git a/go.mod b/go.mod index 3361920..4639c0f 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,6 @@ require ( github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 // indirect - github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.30.1 github.com/aws/aws-sdk-go-v2/service/sso v1.20.9 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.28.10 // indirect diff --git a/go.sum b/go.sum index 4ed95c0..7ea82f3 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,6 @@ github.com/apex/logs v1.0.0/go.mod h1:XzxuLZ5myVHDy9SAmYpamKKRNApGj54PfYLcFrXqDw github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy8kCu4PNA+aP7WUV72eXWJeP9/r3/K9aLE= github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys= github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= -github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= -github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2 v1.30.0 h1:6qAwtzlfcTtcL8NHtbDQAqgM5s6NDipQTkPxyH/6kAA= github.com/aws/aws-sdk-go-v2 v1.30.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= github.com/aws/aws-sdk-go-v2/config v1.27.16 h1:knpCuH7laFVGYTNd99Ns5t+8PuRjDn4HnnZK48csipM= @@ -16,12 +14,8 @@ github.com/aws/aws-sdk-go-v2/credentials v1.17.16 h1:7d2QxY83uYl0l58ceyiSpxg9bSb github.com/aws/aws-sdk-go-v2/credentials v1.17.16/go.mod h1:Ae6li/6Yc6eMzysRL2BXlPYvnrLLBg3D11/AmOjw50k= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3 h1:dQLK4TjtnlRGb0czOht2CevZ5l6RSyRWAnKeGd7VAFE= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.3/go.mod h1:TL79f2P6+8Q7dTsILpiVST+AL9lkF6PPGI167Ny0Cjw= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 h1:lf/8VTF2cM+N4SLzaYJERKEWAXq8MOMpZfU6wEPWsPk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7/go.mod h1:4SjkU7QiqK2M9oozyMzfZ/23LmUY+h3oFqhdeP5OMiI= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12 h1:SJ04WXGTwnHlWIODtC5kJzKbeuHt+OUNOgKg7nfnUGw= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.12/go.mod h1:FkpvXhA92gb3GE9LD6Og0pHHycTxW7xGpnEh5E7Opwo= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 h1:4OYVp0705xu8yjdyoWix0r9wPIRXnIzzOoUpQVHIJ/g= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7/go.mod h1:vd7ESTEvI76T2Na050gODNmNU7+OyKrIKroYTu4ABiI= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12 h1:hb5KgeYfObi5MHkSSZMEudnIvX30iB+E21evI4r6BnQ= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.12/go.mod h1:CroKe/eWJdyfy9Vx4rljP5wTUjNJfb+fPz1uMYUhEGM= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= @@ -36,8 +30,6 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1x github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9 h1:Wx0rlZoEJR7JwlSZcHnEa7CNjrSIyVxMFWGAaXy4fJY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.9/go.mod h1:aVMHdE0aHO3v+f/iw01fmXV/5DbfQ3Bi9nN7nd9bE9Y= -github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.30.1 h1:N19/J0IqsoNlkbXLe+JYWLjOyGmRijt6dw0+MaL/9wE= -github.com/aws/aws-sdk-go-v2/service/servicediscovery v1.30.1/go.mod h1:uuMsqZ2ATDqrzaAldWWuEUd9KGqi1NmnjroG6Eoe7W4= github.com/aws/aws-sdk-go-v2/service/sso v1.20.9 h1:aD7AGQhvPuAxlSUfo0CWU7s6FpkbyykMhGYMvlqTjVs= github.com/aws/aws-sdk-go-v2/service/sso v1.20.9/go.mod h1:c1qtZUWtygI6ZdvKppzCSXsDOq5I4luJPZ0Ud3juFCA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.24.3 h1:Pav5q3cA260Zqez42T9UhIlsd9QeypszRPwC9LdSSsQ= diff --git a/rollout.go b/rollout.go index efc6822..32e6872 100644 --- a/rollout.go +++ b/rollout.go @@ -9,16 +9,13 @@ import ( "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" - "github.com/loilo-inc/canarycage/task" - "github.com/loilo-inc/canarycage/taskset" + "github.com/loilo-inc/canarycage/rollout" "github.com/loilo-inc/canarycage/types" "golang.org/x/xerrors" ) func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.RollOutResult, error) { - result := &types.RollOutResult{ - ServiceIntact: true, - } + result := &types.RollOutResult{ServiceIntact: true} env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) if out, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ @@ -44,100 +41,8 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R } else { nextTaskDefinition = o } - if input.UpdateService { - log.Info("--updateService flag is set. use provided service configurations for canary test instead of current service") - } - canaryTasks, startCanaryTaskErr := c.StartCanaryTasks(ctx, nextTaskDefinition, input) - // ensure canary task stopped after rolling out either success or failure - defer func() { - _ = recover() - if canaryTasks == nil { - return - } else if err := canaryTasks.Cleanup(ctx); err != nil { - log.Errorf("failed to cleanup canary tasks due to: %s", err) - } - }() - if startCanaryTaskErr != nil { - return result, xerrors.Errorf("failed to start canary task due to: %w", startCanaryTaskErr) - } - log.Infof("executing canary tasks...") - if err := canaryTasks.Exec(ctx); err != nil { - return result, xerrors.Errorf("failed to exec canary task due to: %w", err) - } - log.Infof("canary tasks have been executed successfully!") - log.Infof( - "updating the task definition of '%s' into '%s:%d'...", - env.Service, *nextTaskDefinition.Family, nextTaskDefinition.Revision, - ) - updateInput := &ecs.UpdateServiceInput{ - Cluster: &env.Cluster, - Service: &env.Service, - TaskDefinition: nextTaskDefinition.TaskDefinitionArn, - } - if input.UpdateService { - updateInput.LoadBalancers = env.ServiceDefinitionInput.LoadBalancers - updateInput.NetworkConfiguration = env.ServiceDefinitionInput.NetworkConfiguration - updateInput.ServiceConnectConfiguration = env.ServiceDefinitionInput.ServiceConnectConfiguration - updateInput.ServiceRegistries = env.ServiceDefinitionInput.ServiceRegistries - updateInput.PlatformVersion = env.ServiceDefinitionInput.PlatformVersion - updateInput.VolumeConfigurations = env.ServiceDefinitionInput.VolumeConfigurations - } - if _, err := ecsCli.UpdateService(ctx, updateInput); err != nil { - return result, err - } - result.ServiceIntact = false - log.Infof("waiting for service '%s' to be stable...", env.Service) - if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ - Cluster: &env.Cluster, - Services: []string{env.Service}, - }, env.GetServiceStableWait()); err != nil { - return result, err - } - log.Infof("🥴 service '%s' has become to be stable!", env.Service) - log.Infof( - "🐥 service '%s' successfully rolled out to '%s:%d'!", - env.Service, *nextTaskDefinition.Family, nextTaskDefinition.Revision, - ) - return result, nil -} - -func (c *cage) StartCanaryTasks( - ctx context.Context, - nextTaskDefinition *ecstypes.TaskDefinition, - input *types.RollOutInput, -) (taskset.Set, error) { - var networkConfiguration *ecstypes.NetworkConfiguration - var platformVersion *string - var loadBalancers []ecstypes.LoadBalancer - env := c.di.Get(key.Env).(*env.Envars) - if input.UpdateService { - networkConfiguration = env.ServiceDefinitionInput.NetworkConfiguration - platformVersion = env.ServiceDefinitionInput.PlatformVersion - loadBalancers = env.ServiceDefinitionInput.LoadBalancers - } else { - ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - if o, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ - Cluster: &env.Cluster, - Services: []string{env.Service}, - }); err != nil { - return nil, err - } else { - service := o.Services[0] - networkConfiguration = service.NetworkConfiguration - platformVersion = service.PlatformVersion - loadBalancers = service.LoadBalancers - } - } - factory := c.di.Get(key.TaskFactory).(task.Factory) - return taskset.NewSet( - factory, - &taskset.Input{ - Input: &task.Input{ - NetworkConfiguration: networkConfiguration, - TaskDefinition: nextTaskDefinition, - PlatformVersion: platformVersion, - }, - LoadBalancers: loadBalancers, - }, - ), nil + executor := rollout.NewExecutor(c.di, nextTaskDefinition) + result.ServiceIntact = !executor.ServiceUpdated() + err := executor.RollOut(ctx, input) + return result, err } diff --git a/rollout/rollout.go b/rollout/rollout.go new file mode 100644 index 0000000..6378b6c --- /dev/null +++ b/rollout/rollout.go @@ -0,0 +1,139 @@ +package rollout + +import ( + "context" + + "github.com/apex/log" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/loilo-inc/canarycage/awsiface" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/taskset" + "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" +) + +type Executor interface { + RollOut(ctx context.Context, input *types.RollOutInput) error + ServiceUpdated() bool +} + +type executor struct { + di *di.D + td *ecstypes.TaskDefinition + serviceUpdated bool +} + +func NewExecutor(di *di.D, td *ecstypes.TaskDefinition) Executor { + return &executor{di: di, td: td} +} + +func (c *executor) RollOut(ctx context.Context, input *types.RollOutInput) error { + env := c.di.Get(key.Env).(*env.Envars) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + if input.UpdateService { + log.Info("--updateService flag is set. use provided service configurations for canary test instead of current service") + } + canaryTasks, startCanaryTaskErr := c.startCanaryTasks(ctx, input) + // ensure canary task stopped after rolling out either success or failure + defer func() { + _ = recover() + if canaryTasks == nil { + return + } else if err := canaryTasks.Cleanup(ctx); err != nil { + log.Errorf("failed to cleanup canary tasks due to: %s", err) + } + }() + if startCanaryTaskErr != nil { + log.Errorf("😨 failed to start canary task due to: %w", startCanaryTaskErr) + return startCanaryTaskErr + } + log.Infof("executing canary tasks...") + if err := canaryTasks.Exec(ctx); err != nil { + log.Errorf("😨 failed to exec canary tasks: %s", err) + return err + } + log.Infof("canary tasks have been executed successfully!") + log.Infof( + "updating the task definition of '%s' into '%s:%d'...", + env.Service, *c.td.Family, c.td.Revision, + ) + updateInput := &ecs.UpdateServiceInput{ + Cluster: &env.Cluster, + Service: &env.Service, + TaskDefinition: c.td.TaskDefinitionArn, + } + if input.UpdateService { + updateInput.LoadBalancers = env.ServiceDefinitionInput.LoadBalancers + updateInput.NetworkConfiguration = env.ServiceDefinitionInput.NetworkConfiguration + updateInput.ServiceConnectConfiguration = env.ServiceDefinitionInput.ServiceConnectConfiguration + updateInput.ServiceRegistries = env.ServiceDefinitionInput.ServiceRegistries + updateInput.PlatformVersion = env.ServiceDefinitionInput.PlatformVersion + updateInput.VolumeConfigurations = env.ServiceDefinitionInput.VolumeConfigurations + } + if _, err := ecsCli.UpdateService(ctx, updateInput); err != nil { + log.Errorf("😨 failed to update service: %s", err) + return err + } + c.serviceUpdated = false + log.Infof("waiting for service '%s' to be stable...", env.Service) + if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ + Cluster: &env.Cluster, + Services: []string{env.Service}, + }, env.GetServiceStableWait()); err != nil { + log.Errorf("😨 failed to wait for service to be stable: %s", err) + return err + } + log.Infof("🥴 service '%s' has become to be stable!", env.Service) + log.Infof( + "🐥 service '%s' successfully rolled out to '%s:%d'!", + env.Service, *c.td.Family, c.td.Revision, + ) + return nil +} + +func (c *executor) ServiceUpdated() bool { + return c.serviceUpdated +} + +func (c *executor) startCanaryTasks( + ctx context.Context, + input *types.RollOutInput, +) (taskset.Set, error) { + var networkConfiguration *ecstypes.NetworkConfiguration + var platformVersion *string + var loadBalancers []ecstypes.LoadBalancer + env := c.di.Get(key.Env).(*env.Envars) + factory := c.di.Get(key.TaskFactory).(task.Factory) + ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) + if input.UpdateService { + networkConfiguration = env.ServiceDefinitionInput.NetworkConfiguration + platformVersion = env.ServiceDefinitionInput.PlatformVersion + loadBalancers = env.ServiceDefinitionInput.LoadBalancers + } else { + if o, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ + Cluster: &env.Cluster, + Services: []string{env.Service}, + }); err != nil { + return nil, err + } else { + service := o.Services[0] + networkConfiguration = service.NetworkConfiguration + platformVersion = service.PlatformVersion + loadBalancers = service.LoadBalancers + } + } + return taskset.NewSet( + factory, + &taskset.Input{ + Input: &task.Input{ + NetworkConfiguration: networkConfiguration, + TaskDefinition: c.td, + PlatformVersion: platformVersion, + }, + LoadBalancers: loadBalancers, + }, + ), nil +} diff --git a/rollout/rollout_test.go b/rollout/rollout_test.go new file mode 100644 index 0000000..1834e3a --- /dev/null +++ b/rollout/rollout_test.go @@ -0,0 +1,189 @@ +package rollout + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ecs" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + "github.com/golang/mock/gomock" + "github.com/loilo-inc/canarycage/env" + "github.com/loilo-inc/canarycage/key" + "github.com/loilo-inc/canarycage/mocks/mock_awsiface" + "github.com/loilo-inc/canarycage/mocks/mock_task" + "github.com/loilo-inc/canarycage/task" + "github.com/loilo-inc/canarycage/test" + "github.com/loilo-inc/canarycage/types" + "github.com/loilo-inc/logos/di" + "github.com/stretchr/testify/assert" +) + +func TestExecutor_Rollout(t *testing.T) { + setup := func(t *testing.T) ( + *executor, + *env.Envars, + *test.MockContext, + *mock_awsiface.MockEcsClient, + *mock_task.MockTask, + *ecstypes.TaskDefinition, + ) { + ctrl := gomock.NewController(t) + envars := test.DefaultEnvars() + factoryMock := mock_task.NewMockFactory(ctrl) + taskMock := mock_task.NewMockTask(ctrl) + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + srv, _ := mocker.Ecs.CreateService(context.TODO(), envars.ServiceDefinitionInput) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.TaskFactory, factoryMock) + }) + factoryMock.EXPECT().NewAlbTask(&task.Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: srv.Service.NetworkConfiguration, + PlatformVersion: srv.Service.PlatformVersion, + }, &srv.Service.LoadBalancers[0]).Return(taskMock) + e := &executor{di: d, td: td.TaskDefinition} + return e, envars, mocker, ecsMock, taskMock, td.TaskDefinition + } + t.Run("basic", func(t *testing.T) { + e, envars, mocker, ecsMock, taskMock, td := setup(t) + gomock.InOrder( + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(mocker.Ecs.DescribeServices), + taskMock.EXPECT().Start(gomock.Any()).Return(nil), + taskMock.EXPECT().Wait(gomock.Any()).Return(nil), + ecsMock.EXPECT().UpdateService(gomock.Any(), &ecs.UpdateServiceInput{ + Cluster: &envars.Cluster, + Service: &envars.Service, + TaskDefinition: td.TaskDefinitionArn, + ServiceConnectConfiguration: nil, + LoadBalancers: nil, + NetworkConfiguration: nil, + PlatformVersion: nil, + VolumeConfigurations: nil, + ServiceRegistries: nil, + }). + DoAndReturn(mocker.Ecs.UpdateService), + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(mocker.Ecs.DescribeServices), + taskMock.EXPECT().Stop(gomock.Any()).Return(nil), + ) + err := e.RollOut(context.TODO(), &types.RollOutInput{}) + if err != nil { + t.Errorf("RollOut() error = %v", err) + } + srv, _ := mocker.GetEcsService(envars.Service) + assert.Equal(t, *srv.TaskDefinition, *td.TaskDefinitionArn) + }) + t.Run("updateService", func(t *testing.T) { + e, envars, mocker, ecsMock, taskMock, td := setup(t) + gomock.InOrder( + taskMock.EXPECT().Start(gomock.Any()).Return(nil), + taskMock.EXPECT().Wait(gomock.Any()).Return(nil), + ecsMock.EXPECT().UpdateService(gomock.Any(), &ecs.UpdateServiceInput{ + Cluster: &envars.Cluster, + Service: &envars.Service, + TaskDefinition: td.TaskDefinitionArn, + ServiceConnectConfiguration: envars.ServiceDefinitionInput.ServiceConnectConfiguration, + LoadBalancers: envars.ServiceDefinitionInput.LoadBalancers, + NetworkConfiguration: envars.ServiceDefinitionInput.NetworkConfiguration, + PlatformVersion: envars.ServiceDefinitionInput.PlatformVersion, + VolumeConfigurations: envars.ServiceDefinitionInput.VolumeConfigurations, + ServiceRegistries: envars.ServiceDefinitionInput.ServiceRegistries, + }). + DoAndReturn(mocker.Ecs.UpdateService), + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(mocker.Ecs.DescribeServices), + taskMock.EXPECT().Stop(gomock.Any()).Return(nil), + ) + err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) + if err != nil { + t.Errorf("RollOut() error = %v", err) + } + srv, _ := mocker.GetEcsService(envars.Service) + assert.Equal(t, *srv.TaskDefinition, *td.TaskDefinitionArn) + }) +} + +func TestExecutor_Rollout_Failure(t *testing.T) { + setup := func(t *testing.T) (*executor, *env.Envars, *mock_awsiface.MockEcsClient, *mock_task.MockFactory, *mock_task.MockTask) { + ctrl := gomock.NewController(t) + envars := test.DefaultEnvars() + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + taskMock := mock_task.NewMockTask(ctrl) + factoryMock := mock_task.NewMockFactory(ctrl) + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.TaskFactory, factoryMock) + }) + td := &ecstypes.TaskDefinition{ + TaskDefinitionArn: aws.String("arn://aaa"), + Family: aws.String("family"), + Revision: 1, + } + e := &executor{di: d, td: td} + return e, envars, ecsMock, factoryMock, taskMock + } + t.Run("should not call task.Task.Stop() if task not created", func(t *testing.T) { + e, _, ecsCli, _, _ := setup(t) + ecsCli.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return(nil, test.Err) + err := e.RollOut(context.TODO(), &types.RollOutInput{}) + assert.EqualError(t, err, "error") + }) + t.Run("should call task.Task.Stop() even if task.Task.Start() failed", func(t *testing.T) { + e, _, _, factoryMock, taskMock := setup(t) + gomock.InOrder( + factoryMock.EXPECT().NewAlbTask(gomock.Any(), gomock.Any()).Return(taskMock), + taskMock.EXPECT().Start(gomock.Any()).Return(test.Err), + taskMock.EXPECT().Stop(gomock.Any()).Return(nil), + ) + err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) + assert.EqualError(t, err, "error") + }) + t.Run("should call task.Task.Stop() even if task.Task.Wait() failed", func(t *testing.T) { + e, _, _, factoryMock, taskMock := setup(t) + gomock.InOrder( + factoryMock.EXPECT().NewAlbTask(gomock.Any(), gomock.Any()).Return(taskMock), + taskMock.EXPECT().Start(gomock.Any()).Return(nil), + taskMock.EXPECT().Wait(gomock.Any()).Return(test.Err), + taskMock.EXPECT().Stop(gomock.Any()).Return(nil), + ) + err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) + assert.EqualError(t, err, "error") + }) + t.Run("should call task.Task.Stop() even if ecs.UpdateService() failed", func(t *testing.T) { + e, _, ecsMock, factoryMock, taskMock := setup(t) + gomock.InOrder( + factoryMock.EXPECT().NewAlbTask(gomock.Any(), gomock.Any()).Return(taskMock), + taskMock.EXPECT().Start(gomock.Any()).Return(nil), + taskMock.EXPECT().Wait(gomock.Any()).Return(nil), + ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any()). + Return(nil, test.Err), + taskMock.EXPECT().Stop(gomock.Any()).Return(nil), + ) + err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) + assert.EqualError(t, err, "error") + }) + t.Run("should call task.Task.Stop() even if ecs.DescribeServices() failed", func(t *testing.T) { + e, _, ecsMock, factoryMock, taskMock := setup(t) + gomock.InOrder( + factoryMock.EXPECT().NewAlbTask(gomock.Any(), gomock.Any()).Return(taskMock), + taskMock.EXPECT().Start(gomock.Any()).Return(nil), + taskMock.EXPECT().Wait(gomock.Any()).Return(nil), + ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any()). + Return(&ecs.UpdateServiceOutput{}, nil), + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ecs.DescribeServicesOutput{ + Services: []ecstypes.Service{{Status: aws.String("INACTIVE")}}, + }, nil), + taskMock.EXPECT().Stop(gomock.Any()).Return(nil), + ) + err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) + assert.EqualError(t, err, "waiter state transitioned to Failure") + }) +} diff --git a/rollout_test.go b/rollout_test.go index bf763c2..c7e50eb 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -22,6 +22,9 @@ import ( "github.com/stretchr/testify/assert" ) +func TestCage_Rollout(t *testing.T) { +} + func TestCage_RollOut_FARGATE(t *testing.T) { log.SetLevel(log.DebugLevel) t.Run("basic", func(t *testing.T) { diff --git a/test/context.go b/test/context.go index f0c9c02..b4060b7 100644 --- a/test/context.go +++ b/test/context.go @@ -1,6 +1,7 @@ package test import ( + "fmt" "sync" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" @@ -70,3 +71,5 @@ func (ctx *commons) ActiveServiceSize() (count int) { } return } + +var Err = fmt.Errorf("error") From 382bfbb188d1ed59b4f79cacc8bc8259ff9af8ac Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 20:10:09 +0900 Subject: [PATCH 39/47] add tests --- Makefile | 5 +- cli/cage/commands/rollout.go | 2 +- mocks/mock_rollout/executor.go | 64 ++ rollout.go | 19 +- rollout/{rollout.go => executor.go} | 2 +- rollout/{rollout_test.go => executor_test.go} | 1 + rollout_test.go | 554 +++++------------- task_definition.go | 5 +- test/alb.go | 71 ++- test/context.go | 4 +- test/ecs.go | 10 +- types/iface.go | 2 +- 12 files changed, 287 insertions(+), 452 deletions(-) create mode 100644 mocks/mock_rollout/executor.go rename rollout/{rollout.go => executor.go} (99%) rename rollout/{rollout_test.go => executor_test.go} (98%) diff --git a/Makefile b/Makefile index e96d26f..75ea833 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,8 @@ mocks: mocks/mock_awsiface/iface.go \ mocks/mock_upgrade/upgrade.go \ mocks/mock_task/task.go \ mocks/mock_taskset/taskset.go \ - mocks/mock_task/factory.go + mocks/mock_task/factory.go \ + mocks/mock_rollout/executor.go mocks/mock_awsiface/iface.go: awsiface/iface.go $(MOCKGEN) -source=./awsiface/iface.go > mocks/mock_awsiface/iface.go mocks/mock_types/iface.go: types/iface.go @@ -27,4 +28,6 @@ mocks/mock_taskset/taskset.go: taskset/taskset.go $(MOCKGEN) -source=./taskset/taskset.go > mocks/mock_taskset/taskset.go mocks/mock_task/factory.go: task/factory.go $(MOCKGEN) -source=./task/factory.go > mocks/mock_task/factory.go +mocks/mock_rollout/executor.go: rollout/executor.go + $(MOCKGEN) -source=./rollout/executor.go > mocks/mock_rollout/executor.go .PHONY: mocks diff --git a/cli/cage/commands/rollout.go b/cli/cage/commands/rollout.go index 87d6129..0b2151c 100644 --- a/cli/cage/commands/rollout.go +++ b/cli/cage/commands/rollout.go @@ -56,7 +56,7 @@ func (c *CageCommands) RollOut( } result, err := cagecli.RollOut(context.Background(), &types.RollOutInput{UpdateService: updateServiceConf}) if err != nil { - if result.ServiceIntact { + if !result.ServiceUpdated { log.Errorf("🤕 failed to roll out new tasks but service '%s' is not changed", envars.Service) } else { log.Errorf("😭 failed to roll out new tasks and service '%s' might be changed. CHECK ECS CONSOLE NOW!", envars.Service) diff --git a/mocks/mock_rollout/executor.go b/mocks/mock_rollout/executor.go new file mode 100644 index 0000000..97296be --- /dev/null +++ b/mocks/mock_rollout/executor.go @@ -0,0 +1,64 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./rollout/executor.go + +// Package mock_rollout is a generated GoMock package. +package mock_rollout + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + types "github.com/loilo-inc/canarycage/types" +) + +// MockExecutor is a mock of Executor interface. +type MockExecutor struct { + ctrl *gomock.Controller + recorder *MockExecutorMockRecorder +} + +// MockExecutorMockRecorder is the mock recorder for MockExecutor. +type MockExecutorMockRecorder struct { + mock *MockExecutor +} + +// NewMockExecutor creates a new mock instance. +func NewMockExecutor(ctrl *gomock.Controller) *MockExecutor { + mock := &MockExecutor{ctrl: ctrl} + mock.recorder = &MockExecutorMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockExecutor) EXPECT() *MockExecutorMockRecorder { + return m.recorder +} + +// RollOut mocks base method. +func (m *MockExecutor) RollOut(ctx context.Context, input *types.RollOutInput) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RollOut", ctx, input) + ret0, _ := ret[0].(error) + return ret0 +} + +// RollOut indicates an expected call of RollOut. +func (mr *MockExecutorMockRecorder) RollOut(ctx, input interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RollOut", reflect.TypeOf((*MockExecutor)(nil).RollOut), ctx, input) +} + +// ServiceUpdated mocks base method. +func (m *MockExecutor) ServiceUpdated() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ServiceUpdated") + ret0, _ := ret[0].(bool) + return ret0 +} + +// ServiceUpdated indicates an expected call of ServiceUpdated. +func (mr *MockExecutorMockRecorder) ServiceUpdated() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ServiceUpdated", reflect.TypeOf((*MockExecutor)(nil).ServiceUpdated)) +} diff --git a/rollout.go b/rollout.go index 32e6872..133910a 100644 --- a/rollout.go +++ b/rollout.go @@ -15,7 +15,7 @@ import ( ) func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.RollOutResult, error) { - result := &types.RollOutResult{ServiceIntact: true} + result := &types.RollOutResult{} env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) if out, err := ecsCli.DescribeServices(ctx, &ecs.DescribeServicesInput{ @@ -23,12 +23,19 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R Services: []string{env.Service}, }); err != nil { return result, xerrors.Errorf("failed to describe current service due to: %w", err) - } else if len(out.Services) == 0 { - return result, xerrors.Errorf("service '%s' doesn't exist. Run 'cage up' or create service before rolling out", env.Service) } else { - service := out.Services[0] + var service *ecstypes.Service + for _, s := range out.Services { + if *s.ServiceName == env.Service { + service = &s + break + } + } + if service == nil { + return result, xerrors.Errorf("service '%s' doesn't exist. Run 'cage up' or create service before rolling out", env.Service) + } if *service.Status != "ACTIVE" { - return result, xerrors.Errorf("😵 '%s' status is '%s'. Stop rolling out", env.Service, *service.Status) + return result, xerrors.Errorf("😵 service '%s' status is '%s'. Stop rolling out", env.Service, *service.Status) } if service.LaunchType == ecstypes.LaunchTypeEc2 && env.CanaryInstanceArn == "" { return result, xerrors.Errorf("🥺 --canaryInstanceArn is required when LaunchType = 'EC2'") @@ -42,7 +49,7 @@ func (c *cage) RollOut(ctx context.Context, input *types.RollOutInput) (*types.R nextTaskDefinition = o } executor := rollout.NewExecutor(c.di, nextTaskDefinition) - result.ServiceIntact = !executor.ServiceUpdated() err := executor.RollOut(ctx, input) + result.ServiceUpdated = executor.ServiceUpdated() return result, err } diff --git a/rollout/rollout.go b/rollout/executor.go similarity index 99% rename from rollout/rollout.go rename to rollout/executor.go index 6378b6c..c44aad2 100644 --- a/rollout/rollout.go +++ b/rollout/executor.go @@ -77,7 +77,7 @@ func (c *executor) RollOut(ctx context.Context, input *types.RollOutInput) error log.Errorf("😨 failed to update service: %s", err) return err } - c.serviceUpdated = false + c.serviceUpdated = true log.Infof("waiting for service '%s' to be stable...", env.Service) if err := ecs.NewServicesStableWaiter(ecsCli).Wait(ctx, &ecs.DescribeServicesInput{ Cluster: &env.Cluster, diff --git a/rollout/rollout_test.go b/rollout/executor_test.go similarity index 98% rename from rollout/rollout_test.go rename to rollout/executor_test.go index 1834e3a..97552cc 100644 --- a/rollout/rollout_test.go +++ b/rollout/executor_test.go @@ -34,6 +34,7 @@ func TestExecutor_Rollout(t *testing.T) { taskMock := mock_task.NewMockTask(ctrl) mocker := test.NewMockContext() td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) + envars.ServiceDefinitionInput.TaskDefinition = td.TaskDefinition.TaskDefinitionArn srv, _ := mocker.Ecs.CreateService(context.TODO(), envars.ServiceDefinitionInput) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) d := di.NewDomain(func(b *di.B) { diff --git a/rollout_test.go b/rollout_test.go index c7e50eb..d7c1136 100644 --- a/rollout_test.go +++ b/rollout_test.go @@ -1,18 +1,17 @@ -package cage_test +package cage import ( "context" - "strings" + "fmt" "testing" - "github.com/apex/log" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - elbv2 "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" - elbv2types "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" + alb "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2" + albtypes "github.com/aws/aws-sdk-go-v2/service/elasticloadbalancingv2/types" "github.com/golang/mock/gomock" - cage "github.com/loilo-inc/canarycage" + "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/task" @@ -22,423 +21,172 @@ import ( "github.com/stretchr/testify/assert" ) -func TestCage_Rollout(t *testing.T) { -} - -func TestCage_RollOut_FARGATE(t *testing.T) { - log.SetLevel(log.DebugLevel) - t.Run("basic", func(t *testing.T) { - for _, v := range []int{1, 2, 15} { - log.Info("====") - envars := test.DefaultEnvars() - ctrl := gomock.NewController(t) - mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, v, "FARGATE") - - if mctx.ActiveServiceSize() != 1 { - t.Fatalf("current service not setup") - } - - if taskCnt := mctx.RunningTaskSize(); taskCnt != v { - t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) - } - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.NoError(t, err) - assert.False(t, result.ServiceIntact) - assert.Equal(t, 1, mctx.ActiveServiceSize()) - assert.Equal(t, v, mctx.RunningTaskSize()) +// fake integration test with test.MockContext +func TestCage_RollOut(t *testing.T) { + for i := 0b000; i < 0b1000; i++ { + env := test.DefaultEnvars() + isEc2 := i&0b001 > 0 + useExistingTd := i&0b010 > 0 + updateService := i&0b100 > 0 + if isEc2 { + env.CanaryInstanceArn = "arn:aws:ecs:us-west-2:123456789012:container-instance/123456789012" } - }) - t.Run("multiple load balancers", func(t *testing.T) { - log.Info("====") - envars := test.DefaultEnvars() - lb := envars.ServiceDefinitionInput.LoadBalancers[0] - envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{lb, lb} - ctrl := gomock.NewController(t) - - mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.NoError(t, err) - assert.False(t, result.ServiceIntact) - assert.Equal(t, 1, mctx.ActiveServiceSize()) - assert.Equal(t, 1, mctx.RunningTaskSize()) - }) - t.Run("wait until canary task is registered to target group", func(t *testing.T) { - envars := test.DefaultEnvars() - ctrl := gomock.NewController(t) - mocker, ecsMock, _, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") + if useExistingTd { + env.TaskDefinitionArn = "arn:aws:ecs:us-west-2:123456789012:task-definition/td" + env.TaskDefinitionInput = nil + } + for j := 0; j < 3; j++ { + t.Run(fmt.Sprintf("isEc2=%t, useTd=%t, lbcount=%d", isEc2, useExistingTd, j), func(t *testing.T) { + integrationTest(t, env, j, &types.RollOutInput{UpdateService: updateService}) + }) + } + } +} - albMock := mock_awsiface.NewMockAlbClient(ctrl) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.RegisterTargets).Times(1) - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DeregisterTargets).Times(1) - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetGroupAttributes).Times(1) - gomock.InOrder( - albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ - TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ - { - Target: &elbv2types.TargetDescription{ - Id: aws.String("127.0.0.1"), - Port: aws.Int32(80), - AvailabilityZone: aws.String("us-west-2"), - }, - TargetHealth: &elbv2types.TargetHealth{ - State: elbv2types.TargetHealthStateEnumUnused, - }, - }}, - }, nil).Times(2), - albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetHealth).AnyTimes(), - ) - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.NoError(t, err) - assert.NotNil(t, result) - }) - t.Run("stop rolloing out when canary task is not registered to target group", func(t *testing.T) { - envars := test.DefaultEnvars() - ctrl := gomock.NewController(t) - mocker, ecsMock, _, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") - albMock := mock_awsiface.NewMockAlbClient(ctrl) - albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.RegisterTargets).Times(1) - albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DeregisterTargets).Times(1) - albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetGroupAttributes).Times(1) - gomock.InOrder( - albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ - TargetHealthDescriptions: []elbv2types.TargetHealthDescription{{ - Target: &elbv2types.TargetDescription{ - Id: aws.String("192.0.0.1"), - Port: aws.Int32(8000), - AvailabilityZone: aws.String("us-west-2"), - }, - TargetHealth: &elbv2types.TargetHealth{ - State: elbv2types.TargetHealthStateEnumUnhealthy, - }, - }, { - Target: &elbv2types.TargetDescription{ - Id: aws.String("127.0.0.1"), - Port: aws.Int32(8000), - AvailabilityZone: aws.String("us-west-2"), - }, - TargetHealth: &elbv2types.TargetHealth{ - State: elbv2types.TargetHealthStateEnumUnused, - }, - }}, - }, nil), - albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Alb.DescribeTargetHealth).AnyTimes(), - ) - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.NotNil(t, err) - }) - t.Run("update service", func(t *testing.T) { - envars := test.DefaultEnvars() - ctrl := gomock.NewController(t) - mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - newLb := ecstypes.LoadBalancer{ +func integrationTest(t *testing.T, env *env.Envars, lbcount int, input *types.RollOutInput) { + mocker := test.NewMockContext() + var td *ecstypes.TaskDefinition + if env.TaskDefinitionArn == "" { + o, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), env.TaskDefinitionInput) + td = o.TaskDefinition + } else { + o, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), test.DefaultEnvars().TaskDefinitionInput) + td = o.TaskDefinition + env.TaskDefinitionArn = *td.TaskDefinitionArn + } + env.ServiceDefinitionInput.TaskDefinition = td.TaskDefinitionArn + _, _ = mocker.Ecs.CreateService(context.TODO(), env.ServiceDefinitionInput) + var lbs []ecstypes.LoadBalancer + for i := 0; i < lbcount; i++ { + tgArn := aws.String(fmt.Sprintf("tg%d", i+1)) + lbs = append(lbs, ecstypes.LoadBalancer{ ContainerName: aws.String("container"), ContainerPort: aws.Int32(80), - TargetGroupArn: aws.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:targetgroup/new-target-group/abcdefg"), - } - newNetwork := &ecstypes.NetworkConfiguration{ - AwsvpcConfiguration: &ecstypes.AwsVpcConfiguration{ - Subnets: []string{"subnet-1234567890abcdefg"}, - SecurityGroups: []string{"sg-12345678"}, + TargetGroupArn: tgArn, + }) + mocker.Alb.RegisterTargets(context.TODO(), &alb.RegisterTargetsInput{ + TargetGroupArn: tgArn, + Targets: []albtypes.TargetDescription{ + {Id: aws.String(fmt.Sprintf("127.0.0.%d", i+1))}, }, - } - envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{newLb} - envars.ServiceDefinitionInput.NetworkConfiguration = newNetwork - envars.ServiceDefinitionInput.PlatformVersion = aws.String("LATEST") - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - service, _ := mctx.GetEcsService(envars.Service) - assert.Equal(t, "1.4.0", *service.PlatformVersion) - assert.NotNil(t, service.NetworkConfiguration) - assert.NotNil(t, service.LoadBalancers) - _, err := cagecli.RollOut(ctx, &types.RollOutInput{UpdateService: true}) - assert.NoError(t, err) - service, _ = mctx.GetEcsService(envars.Service) - assert.Equal(t, "LATEST", *service.PlatformVersion) - assert.Equal(t, *newNetwork, *service.NetworkConfiguration) - assert.Equal(t, *service.LoadBalancers[0].ContainerName, *newLb.ContainerName) - }) - t.Run("should stop canary task if error occurs before registering target", func(t *testing.T) { - envars := test.DefaultEnvars() + }) + } + c := &cage{di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.Ec2Cli, mocker.Ec2) + b.Set(key.EcsCli, mocker.Ecs) + b.Set(key.AlbCli, mocker.Alb) + b.Set(key.Time, test.NewFakeTime()) + b.Set(key.TaskFactory, task.NewFactory(b.Future())) + })} + assert.Equal(t, 1, mocker.RunningTaskSize()) + assert.Equal(t, 1, len(mocker.TaskDefinitions.List())) + for _, lb := range lbs { + assert.Equal(t, 1, len(mocker.TargetGroups[*lb.TargetGroupArn].Targets)) + } + result, err := c.RollOut(context.TODO(), input) + if err != nil { + t.Fatal(err) + } + assert.True(t, result.ServiceUpdated) + assert.Equal(t, 1, mocker.RunningTaskSize()) + assert.Equal(t, 1, len(mocker.Services)) + if env.TaskDefinitionArn != "" { + assert.Equal(t, 1, len(mocker.TaskDefinitions.List())) + } else { + assert.Equal(t, 2, len(mocker.TaskDefinitions.List())) + } + for _, lb := range lbs { + assert.Equal(t, 1, len(mocker.TargetGroups[*lb.TargetGroupArn].Targets)) + } + updatedService, _ := mocker.GetEcsService(env.Service) + if env.TaskDefinitionArn != "" { + assert.Equal(t, env.TaskDefinitionArn, *updatedService.TaskDefinition) + } else { + assert.NotEqual(t, *td.TaskDefinitionArn, *updatedService.TaskDefinition) + } +} + +func TestCage_Rollout_Failure(t *testing.T) { + t.Run("should error if DescribeServices failed", func(t *testing.T) { ctrl := gomock.NewController(t) - mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "FARGATE") - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) + env := test.DefaultEnvars() + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + c := &cage{di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - envars.ServiceDefinitionInput.LoadBalancers = []ecstypes.LoadBalancer{ - { - ContainerName: aws.String("missing-container"), - ContainerPort: aws.Int32(80), - TargetGroupArn: aws.String("arn:aws:elasticloadbalancing:us-west-2:123456789012:targetgroup/new-target-group/abcdefg"), - }, - } - result, err := cagecli.RollOut(ctx, &types.RollOutInput{UpdateService: true}) - assert.EqualError(t, err, "failed to exec canary task due to: couldn't find host port in container definition") - assert.Equal(t, result.ServiceIntact, true) - assert.Equal(t, 1, mctx.RunningTaskSize()) + })} + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return(nil, test.Err) + result, err := c.RollOut(context.TODO(), &types.RollOutInput{}) + assert.EqualError(t, err, "failed to describe current service due to: error") + assert.False(t, result.ServiceUpdated) }) - t.Run("Show error if service doesn't exist", func(t *testing.T) { - envars := test.DefaultEnvars() + t.Run("should error if service doesn't exist", func(t *testing.T) { ctrl := gomock.NewController(t) - mocker, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") - delete(mocker.Services, envars.Service) - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) + env := test.DefaultEnvars() + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + c := &cage{di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - _, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.EqualError(t, err, "service 'service' doesn't exist. Run 'cage up' or create service before rolling out") + })} + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return(&ecs.DescribeServicesOutput{ + Services: []ecstypes.Service{}, + }, nil) + result, err := c.RollOut(context.TODO(), &types.RollOutInput{}) + assert.ErrorContains(t, err, "service 'service' doesn't exist") + assert.False(t, result.ServiceUpdated) }) - t.Run("Roll out even if the service does not have a load balancer", func(t *testing.T) { - envars := test.DefaultEnvars() - envars.ServiceDefinitionInput.LoadBalancers = nil - envars.CanaryTaskIdleDuration = 1 + t.Run("should error if service status is not ACTIVE", func(t *testing.T) { ctrl := gomock.NewController(t) - _, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) + env := test.DefaultEnvars() + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + c := &cage{di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - if res, err := cagecli.RollOut(ctx, &types.RollOutInput{}); err != nil { - t.Fatalf(err.Error()) - } else if res.ServiceIntact { - t.Fatalf("no") - } + })} + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return(&ecs.DescribeServicesOutput{ + Services: []ecstypes.Service{{ + ServiceName: aws.String("service"), + Status: aws.String("INACTIVE")}}, + }, nil) + result, err := c.RollOut(context.TODO(), &types.RollOutInput{}) + assert.ErrorContains(t, err, "😵 service 'service' status is 'INACTIVE'") + assert.False(t, result.ServiceUpdated) }) - - t.Run("stop rolloing out when service status is inactive", func(t *testing.T) { - envars := test.DefaultEnvars() + t.Run("should error if LaunchType is EC2 and --canaryInstanceArn is not provided", func(t *testing.T) { ctrl := gomock.NewController(t) + env := test.DefaultEnvars() ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return( - &ecs.DescribeServicesOutput{ - Services: []ecstypes.Service{ - {Status: aws.String("INACTIVE")}, - }, - }, nil, - ) - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) + c := &cage{di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - _, err := cagecli.RollOut(context.Background(), &types.RollOutInput{}) - assert.EqualError(t, err, "😵 'service' status is 'INACTIVE'. Stop rolling out") + })} + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return(&ecs.DescribeServicesOutput{ + Services: []ecstypes.Service{{ + ServiceName: aws.String("service"), + Status: aws.String("ACTIVE"), + LaunchType: ecstypes.LaunchTypeEc2}}, + }, nil) + result, err := c.RollOut(context.TODO(), &types.RollOutInput{}) + assert.ErrorContains(t, err, "--canaryInstanceArn is required when LaunchType = 'EC2'") + assert.False(t, result.ServiceUpdated) }) - t.Run("Stop rolling out if the canary task container does not become healthy", func(t *testing.T) { - envars := test.DefaultEnvars() + t.Run("should error if CreateNextTaskDefinition failed", func(t *testing.T) { ctrl := gomock.NewController(t) - mocker, _, albMock, ec2Mock := test.Setup(ctrl, envars, 2, "FARGATE") - + env := test.DefaultEnvars() ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - ecsMock.EXPECT().CreateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.CreateService).AnyTimes() - ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.UpdateService).AnyTimes() - ecsMock.EXPECT().DeleteService(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DeleteService).AnyTimes() - ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.StartTask).AnyTimes() - ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RegisterTaskDefinition).AnyTimes() - ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeServices).AnyTimes() - ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, input *ecs.DescribeTasksInput, opts ...func(options *ecs.Options)) (*ecs.DescribeTasksOutput, error) { - out, err := mocker.Ecs.DescribeTasks(ctx, input, opts...) - if err != nil { - return out, err - } - task := mocker.Tasks[input.Tasks[0]] - if strings.Contains(*task.Group, "canary-task") { - for i := range out.Tasks { - for i2 := range out.Tasks[i].Containers { - out.Tasks[i].Containers[i2].HealthStatus = ecstypes.HealthStatusUnknown - } - } - } - return out, err - }, - ).AnyTimes() - ecsMock.EXPECT().ListTasks(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.ListTasks).AnyTimes() - ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.DescribeContainerInstances).AnyTimes() - ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.RunTask).AnyTimes() - ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(mocker.Ecs.StopTask).AnyTimes() - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) + c := &cage{di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - res, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.NotNil(t, res) - assert.NotNil(t, err) - - for _, task := range mocker.Tasks { - if strings.Contains(*task.Group, "canary-task") { - assert.Equal(t, "STOPPED", *task.LastStatus) - } - } + })} + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return(&ecs.DescribeServicesOutput{ + Services: []ecstypes.Service{{ + ServiceName: aws.String("service"), + Status: aws.String("ACTIVE")}}, + }, nil) + ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).Return(nil, test.Err) + result, err := c.RollOut(context.TODO(), &types.RollOutInput{}) + assert.EqualError(t, err, "failed to register next task definition due to: error") + assert.False(t, result.ServiceUpdated) }) - -} - -func TestCage_RollOut_EC2(t *testing.T) { - log.SetLevel(log.DebugLevel) - for _, v := range []int{1, 2, 15} { - log.Info("====") - canaryInstanceArn := "arn:aws:ecs:us-west-2:1234567689012:container-instance/abcdefg-hijk-lmn-opqrstuvwxyz" - attributeValue := "true" - envars := test.DefaultEnvars() - envars.CanaryInstanceArn = canaryInstanceArn - ctrl := gomock.NewController(t) - mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, v, "ec2") - ecsMock.EXPECT().ListAttributes(gomock.Any(), gomock.Any()).Return(&ecs.ListAttributesOutput{ - Attributes: []ecstypes.Attribute{ - { - Name: &envars.Service, - Value: &attributeValue, - TargetId: &canaryInstanceArn, - }, - }, - }, nil).AnyTimes() - if mctx.ActiveServiceSize() != 1 { - t.Fatalf("current service not setup") - } - if taskCnt := mctx.RunningTaskSize(); taskCnt != v { - t.Fatalf("current tasks not setup: %d/%d", v, taskCnt) - } - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - if err != nil { - t.Fatalf("%s", err) - } - assert.False(t, result.ServiceIntact) - assert.Equal(t, 1, mctx.ActiveServiceSize()) - assert.Equal(t, v, mctx.RunningTaskSize()) - } -} - -func TestCage_RollOut_EC2_without_ContainerInstanceArn(t *testing.T) { - log.SetLevel(log.DebugLevel) - log.Info("====") - envars := test.DefaultEnvars() - ctrl := gomock.NewController(t) - mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "EC2") - if mctx.ActiveServiceSize() != 1 { - t.Fatalf("current service not setup") - } - if taskCnt := mctx.RunningTaskSize(); taskCnt != 1 { - t.Fatalf("current tasks not setup: %d/%d", 1, taskCnt) - } - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.ErrorContains(t, err, "canaryInstanceArn is required") - assert.NotNil(t, result) -} - -func TestCage_RollOut_EC2_no_attribute(t *testing.T) { - log.SetLevel(log.DebugLevel) - log.Info("====") - canaryInstanceArn := "arn:aws:ecs:us-west-2:1234567689012:container-instance/abcdefg-hijk-lmn-opqrstuvwxyz" - envars := test.DefaultEnvars() - envars.CanaryInstanceArn = canaryInstanceArn - ctrl := gomock.NewController(t) - mctx, ecsMock, albMock, ec2Mock := test.Setup(ctrl, envars, 1, "EC2") - if mctx.ActiveServiceSize() != 1 { - t.Fatalf("current service not setup") - } - if taskCnt := mctx.RunningTaskSize(); taskCnt != 1 { - t.Fatalf("current tasks not setup: %d/%d", 1, taskCnt) - } - ecsMock.EXPECT().ListAttributes(gomock.Any(), gomock.Any()).Return(&ecs.ListAttributesOutput{ - Attributes: []ecstypes.Attribute{}, - }, nil) - ecsMock.EXPECT().PutAttributes(gomock.Any(), gomock.Any()).Return(&ecs.PutAttributesOutput{}, nil) - cagecli := cage.NewCage(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.AlbCli, albMock) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.TaskFactory, task.NewFactory(b.Future())) - b.Set(key.Time, test.NewFakeTime()) - })) - ctx := context.Background() - result, err := cagecli.RollOut(ctx, &types.RollOutInput{}) - assert.NoError(t, err) - assert.False(t, result.ServiceIntact) - assert.Equal(t, 1, mctx.ActiveServiceSize()) - assert.Equal(t, 1, mctx.RunningTaskSize()) } diff --git a/task_definition.go b/task_definition.go index ee94e87..f0a7ea7 100644 --- a/task_definition.go +++ b/task_definition.go @@ -9,7 +9,6 @@ import ( "github.com/loilo-inc/canarycage/awsiface" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" - "golang.org/x/xerrors" ) func (c *cage) CreateNextTaskDefinition(ctx context.Context) (*ecstypes.TaskDefinition, error) { @@ -21,13 +20,13 @@ func (c *cage) CreateNextTaskDefinition(ctx context.Context) (*ecstypes.TaskDefi TaskDefinition: &env.TaskDefinitionArn, }) if err != nil { - return nil, xerrors.Errorf("failed to describe next task definition: %w", err) + return nil, err } return o.TaskDefinition, nil } else { log.Infof("creating next task definition...") if out, err := ecsCli.RegisterTaskDefinition(ctx, env.TaskDefinitionInput); err != nil { - return nil, xerrors.Errorf("failed to register next task definition: %w", err) + return nil, err } else { log.Infof( "task definition '%s:%d' has been registered", diff --git a/test/alb.go b/test/alb.go index 57782ec..6e4fb39 100644 --- a/test/alb.go +++ b/test/alb.go @@ -2,6 +2,7 @@ package test import ( "context" + "fmt" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" @@ -14,6 +15,10 @@ type AlbServer struct { *commons } +type TargetGroup struct { + Targets map[string]elbv2types.TargetDescription +} + func (ctx *AlbServer) DescribeTargetGroups(_ context.Context, input *elbv2.DescribeTargetGroupsInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetGroupsOutput, error) { return &elbv2.DescribeTargetGroupsOutput{ TargetGroups: []elbv2types.TargetGroup{ @@ -37,38 +42,29 @@ func (ctx *AlbServer) DescribeTargetGroupAttributes(_ context.Context, input *el }, }, nil } -func (ctx *AlbServer) DescribeTargetHealth(_ context.Context, input *elbv2.DescribeTargetHealthInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { - if _, ok := ctx.TargetGroups[*input.TargetGroupArn]; !ok { - return &elbv2.DescribeTargetHealthOutput{ - TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ - { - Target: &elbv2types.TargetDescription{ - Id: input.Targets[0].Id, - Port: input.Targets[0].Port, - AvailabilityZone: aws.String("us-west-2"), - }, - TargetHealth: &elbv2types.TargetHealth{ - State: elbv2types.TargetHealthStateEnumUnused, - }, - }, - }, - }, nil - } +func (ctx *AlbServer) DescribeTargetHealth(_ context.Context, input *elbv2.DescribeTargetHealthInput, _ ...func(options *elbv2.Options)) (*elbv2.DescribeTargetHealthOutput, error) { var ret []elbv2types.TargetHealthDescription - for _, task := range ctx.Tasks { - if task.LastStatus != nil && *task.LastStatus == "RUNNING" { - ret = append(ret, elbv2types.TargetHealthDescription{ - Target: &elbv2types.TargetDescription{ - Id: input.Targets[0].Id, - Port: input.Targets[0].Port, - AvailabilityZone: aws.String("us-west-2"), - }, - TargetHealth: &elbv2types.TargetHealth{ - State: elbv2types.TargetHealthStateEnumHealthy, - }, - }) + tg, ok := ctx.TargetGroups[*input.TargetGroupArn] + if !ok { + return nil, fmt.Errorf("target group not found") + } + for _, t := range input.Targets { + _, ok := tg.Targets[*t.Id] + var health elbv2types.TargetHealth + if !ok { + health = elbv2types.TargetHealth{ + State: elbv2types.TargetHealthStateEnumUnused, + } + } else { + health = elbv2types.TargetHealth{ + State: elbv2types.TargetHealthStateEnumHealthy, + } } + ret = append(ret, elbv2types.TargetHealthDescription{ + Target: &t, + TargetHealth: &health, + }) } return &elbv2.DescribeTargetHealthOutput{ TargetHealthDescriptions: ret, @@ -76,12 +72,25 @@ func (ctx *AlbServer) DescribeTargetHealth(_ context.Context, input *elbv2.Descr } func (ctx *AlbServer) RegisterTargets(_ context.Context, input *elbv2.RegisterTargetsInput, _ ...func(options *elbv2.Options)) (*elbv2.RegisterTargetsOutput, error) { - ctx.TargetGroups[*input.TargetGroupArn] = struct{}{} + tg, ok := ctx.TargetGroups[*input.TargetGroupArn] + if !ok { + tg = &TargetGroup{Targets: make(map[string]elbv2types.TargetDescription)} + ctx.TargetGroups[*input.TargetGroupArn] = tg + } + for _, t := range input.Targets { + tg.Targets[*t.Id] = t + } return &elbv2.RegisterTargetsOutput{}, nil } func (ctx *AlbServer) DeregisterTargets(_ context.Context, input *elbv2.DeregisterTargetsInput, _ ...func(options *elbv2.Options)) (*elbv2.DeregisterTargetsOutput, error) { - delete(ctx.TargetGroups, *input.TargetGroupArn) + tg, ok := ctx.TargetGroups[*input.TargetGroupArn] + if !ok { + return nil, fmt.Errorf("target group not found") + } + for _, t := range input.Targets { + delete(tg.Targets, *t.Id) + } return &elbv2.DeregisterTargetsOutput{}, nil } diff --git a/test/context.go b/test/context.go index b4060b7..f7eb701 100644 --- a/test/context.go +++ b/test/context.go @@ -12,7 +12,7 @@ type commons struct { Services map[string]*ecstypes.Service Tasks map[string]*ecstypes.Task TaskDefinitions *TaskDefinitionRepository - TargetGroups map[string]struct{} + TargetGroups map[string]*TargetGroup mux sync.Mutex } @@ -30,7 +30,7 @@ func NewMockContext() *MockContext { TaskDefinitions: &TaskDefinitionRepository{ families: make(map[string]*TaskDefinitionFamily), }, - TargetGroups: make(map[string]struct{}), + TargetGroups: make(map[string]*TargetGroup), } return &MockContext{ commons: cm, diff --git a/test/ecs.go b/test/ecs.go index 4c337fe..b25d116 100644 --- a/test/ecs.go +++ b/test/ecs.go @@ -14,6 +14,7 @@ import ( type EcsServer struct { *commons + ipv4 int } func (ctx *EcsServer) CreateService(c context.Context, input *ecs.CreateServiceInput, _ ...func(options *ecs.Options)) (*ecs.CreateServiceOutput, error) { @@ -52,12 +53,14 @@ func (ctx *EcsServer) CreateService(c context.Context, input *ecs.CreateServiceI ctx.mux.Unlock() log.Debugf("%s: running=%d, desired=%d", *input.ServiceName, ret.RunningCount, *input.DesiredCount) for i := 0; i < int(*input.DesiredCount); i++ { - ctx.StartTask(c, &ecs.StartTaskInput{ + if _, err := ctx.StartTask(c, &ecs.StartTaskInput{ Cluster: input.Cluster, Group: aws.String(fmt.Sprintf("service:%s", *input.ServiceName)), NetworkConfiguration: input.NetworkConfiguration, TaskDefinition: input.TaskDefinition, - }) + }); err != nil { + log.Fatalf("failed to start task: %v", err) + } } ctx.mux.Lock() ctx.Services[*input.ServiceName].RunningCount = *input.DesiredCount @@ -167,13 +170,14 @@ func (ctx *EcsServer) StartTask(_ context.Context, input *ecs.StartTaskInput, _ return nil, fmt.Errorf("task definition not found: %s", *input.TaskDefinition) } taskArn := fmt.Sprintf("arn:aws:ecs:us-west-2:012345678910:task/%s", uuid.New().String()) + ctx.ipv4++ attachment := types.Attachment{ Status: aws.String("ATTACHED"), Type: aws.String("ElasticNetworkInterface"), Details: []types.KeyValuePair{ { Name: aws.String("privateIPv4Address"), - Value: aws.String("127.0.0.1"), + Value: aws.String(fmt.Sprintf("127.0.0.%d", ctx.ipv4)), }, }, } diff --git a/types/iface.go b/types/iface.go index d24bbcc..e4ca1f0 100644 --- a/types/iface.go +++ b/types/iface.go @@ -32,7 +32,7 @@ type RollOutInput struct { } type RollOutResult struct { - ServiceIntact bool + ServiceUpdated bool } type UpResult struct { From 7481c34c488a39aa0662906114367416fe7df2b7 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 20:20:43 +0900 Subject: [PATCH 40/47] test --- rollout/executor.go | 6 +++--- rollout/executor_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/rollout/executor.go b/rollout/executor.go index c44aad2..8aa243c 100644 --- a/rollout/executor.go +++ b/rollout/executor.go @@ -30,7 +30,7 @@ func NewExecutor(di *di.D, td *ecstypes.TaskDefinition) Executor { return &executor{di: di, td: td} } -func (c *executor) RollOut(ctx context.Context, input *types.RollOutInput) error { +func (c *executor) RollOut(ctx context.Context, input *types.RollOutInput) (lastErr error) { env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) if input.UpdateService { @@ -42,8 +42,8 @@ func (c *executor) RollOut(ctx context.Context, input *types.RollOutInput) error _ = recover() if canaryTasks == nil { return - } else if err := canaryTasks.Cleanup(ctx); err != nil { - log.Errorf("failed to cleanup canary tasks due to: %s", err) + } else if lastErr = canaryTasks.Cleanup(ctx); lastErr != nil { + log.Errorf("failed to cleanup canary tasks due to: %s", lastErr) } }() if startCanaryTaskErr != nil { diff --git a/rollout/executor_test.go b/rollout/executor_test.go index 97552cc..1fe01be 100644 --- a/rollout/executor_test.go +++ b/rollout/executor_test.go @@ -19,6 +19,16 @@ import ( "github.com/stretchr/testify/assert" ) +func TestNewExecutor(t *testing.T) { + td := &ecstypes.TaskDefinition{} + di := di.EmptyDomain() + e := NewExecutor(di, td) + v, ok := e.(*executor) + assert.True(t, ok) + assert.Equal(t, td, v.td) + assert.Equal(t, di, v.di) +} + func TestExecutor_Rollout(t *testing.T) { setup := func(t *testing.T) ( *executor, @@ -79,6 +89,7 @@ func TestExecutor_Rollout(t *testing.T) { } srv, _ := mocker.GetEcsService(envars.Service) assert.Equal(t, *srv.TaskDefinition, *td.TaskDefinitionArn) + assert.True(t, e.ServiceUpdated()) }) t.Run("updateService", func(t *testing.T) { e, envars, mocker, ecsMock, taskMock, td := setup(t) @@ -107,6 +118,7 @@ func TestExecutor_Rollout(t *testing.T) { } srv, _ := mocker.GetEcsService(envars.Service) assert.Equal(t, *srv.TaskDefinition, *td.TaskDefinitionArn) + assert.True(t, e.ServiceUpdated()) }) } @@ -135,6 +147,7 @@ func TestExecutor_Rollout_Failure(t *testing.T) { ecsCli.EXPECT().DescribeServices(gomock.Any(), gomock.Any()).Return(nil, test.Err) err := e.RollOut(context.TODO(), &types.RollOutInput{}) assert.EqualError(t, err, "error") + assert.False(t, e.ServiceUpdated()) }) t.Run("should call task.Task.Stop() even if task.Task.Start() failed", func(t *testing.T) { e, _, _, factoryMock, taskMock := setup(t) @@ -145,6 +158,7 @@ func TestExecutor_Rollout_Failure(t *testing.T) { ) err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) assert.EqualError(t, err, "error") + assert.False(t, e.ServiceUpdated()) }) t.Run("should call task.Task.Stop() even if task.Task.Wait() failed", func(t *testing.T) { e, _, _, factoryMock, taskMock := setup(t) @@ -156,6 +170,7 @@ func TestExecutor_Rollout_Failure(t *testing.T) { ) err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) assert.EqualError(t, err, "error") + assert.False(t, e.ServiceUpdated()) }) t.Run("should call task.Task.Stop() even if ecs.UpdateService() failed", func(t *testing.T) { e, _, ecsMock, factoryMock, taskMock := setup(t) @@ -169,6 +184,7 @@ func TestExecutor_Rollout_Failure(t *testing.T) { ) err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) assert.EqualError(t, err, "error") + assert.False(t, e.ServiceUpdated()) }) t.Run("should call task.Task.Stop() even if ecs.DescribeServices() failed", func(t *testing.T) { e, _, ecsMock, factoryMock, taskMock := setup(t) @@ -186,5 +202,24 @@ func TestExecutor_Rollout_Failure(t *testing.T) { ) err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) assert.EqualError(t, err, "waiter state transitioned to Failure") + assert.False(t, e.ServiceUpdated()) + }) + t.Run("should log error if task.Task.Stop() failed", func(t *testing.T) { + e, _, ecsMock, factoryMock, taskMock := setup(t) + gomock.InOrder( + factoryMock.EXPECT().NewAlbTask(gomock.Any(), gomock.Any()).Return(taskMock), + taskMock.EXPECT().Start(gomock.Any()).Return(nil), + taskMock.EXPECT().Wait(gomock.Any()).Return(nil), + ecsMock.EXPECT().UpdateService(gomock.Any(), gomock.Any()). + Return(&ecs.UpdateServiceOutput{}, nil), + ecsMock.EXPECT().DescribeServices(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ecs.DescribeServicesOutput{ + Services: []ecstypes.Service{{Status: aws.String("INACTIVE")}}, + }, nil), + taskMock.EXPECT().Stop(gomock.Any()).Return(test.Err), + ) + err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) + assert.EqualError(t, err, "error") + assert.True(t, e.ServiceUpdated()) }) } From 88b20b89820565fe6961b25d272daf727092e570 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 20:21:31 +0900 Subject: [PATCH 41/47] removed unused ecs api --- README.md | 2 -- awsiface/iface.go | 2 -- mocks/mock_awsiface/iface.go | 40 ------------------------------------ 3 files changed, 44 deletions(-) diff --git a/README.md b/README.md index 5e258c9..0c91991 100644 --- a/README.md +++ b/README.md @@ -113,8 +113,6 @@ By default, `cage rollout` will only update the task definition of the service. "ecs:ListTasks", "ecs:RunTask", "ecs:StopTask", - "ecs:ListAttributes", - "ecs:PutAttributes", "ecs:DescribeTaskDefinition" ], "Resource": "*" diff --git a/awsiface/iface.go b/awsiface/iface.go index d377f3e..94cd650 100644 --- a/awsiface/iface.go +++ b/awsiface/iface.go @@ -21,8 +21,6 @@ type ( ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) RunTask(ctx context.Context, params *ecs.RunTaskInput, optFns ...func(*ecs.Options)) (*ecs.RunTaskOutput, error) StopTask(ctx context.Context, params *ecs.StopTaskInput, optFns ...func(*ecs.Options)) (*ecs.StopTaskOutput, error) - ListAttributes(ctx context.Context, params *ecs.ListAttributesInput, optFns ...func(*ecs.Options)) (*ecs.ListAttributesOutput, error) - PutAttributes(ctx context.Context, params *ecs.PutAttributesInput, optFns ...func(*ecs.Options)) (*ecs.PutAttributesOutput, error) DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error) } AlbClient interface { diff --git a/mocks/mock_awsiface/iface.go b/mocks/mock_awsiface/iface.go index b582870..4fac71b 100644 --- a/mocks/mock_awsiface/iface.go +++ b/mocks/mock_awsiface/iface.go @@ -157,26 +157,6 @@ func (mr *MockEcsClientMockRecorder) DescribeTasks(ctx, params interface{}, optF return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeTasks", reflect.TypeOf((*MockEcsClient)(nil).DescribeTasks), varargs...) } -// ListAttributes mocks base method. -func (m *MockEcsClient) ListAttributes(ctx context.Context, params *ecs.ListAttributesInput, optFns ...func(*ecs.Options)) (*ecs.ListAttributesOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, params} - for _, a := range optFns { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "ListAttributes", varargs...) - ret0, _ := ret[0].(*ecs.ListAttributesOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListAttributes indicates an expected call of ListAttributes. -func (mr *MockEcsClientMockRecorder) ListAttributes(ctx, params interface{}, optFns ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, params}, optFns...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAttributes", reflect.TypeOf((*MockEcsClient)(nil).ListAttributes), varargs...) -} - // ListTasks mocks base method. func (m *MockEcsClient) ListTasks(ctx context.Context, params *ecs.ListTasksInput, optFns ...func(*ecs.Options)) (*ecs.ListTasksOutput, error) { m.ctrl.T.Helper() @@ -197,26 +177,6 @@ func (mr *MockEcsClientMockRecorder) ListTasks(ctx, params interface{}, optFns . return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockEcsClient)(nil).ListTasks), varargs...) } -// PutAttributes mocks base method. -func (m *MockEcsClient) PutAttributes(ctx context.Context, params *ecs.PutAttributesInput, optFns ...func(*ecs.Options)) (*ecs.PutAttributesOutput, error) { - m.ctrl.T.Helper() - varargs := []interface{}{ctx, params} - for _, a := range optFns { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "PutAttributes", varargs...) - ret0, _ := ret[0].(*ecs.PutAttributesOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// PutAttributes indicates an expected call of PutAttributes. -func (mr *MockEcsClientMockRecorder) PutAttributes(ctx, params interface{}, optFns ...interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]interface{}{ctx, params}, optFns...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutAttributes", reflect.TypeOf((*MockEcsClient)(nil).PutAttributes), varargs...) -} - // RegisterTaskDefinition mocks base method. func (m *MockEcsClient) RegisterTaskDefinition(ctx context.Context, params *ecs.RegisterTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.RegisterTaskDefinitionOutput, error) { m.ctrl.T.Helper() From 53dd2b1c5e7effd55f8dfe77ff7ff99b1bfe91d0 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 20:40:19 +0900 Subject: [PATCH 42/47] test --- rollout/executor.go | 5 +- rollout/executor_test.go | 4 +- task/alb_task.go | 50 +++++------ task/alb_task_test.go | 186 +++++++++++++++++++++++++-------------- task/common.go | 30 +++---- task/common_test.go | 97 ++++++++++++-------- task/export_test.go | 22 ----- task/factory_test.go | 19 ++-- task/simple_task.go | 2 +- task/simple_task_test.go | 48 +++++++--- 10 files changed, 271 insertions(+), 192 deletions(-) delete mode 100644 task/export_test.go diff --git a/rollout/executor.go b/rollout/executor.go index 8aa243c..03d6700 100644 --- a/rollout/executor.go +++ b/rollout/executor.go @@ -42,8 +42,9 @@ func (c *executor) RollOut(ctx context.Context, input *types.RollOutInput) (last _ = recover() if canaryTasks == nil { return - } else if lastErr = canaryTasks.Cleanup(ctx); lastErr != nil { - log.Errorf("failed to cleanup canary tasks due to: %s", lastErr) + } else if err := canaryTasks.Cleanup(ctx); err != nil { + log.Errorf("failed to cleanup canary tasks due to: %s", err) + lastErr = err } }() if startCanaryTaskErr != nil { diff --git a/rollout/executor_test.go b/rollout/executor_test.go index 1fe01be..8d60197 100644 --- a/rollout/executor_test.go +++ b/rollout/executor_test.go @@ -186,7 +186,7 @@ func TestExecutor_Rollout_Failure(t *testing.T) { assert.EqualError(t, err, "error") assert.False(t, e.ServiceUpdated()) }) - t.Run("should call task.Task.Stop() even if ecs.DescribeServices() failed", func(t *testing.T) { + t.Run("should call task.Task.Stop() even if ecs.NewServicesStableWaiter.Wait() failed", func(t *testing.T) { e, _, ecsMock, factoryMock, taskMock := setup(t) gomock.InOrder( factoryMock.EXPECT().NewAlbTask(gomock.Any(), gomock.Any()).Return(taskMock), @@ -202,7 +202,7 @@ func TestExecutor_Rollout_Failure(t *testing.T) { ) err := e.RollOut(context.TODO(), &types.RollOutInput{UpdateService: true}) assert.EqualError(t, err, "waiter state transitioned to Failure") - assert.False(t, e.ServiceUpdated()) + assert.True(t, e.ServiceUpdated()) }) t.Run("should log error if task.Task.Stop() failed", func(t *testing.T) { e, _, ecsMock, factoryMock, taskMock := setup(t) diff --git a/task/alb_task.go b/task/alb_task.go index 5e9cd00..83b85b2 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -23,8 +23,8 @@ import ( // albTask is a task that is attached to an Application Load Balancer type albTask struct { *common - Lb *ecstypes.LoadBalancer - Target *elbv2types.TargetDescription + lb *ecstypes.LoadBalancer + target *elbv2types.TargetDescription } func NewAlbTask( @@ -34,7 +34,7 @@ func NewAlbTask( ) Task { return &albTask{ common: &common{Input: input, di: di}, - Lb: lb, + lb: lb, } } @@ -48,7 +48,7 @@ func (c *albTask) Wait(ctx context.Context) error { if err := c.RegisterToTargetGroup(ctx); err != nil { return err } - log.Infof("canary task '%s' is registered to target group '%s'", *c.Target.Id, *c.Lb.TargetGroupArn) + log.Infof("canary task '%s' is registered to target group '%s'", *c.target.Id, *c.lb.TargetGroupArn) log.Infof("😷 waiting canary target to be healthy...") if err := c.WaitUntilTargetHealthy(ctx); err != nil { return err @@ -98,7 +98,7 @@ func (c *albTask) getFargateTargetNetwork(ctx context.Context) (*string, *string ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.TaskArn}, + Tasks: []string{*c.taskArn}, }); err != nil { return nil, nil, err } else { @@ -149,7 +149,7 @@ func (c *albTask) getEc2TargetNetwork(ctx context.Context) (*string, *string, er func (c *albTask) getTargetPort() (*int32, error) { for _, container := range c.TaskDefinition.ContainerDefinitions { - if *container.Name == *c.Lb.ContainerName { + if *container.Name == *c.lb.ContainerName { return container.PortMappings[0].HostPort, nil } } @@ -157,16 +157,16 @@ func (c *albTask) getTargetPort() (*int32, error) { } func (c *albTask) RegisterToTargetGroup(ctx context.Context) error { - log.Infof("registering the canary task to target group '%s'...", *c.Lb.TargetGroupArn) + log.Infof("registering the canary task to target group '%s'...", *c.lb.TargetGroupArn) if target, err := c.describeTaskTarget(ctx); err != nil { return err } else { - c.Target = target + c.target = target } albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) if _, err := albCli.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ - TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{*c.Target}}, + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{*c.target}}, ); err != nil { return err } @@ -188,20 +188,20 @@ func (c *albTask) WaitUntilTargetHealthy( return ctx.Err() case <-timer.NewTimer(waitPeriod).C: if o, err := albCli.DescribeTargetHealth(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{*c.Target}, + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{*c.target}, }); err != nil { return err } else { for _, desc := range o.TargetHealthDescriptions { - if *desc.Target.Id == *c.Target.Id && *desc.Target.Port == *c.Target.Port { + if *desc.Target.Id == *c.target.Id && *desc.Target.Port == *c.target.Port { recentState = &desc.TargetHealth.State } } if recentState == nil { - return xerrors.Errorf("'%s' is not registered to the target group '%s'", *c.Target.Id, *c.Lb.TargetGroupArn) + return xerrors.Errorf("'%s' is not registered to the target group '%s'", *c.target.Id, *c.lb.TargetGroupArn) } - log.Infof("canary task '%s' (%s:%d) state is: %s", *c.TaskArn, *c.Target.Id, *c.Target.Port, *recentState) + log.Infof("canary task '%s' (%s:%d) state is: %s", *c.taskArn, *c.target.Id, *c.target.Port, *recentState) switch *recentState { case elbv2types.TargetHealthStateEnumHealthy: return nil @@ -212,10 +212,10 @@ func (c *albTask) WaitUntilTargetHealthy( } } // unhealthy, draining, unused - log.Errorf("😨 canary task '%s' is unhealthy", *c.TaskArn) + log.Errorf("😨 canary task '%s' is unhealthy", *c.taskArn) return xerrors.Errorf( "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", - *c.TaskArn, *c.Target.Id, *c.Target.Port, *recentState, + *c.taskArn, *c.target.Id, *c.target.Port, *recentState, ) } @@ -223,7 +223,7 @@ func (c *albTask) GetTargetDeregistrationDelay(ctx context.Context) (time.Durati deregistrationDelay := 300 * time.Second albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) if o, err := albCli.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ - TargetGroupArn: c.Lb.TargetGroupArn, + TargetGroupArn: c.lb.TargetGroupArn, }); err != nil { return deregistrationDelay, err } else { @@ -242,7 +242,7 @@ func (c *albTask) GetTargetDeregistrationDelay(ctx context.Context) (time.Durati } func (c *albTask) DeregisterTarget(ctx context.Context) { - if c.Target == nil { + if c.target == nil { return } albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) @@ -251,10 +251,10 @@ func (c *albTask) DeregisterTarget(ctx context.Context) { log.Errorf("failed to get deregistration delay: %v", err) log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) } - log.Infof("deregistering the canary task from target group '%s'...", *c.Target.Id) + log.Infof("deregistering the canary task from target group '%s'...", *c.target.Id) if _, err := albCli.DeregisterTargets(ctx, &elbv2.DeregisterTargetsInput{ - TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{*c.Target}, + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{*c.target}, }); err != nil { log.Errorf("failed to deregister the canary task from target group: %v", err) log.Errorf("continuing to stop the canary task...") @@ -263,8 +263,8 @@ func (c *albTask) DeregisterTarget(ctx context.Context) { log.Infof("deregister operation accepted. waiting for the canary task to be deregistered...") deregisterWait := deregistrationDelay + time.Minute // add 1 minute for safety if err := elbv2.NewTargetDeregisteredWaiter(albCli).Wait(ctx, &elbv2.DescribeTargetHealthInput{ - TargetGroupArn: c.Lb.TargetGroupArn, - Targets: []elbv2types.TargetDescription{*c.Target}, + TargetGroupArn: c.lb.TargetGroupArn, + Targets: []elbv2types.TargetDescription{*c.target}, }, deregisterWait); err != nil { log.Errorf("failed to wait for the canary task deregistered from target group: %v", err) log.Errorf("continuing to stop the canary task...") @@ -272,6 +272,6 @@ func (c *albTask) DeregisterTarget(ctx context.Context) { } log.Infof( "canary task '%s' has successfully been deregistered from target group '%s'", - *c.TaskArn, *c.Target.Id, + *c.taskArn, *c.target.Id, ) } diff --git a/task/alb_task_test.go b/task/alb_task_test.go index e7072de..b34c2a5 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -1,4 +1,4 @@ -package task_test +package task import ( "context" @@ -17,14 +17,25 @@ import ( "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" - "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) +func TestNewAlbTask(t *testing.T) { + d := &di.D{} + input := &Input{} + lb := &ecstypes.LoadBalancer{} + task := NewAlbTask(d, input, lb) + v, ok := task.(*albTask) + assert.NotNil(t, task) + assert.True(t, ok) + assert.Equal(t, input, v.Input) + assert.Equal(t, lb, v.lb) +} + func TestAlbTask(t *testing.T) { - setup := func(env *env.Envars) (task.Task, *test.MockContext) { + setup := func(env *env.Envars) (*albTask, *test.MockContext) { mocker := test.NewMockContext() ctx := context.TODO() td, _ := mocker.Ecs.RegisterTaskDefinition(ctx, env.TaskDefinitionInput) @@ -37,10 +48,16 @@ func TestAlbTask(t *testing.T) { b.Set(key.AlbCli, mocker.Alb) b.Set(key.Time, test.NewFakeTime()) }) - stask := task.NewAlbTask(d, &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, - }, &ecsSvc.Service.LoadBalancers[0]) + stask := &albTask{ + lb: &ecsSvc.Service.LoadBalancers[0], + common: &common{ + di: d, + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, + }, + }, + } mocker.Alb.RegisterTargets(ctx, &elbv2.RegisterTargetsInput{ TargetGroupArn: ecsSvc.Service.LoadBalancers[0].TargetGroupArn, }) @@ -79,21 +96,26 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { Port: aws.Int32(80), AvailabilityZone: aws.String("ap-northeast-1a"), } - setup := func(t *testing.T) (*mock_awsiface.MockAlbClient, *task.AlbTaskExport) { + setup := func(t *testing.T) (*mock_awsiface.MockAlbClient, *albTask) { ctrl := gomock.NewController(t) env := test.DefaultEnvars() mocker := test.NewMockContext() albMock := mock_awsiface.NewMockAlbClient(ctrl) td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), env.TaskDefinitionInput) - atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.AlbCli, albMock) - b.Set(key.Time, test.NewFakeTime()) - }), &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, - }, &env.ServiceDefinitionInput.LoadBalancers[0]) - atask.TaskArn = aws.String("arn://task") - atask.Target = target + atask := &albTask{ + common: &common{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.AlbCli, albMock) + b.Set(key.Time, test.NewFakeTime()) + }), + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, + }}, + lb: &env.ServiceDefinitionInput.LoadBalancers[0], + } + atask.taskArn = aws.String("arn://task") + atask.target = target return albMock, atask } t.Run("should call DescribeTargetHealth periodically", func(t *testing.T) { @@ -141,7 +163,7 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { ) err := atask.WaitUntilTargetHealthy(context.TODO()) assert.EqualError(t, err, fmt.Sprintf( - "'%s' is not registered to the target group '%s'", *target.Id, *atask.Lb.TargetGroupArn), + "'%s' is not registered to the target group '%s'", *target.Id, *atask.lb.TargetGroupArn), ) }) t.Run("should error if target unhelthy counts exceed the limit", func(t *testing.T) { @@ -158,7 +180,7 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { err := atask.WaitUntilTargetHealthy(context.TODO()) assert.EqualError(t, err, fmt.Sprintf( "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", - *atask.TaskArn, *target.Id, *target.Port, elbv2types.TargetHealthStateEnumUnhealthy, + *atask.taskArn, *target.Id, *target.Port, elbv2types.TargetHealthStateEnumUnhealthy, ), ) }) @@ -169,15 +191,21 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { env := test.DefaultEnvars() mocker := test.NewMockContext() td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), env.TaskDefinitionInput) - atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, env) - }), &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, - }, &ecstypes.LoadBalancer{ - TargetGroupArn: aws.String("arn://target-group"), - ContainerName: aws.String("unknown")}) - atask.TaskArn = aws.String("arn://task") + atask := &albTask{ + common: &common{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + }), + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, + }, + }, + lb: &ecstypes.LoadBalancer{ + TargetGroupArn: aws.String("arn://target-group"), + ContainerName: aws.String("unknown")}, + } + atask.taskArn = aws.String("arn://task") err := atask.RegisterToTargetGroup(context.TODO()) assert.EqualError(t, err, "couldn't find host port in container definition") }) @@ -199,7 +227,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { subnets := []ec2types.Subnet{{ AvailabilityZone: aws.String("ap-northeast-1a"), }} - setup := func(t *testing.T) (*mock_awsiface.MockEc2Client, *mock_awsiface.MockAlbClient, *mock_awsiface.MockEcsClient, *task.AlbTaskExport) { + setup := func(t *testing.T) (*mock_awsiface.MockEc2Client, *mock_awsiface.MockAlbClient, *mock_awsiface.MockEcsClient, *albTask) { ctrl := gomock.NewController(t) envars := test.DefaultEnvars() mocker := test.NewMockContext() @@ -207,16 +235,22 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { ec2Mock := mock_awsiface.NewMockEc2Client(ctrl) albMock := mock_awsiface.NewMockAlbClient(ctrl) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.AlbCli, albMock) - b.Set(key.EcsCli, ecsMock) - }), &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: envars.ServiceDefinitionInput.NetworkConfiguration, - }, &envars.ServiceDefinitionInput.LoadBalancers[0]) - atask.TaskArn = aws.String("arn://task") + atask := &albTask{ + common: &common{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.AlbCli, albMock) + b.Set(key.EcsCli, ecsMock) + }), + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: envars.ServiceDefinitionInput.NetworkConfiguration, + }, + }, + lb: &envars.ServiceDefinitionInput.LoadBalancers[0], + } + atask.taskArn = aws.String("arn://task") return ec2Mock, albMock, ecsMock, atask } t.Run("should call RegisterTargets", func(t *testing.T) { @@ -230,13 +264,13 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { Subnets: subnets, }, nil) albMock.EXPECT().RegisterTargets(gomock.Any(), &elbv2.RegisterTargetsInput{ - TargetGroupArn: atask.Lb.TargetGroupArn, + TargetGroupArn: atask.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ Id: aws.String("127.0.0.1"), Port: aws.Int32(80), AvailabilityZone: subnets[0].AvailabilityZone}, }}).Return(nil, nil) - atask.TaskArn = aws.String("arn://task") + atask.taskArn = aws.String("arn://task") err := atask.RegisterToTargetGroup(context.TODO()) assert.NoError(t, err) }) @@ -297,7 +331,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { subnets := []ec2types.Subnet{{ AvailabilityZone: aws.String("ap-northeast-1a"), }} - setup := func(t *testing.T) (*mock_awsiface.MockEc2Client, *mock_awsiface.MockAlbClient, *mock_awsiface.MockEcsClient, *task.AlbTaskExport) { + setup := func(t *testing.T) (*mock_awsiface.MockEc2Client, *mock_awsiface.MockAlbClient, *mock_awsiface.MockEcsClient, *albTask) { ctrl := gomock.NewController(t) envars := test.DefaultEnvars() envars.CanaryInstanceArn = "arn://container" @@ -306,16 +340,22 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { ec2Mock := mock_awsiface.NewMockEc2Client(ctrl) albMock := mock_awsiface.NewMockAlbClient(ctrl) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.Ec2Cli, ec2Mock) - b.Set(key.AlbCli, albMock) - b.Set(key.EcsCli, ecsMock) - }), &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: envars.ServiceDefinitionInput.NetworkConfiguration, - }, &envars.ServiceDefinitionInput.LoadBalancers[0]) - atask.TaskArn = aws.String("arn://task") + atask := &albTask{ + common: &common{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.Ec2Cli, ec2Mock) + b.Set(key.AlbCli, albMock) + b.Set(key.EcsCli, ecsMock) + }), + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: envars.ServiceDefinitionInput.NetworkConfiguration, + }, + }, + lb: &envars.ServiceDefinitionInput.LoadBalancers[0], + } + atask.taskArn = aws.String("arn://task") return ec2Mock, albMock, ecsMock, atask } t.Run("should call RegisterTargets", func(t *testing.T) { @@ -330,7 +370,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { Subnets: subnets, }, nil) albMock.EXPECT().RegisterTargets(gomock.Any(), &elbv2.RegisterTargetsInput{ - TargetGroupArn: atask.Lb.TargetGroupArn, + TargetGroupArn: atask.lb.TargetGroupArn, Targets: []elbv2types.TargetDescription{{ Id: containerInstances[0].Ec2InstanceId, Port: aws.Int32(80), @@ -370,13 +410,19 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { } func TestAlbTask_GetTargetDeregistrationDelay(t *testing.T) { - setup := func(t *testing.T) (*mock_awsiface.MockAlbClient, *task.AlbTaskExport) { + setup := func(t *testing.T) (*mock_awsiface.MockAlbClient, *albTask) { ctrl := gomock.NewController(t) env := test.DefaultEnvars() albMock := mock_awsiface.NewMockAlbClient(ctrl) - atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.AlbCli, albMock) - }), &task.Input{}, &env.ServiceDefinitionInput.LoadBalancers[0]) + atask := &albTask{ + common: &common{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.AlbCli, albMock) + }), + Input: &Input{}, + }, + lb: &env.ServiceDefinitionInput.LoadBalancers[0], + } return albMock, atask } t.Run("should return deregistration delay", func(t *testing.T) { @@ -425,23 +471,29 @@ func TestAlbTask_DeregisterTarget(t *testing.T) { Port: aws.Int32(80), AvailabilityZone: aws.String("ap-northeast-1a"), } - setup := func(t *testing.T, env *env.Envars) (*mock_awsiface.MockAlbClient, *task.AlbTaskExport) { + setup := func(t *testing.T, env *env.Envars) (*mock_awsiface.MockAlbClient, *albTask) { ctrl := gomock.NewController(t) mocker := test.NewMockContext() albMock := mock_awsiface.NewMockAlbClient(ctrl) td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), env.TaskDefinitionInput) - atask := task.NewAlbTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.AlbCli, albMock) - }), &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, - }, &env.ServiceDefinitionInput.LoadBalancers[0]) - atask.TaskArn = aws.String("arn://task") - atask.Target = target + atask := &albTask{ + common: &common{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.AlbCli, albMock) + }), + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration, + }, + }, + lb: &env.ServiceDefinitionInput.LoadBalancers[0], + } + atask.taskArn = aws.String("arn://task") + atask.target = target return albMock, atask } t.Run("should do nothing if target is nil", func(t *testing.T) { - atask := task.NewAlbTaskExport(di.EmptyDomain(), &task.Input{}, nil) + atask := &albTask{} atask.DeregisterTarget(context.TODO()) }) t.Run("should call DeregisterTargets and wait", func(t *testing.T) { diff --git a/task/common.go b/task/common.go index 7dca95e..eecea12 100644 --- a/task/common.go +++ b/task/common.go @@ -25,7 +25,7 @@ type Input struct { type common struct { *Input di *di.D - TaskArn *string + taskArn *string } func (c *common) Start(ctx context.Context) error { @@ -43,7 +43,7 @@ func (c *common) Start(ctx context.Context) error { }); err != nil { return err } else { - c.TaskArn = o.Tasks[0].TaskArn + c.taskArn = o.Tasks[0].TaskArn } } else { // fargate @@ -57,26 +57,26 @@ func (c *common) Start(ctx context.Context) error { }); err != nil { return err } else { - c.TaskArn = o.Tasks[0].TaskArn + c.taskArn = o.Tasks[0].TaskArn } } return nil } func (c *common) WaitForTaskRunning(ctx context.Context) error { - if c.TaskArn == nil { + if c.taskArn == nil { return xerrors.New("task is not started") } env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - log.Infof("🥚 waiting for canary task '%s' is running...", *c.TaskArn) + log.Infof("🥚 waiting for canary task '%s' is running...", *c.taskArn) if err := ecs.NewTasksRunningWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.TaskArn}, + Tasks: []string{*c.taskArn}, }, env.GetTaskRunningWait()); err != nil { return xerrors.Errorf("failed to wait for canary task to be running: %w", err) } - log.Infof("🐣 canary task '%s' is running!", *c.TaskArn) + log.Infof("🐣 canary task '%s' is running!", *c.taskArn) return nil } @@ -105,11 +105,11 @@ func (c *common) WaitContainerHealthCheck(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() case <-timer.NewTimer(healthCheckPeriod).C: - log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.TaskArn, len(containerHasHealthChecks)) + log.Infof("canary task '%s' waits until %d container(s) become healthy", *c.taskArn, len(containerHasHealthChecks)) var task ecstypes.Task if o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.TaskArn}, + Tasks: []string{*c.taskArn}, }); err != nil { return err } else { @@ -133,31 +133,31 @@ func (c *common) WaitContainerHealthCheck(ctx context.Context) error { } if len(containerHasHealthChecks) == 0 { log.Info("🤩 canary task container(s) is healthy!") - log.Infof("canary task '%s' ensured.", *c.TaskArn) + log.Infof("canary task '%s' ensured.", *c.taskArn) return nil } return xerrors.Errorf("😨 canary task hasn't become to be healthy") } func (c *common) StopTask(ctx context.Context) error { - if c.TaskArn == nil { + if c.taskArn == nil { return nil } env := c.di.Get(key.Env).(*env.Envars) ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) - log.Infof("stopping the canary task '%s'...", *c.TaskArn) + log.Infof("stopping the canary task '%s'...", *c.taskArn) if _, err := ecsCli.StopTask(ctx, &ecs.StopTaskInput{ Cluster: &env.Cluster, - Task: c.TaskArn, + Task: c.taskArn, }); err != nil { return xerrors.Errorf("failed to stop canary task: %w", err) } if err := ecs.NewTasksStoppedWaiter(ecsCli).Wait(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.TaskArn}, + Tasks: []string{*c.taskArn}, }, env.GetTaskStoppedWait()); err != nil { return xerrors.Errorf("failed to wait for canary task to be stopped: %w", err) } - log.Infof("canary task '%s' has successfully been stopped", *c.TaskArn) + log.Infof("canary task '%s' has successfully been stopped", *c.taskArn) return nil } diff --git a/task/common_test.go b/task/common_test.go index 4759e6a..a8bb5c1 100644 --- a/task/common_test.go +++ b/task/common_test.go @@ -1,4 +1,4 @@ -package task_test +package task import ( "context" @@ -14,7 +14,6 @@ import ( "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/mocks/mock_types" - "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" @@ -30,10 +29,13 @@ func TestCommon_Start(t *testing.T) { Tasks: []ecstypes.Task{{TaskArn: aws.String("task-arn")}}, }, nil) envars := test.DefaultEnvars() - cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - }), &task.Input{TaskDefinition: td}) + cm := &common{ + Input: &Input{TaskDefinition: td}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), + } err := cm.Start(context.TODO()) assert.NoError(t, err) }) @@ -43,10 +45,13 @@ func TestCommon_Start(t *testing.T) { ecsMock := mock_awsiface.NewMockEcsClient(ctrl) ecsMock.EXPECT().RunTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")) envars := test.DefaultEnvars() - cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - }), &task.Input{TaskDefinition: td}) + cm := &common{ + Input: &Input{TaskDefinition: td}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), + } err := cm.Start(context.TODO()) assert.EqualError(t, err, "error") }) @@ -61,10 +66,13 @@ func TestCommon_Start(t *testing.T) { }, nil) envars := test.DefaultEnvars() envars.CanaryInstanceArn = "instance-arn" - cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - }), &task.Input{TaskDefinition: td}) + cm := &common{ + Input: &Input{TaskDefinition: td}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), + } err := cm.Start(context.TODO()) assert.NoError(t, err) }) @@ -75,10 +83,13 @@ func TestCommon_Start(t *testing.T) { ecsMock.EXPECT().StartTask(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")) envars := test.DefaultEnvars() envars.CanaryInstanceArn = "instance-arn" - cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - }), &task.Input{TaskDefinition: td}) + cm := &common{ + Input: &Input{TaskDefinition: td}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), + } err := cm.Start(context.TODO()) assert.EqualError(t, err, "error") }) @@ -86,15 +97,18 @@ func TestCommon_Start(t *testing.T) { } func TestCommon_WaitForTaskRunning(t *testing.T) { - setup := func(t *testing.T, envars *env.Envars) (*mock_awsiface.MockEcsClient, *task.CommonExport) { + setup := func(t *testing.T, envars *env.Envars) (*mock_awsiface.MockEcsClient, *common) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) td := &ecstypes.TaskDefinition{} - cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - }), &task.Input{TaskDefinition: td}) - cm.TaskArn = aws.String("task-arn") + cm := &common{ + Input: &Input{TaskDefinition: td}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + }), + } + cm.taskArn = aws.String("task-arn") return ecsMock, cm } t.Run("should call ecs.NewTasksRunningWaiter", func(t *testing.T) { @@ -106,7 +120,7 @@ func TestCommon_WaitForTaskRunning(t *testing.T) { assert.NoError(t, err) }) t.Run("should error if task is not started", func(t *testing.T) { - cm := task.NewCommonExport(di.EmptyDomain(), nil) + cm := &common{} err := cm.WaitForTaskRunning(context.TODO()) assert.EqualError(t, err, "task is not started") }) @@ -126,18 +140,21 @@ func TestCommon_WaitForTaskRunning(t *testing.T) { func TestCommon_WaitContainerHealthCheck(t *testing.T) { setup := func(t *testing.T, envars *env.Envars) (*mock_awsiface.MockEcsClient, *mock_types.MockTime, *ecstypes.TaskDefinition, - *task.CommonExport) { + *common) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) mocker := test.NewMockContext() timerMock := mock_types.NewMockTime(ctrl) td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) - cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.Time, timerMock) - }), &task.Input{TaskDefinition: td.TaskDefinition}) - cm.TaskArn = aws.String("task-arn") + cm := &common{ + Input: &Input{TaskDefinition: td.TaskDefinition}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, timerMock) + }), + } + cm.taskArn = aws.String("task-arn") return ecsMock, timerMock, td.TaskDefinition, cm } t.Run("should call DescribeTasks periodically", func(t *testing.T) { @@ -239,14 +256,16 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { } func TestCommon_StopTask(t *testing.T) { - setup := func(t *testing.T, env *env.Envars) (*mock_awsiface.MockEcsClient, *task.CommonExport) { + setup := func(t *testing.T, env *env.Envars) (*mock_awsiface.MockEcsClient, *common) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) - cm := task.NewCommonExport(di.NewDomain(func(b *di.B) { - b.Set(key.EcsCli, ecsMock) - b.Set(key.Env, env) - }), nil) - cm.TaskArn = aws.String("task-arn") + cm := &common{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.EcsCli, ecsMock) + b.Set(key.Env, env) + }), + } + cm.taskArn = aws.String("task-arn") return ecsMock, cm } t.Run("should call ecscCli.StopTask and wait", func(t *testing.T) { @@ -261,7 +280,7 @@ func TestCommon_StopTask(t *testing.T) { assert.NoError(t, err) }) t.Run("should do nothing if task is not started", func(t *testing.T) { - cm := task.NewCommonExport(di.EmptyDomain(), nil) + cm := &common{} err := cm.StopTask(context.TODO()) assert.NoError(t, err) }) diff --git a/task/export_test.go b/task/export_test.go deleted file mode 100644 index b193a42..0000000 --- a/task/export_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package task - -import ( - ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/loilo-inc/logos/di" -) - -type CommonExport = common -type AlbTaskExport = albTask -type SimpleTaskExport = simpleTask - -func NewCommonExport(di *di.D, input *Input) *common { - return &common{di: di, Input: input} -} - -func NewAlbTaskExport(di *di.D, input *Input, lb *ecstypes.LoadBalancer) *albTask { - return &albTask{common: &common{di: di, Input: input}, Lb: lb} -} - -func NewSimpleTaskExport(di *di.D, input *Input) *simpleTask { - return &simpleTask{common: &common{di: di, Input: input}} -} diff --git a/task/factory_test.go b/task/factory_test.go index 642fd65..272eec9 100644 --- a/task/factory_test.go +++ b/task/factory_test.go @@ -1,10 +1,9 @@ -package task_test +package task import ( "testing" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) @@ -12,16 +11,24 @@ import ( func TestFactory(t *testing.T) { d := &di.D{} t.Run("NewAlbTask", func(t *testing.T) { - f := task.NewFactory(d) - input := &task.Input{} + f := NewFactory(d) + input := &Input{} lb := &ecstypes.LoadBalancer{} task := f.NewAlbTask(input, lb) + v, ok := task.(*albTask) assert.NotNil(t, task) + assert.True(t, ok) + assert.Equal(t, input, v.Input) + assert.Equal(t, lb, v.lb) }) t.Run("NewSimpleTask", func(t *testing.T) { - f := task.NewFactory(d) - input := &task.Input{} + f := NewFactory(d) + input := &Input{} task := f.NewSimpleTask(input) + v, ok := task.(*simpleTask) assert.NotNil(t, task) + assert.True(t, ok) + assert.Equal(t, input, v.Input) + assert.Equal(t, d, v.di) }) } diff --git a/task/simple_task.go b/task/simple_task.go index 0911886..b2b25cd 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -58,7 +58,7 @@ func (c *simpleTask) WaitForIdleDuration(ctx context.Context) error { ecsCli := c.di.Get(key.EcsCli).(awsiface.EcsClient) o, err := ecsCli.DescribeTasks(ctx, &ecs.DescribeTasksInput{ Cluster: &env.Cluster, - Tasks: []string{*c.TaskArn}, + Tasks: []string{*c.taskArn}, }) if err != nil { return err diff --git a/task/simple_task_test.go b/task/simple_task_test.go index 220b7ba..233be89 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -1,4 +1,4 @@ -package task_test +package task import ( "context" @@ -12,12 +12,22 @@ import ( "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" "github.com/loilo-inc/canarycage/mocks/mock_types" - "github.com/loilo-inc/canarycage/task" "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/logos/di" "github.com/stretchr/testify/assert" ) +func TestNewSimpleTask(t *testing.T) { + d := di.NewDomain(func(b *di.B) { + b.Set(key.Env, test.DefaultEnvars()) + }) + task := NewSimpleTask(d, &Input{}) + v, ok := task.(*simpleTask) + assert.NotNil(t, task) + assert.True(t, ok) + assert.Equal(t, d, v.di) +} + func TestSimpleTask(t *testing.T) { ctx := context.TODO() mocker := test.NewMockContext() @@ -33,10 +43,15 @@ func TestSimpleTask(t *testing.T) { b.Set(key.AlbCli, mocker.Alb) b.Set(key.Time, test.NewFakeTime()) }) - stask := task.NewSimpleTask(d, &task.Input{ - TaskDefinition: td.TaskDefinition, - NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, - }) + stask := &simpleTask{ + common: &common{ + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: ecsSvc.Service.NetworkConfiguration, + }, + di: d, + }, + } err := stask.Start(ctx) assert.NoError(t, err) err = stask.Wait(ctx) @@ -47,7 +62,7 @@ func TestSimpleTask(t *testing.T) { } func TestSimpleTask_WaitForIdleDuration(t *testing.T) { - setup := func(t *testing.T, idle int) (*mock_awsiface.MockEcsClient, *mock_types.MockTime, *task.SimpleTaskExport) { + setup := func(t *testing.T, idle int) (*mock_awsiface.MockEcsClient, *mock_types.MockTime, *simpleTask) { ctrl := gomock.NewController(t) mocker := test.NewMockContext() envars := test.DefaultEnvars() @@ -55,12 +70,19 @@ func TestSimpleTask_WaitForIdleDuration(t *testing.T) { td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), envars.TaskDefinitionInput) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) timerMock := mock_types.NewMockTime(ctrl) - cm := task.NewSimpleTaskExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, envars) - b.Set(key.EcsCli, ecsMock) - b.Set(key.Time, timerMock) - }), &task.Input{TaskDefinition: td.TaskDefinition}) - cm.TaskArn = aws.String("task-arn") + cm := &simpleTask{ + common: &common{ + Input: &Input{ + TaskDefinition: td.TaskDefinition, + }, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, envars) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, timerMock) + }), + }, + } + cm.taskArn = aws.String("task-arn") return ecsMock, timerMock, cm } t.Run("should call DescribeTasks periodically", func(t *testing.T) { From 1afb781e218019410bba2b79ab60d85cf48684a7 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Thu, 4 Jul 2024 21:01:19 +0900 Subject: [PATCH 43/47] test --- task/alb_task.go | 22 +++--- task/alb_task_test.go | 162 +++++++++++++++++++++++++++++++++------ task/common.go | 6 +- task/common_test.go | 26 +++---- task/simple_task.go | 10 +-- task/simple_task_test.go | 61 ++++++++++++++- 6 files changed, 228 insertions(+), 59 deletions(-) diff --git a/task/alb_task.go b/task/alb_task.go index 83b85b2..07fee99 100644 --- a/task/alb_task.go +++ b/task/alb_task.go @@ -39,18 +39,18 @@ func NewAlbTask( } func (c *albTask) Wait(ctx context.Context) error { - if err := c.WaitForTaskRunning(ctx); err != nil { + if err := c.waitForTaskRunning(ctx); err != nil { return err } - if err := c.WaitContainerHealthCheck(ctx); err != nil { + if err := c.waitContainerHealthCheck(ctx); err != nil { return err } - if err := c.RegisterToTargetGroup(ctx); err != nil { + if err := c.registerToTargetGroup(ctx); err != nil { return err } log.Infof("canary task '%s' is registered to target group '%s'", *c.target.Id, *c.lb.TargetGroupArn) log.Infof("😷 waiting canary target to be healthy...") - if err := c.WaitUntilTargetHealthy(ctx); err != nil { + if err := c.waitUntilTargetHealthy(ctx); err != nil { return err } log.Info("🤩 canary target is healthy!") @@ -58,8 +58,8 @@ func (c *albTask) Wait(ctx context.Context) error { } func (c *albTask) Stop(ctx context.Context) error { - c.DeregisterTarget(ctx) - return c.StopTask(ctx) + c.deregisterTarget(ctx) + return c.stopTask(ctx) } func (c *albTask) describeTaskTarget( @@ -156,7 +156,7 @@ func (c *albTask) getTargetPort() (*int32, error) { return nil, xerrors.Errorf("couldn't find host port in container definition") } -func (c *albTask) RegisterToTargetGroup(ctx context.Context) error { +func (c *albTask) registerToTargetGroup(ctx context.Context) error { log.Infof("registering the canary task to target group '%s'...", *c.lb.TargetGroupArn) if target, err := c.describeTaskTarget(ctx); err != nil { return err @@ -173,7 +173,7 @@ func (c *albTask) RegisterToTargetGroup(ctx context.Context) error { return nil } -func (c *albTask) WaitUntilTargetHealthy( +func (c *albTask) waitUntilTargetHealthy( ctx context.Context, ) error { albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) @@ -219,7 +219,7 @@ func (c *albTask) WaitUntilTargetHealthy( ) } -func (c *albTask) GetTargetDeregistrationDelay(ctx context.Context) (time.Duration, error) { +func (c *albTask) getTargetDeregistrationDelay(ctx context.Context) (time.Duration, error) { deregistrationDelay := 300 * time.Second albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) if o, err := albCli.DescribeTargetGroupAttributes(ctx, &elbv2.DescribeTargetGroupAttributesInput{ @@ -241,12 +241,12 @@ func (c *albTask) GetTargetDeregistrationDelay(ctx context.Context) (time.Durati return deregistrationDelay, nil } -func (c *albTask) DeregisterTarget(ctx context.Context) { +func (c *albTask) deregisterTarget(ctx context.Context) { if c.target == nil { return } albCli := c.di.Get(key.AlbCli).(awsiface.AlbClient) - deregistrationDelay, err := c.GetTargetDeregistrationDelay(ctx) + deregistrationDelay, err := c.getTargetDeregistrationDelay(ctx) if err != nil { log.Errorf("failed to get deregistration delay: %v", err) log.Errorf("deregistration delay is set to %d seconds", deregistrationDelay) diff --git a/task/alb_task_test.go b/task/alb_task_test.go index b34c2a5..4281bbc 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -90,6 +90,122 @@ func TestAlbTask(t *testing.T) { }) } +func TestAlbTask_Wait(t *testing.T) { + t.Run("should error if task is not running", func(t *testing.T) { + ctrl := gomock.NewController(t) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + cm := &albTask{ + common: &common{ + taskArn: aws.String("task-arn"), + Input: &Input{}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, test.DefaultEnvars()) + b.Set(key.EcsCli, ecsMock) + }), + }, + } + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("STOPPED")}}, + }, nil) + err := cm.Wait(context.TODO()) + assert.ErrorContains(t, err, "failed to wait for canary task to be running") + }) + t.Run("should error if container is not healthy", func(t *testing.T) { + ctrl := gomock.NewController(t) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), test.DefaultEnvars().TaskDefinitionInput) + env := test.DefaultEnvars() + env.CanaryTaskHealthCheckWait = 1 + cm := &albTask{ + common: &common{ + taskArn: aws.String("task-arn"), + Input: &Input{TaskDefinition: td.TaskDefinition}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + }), + }, + } + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING"), + Containers: []ecstypes.Container{{ + Name: env.TaskDefinitionInput.ContainerDefinitions[0].Name, + HealthStatus: ecstypes.HealthStatusUnhealthy, + }}, + }}, + }, nil).Times(2) + err := cm.Wait(context.TODO()) + assert.ErrorContains(t, err, "canary task hasn't become to be healthy") + }) + t.Run("should erro if RegisterToTargetGroup failed", func(t *testing.T) { + ctrl := gomock.NewController(t) + albMock := mock_awsiface.NewMockAlbClient(ctrl) + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), test.DefaultEnvars().TaskDefinitionInput) + env := test.DefaultEnvars() + env.CanaryTaskHealthCheckWait = 1 + cm := &albTask{ + common: &common{ + taskArn: aws.String("task-arn"), + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, mocker.Ecs) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, mocker.Ec2) + b.Set(key.Time, test.NewFakeTime()) + }), + }, + lb: &env.ServiceDefinitionInput.LoadBalancers[0], + } + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := cm.Start(context.TODO()) + if err != nil { + t.Fatal(err) + } + err = cm.Wait(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) + t.Run("should error if waitUntilTargetHealthy failed", func(t *testing.T) { + ctrl := gomock.NewController(t) + albMock := mock_awsiface.NewMockAlbClient(ctrl) + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), test.DefaultEnvars().TaskDefinitionInput) + env := test.DefaultEnvars() + env.CanaryTaskHealthCheckWait = 1 + cm := &albTask{ + common: &common{ + taskArn: aws.String("task-arn"), + Input: &Input{ + TaskDefinition: td.TaskDefinition, + NetworkConfiguration: env.ServiceDefinitionInput.NetworkConfiguration}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, mocker.Ecs) + b.Set(key.AlbCli, albMock) + b.Set(key.Ec2Cli, mocker.Ec2) + b.Set(key.Time, test.NewFakeTime()) + }), + }, + lb: &env.ServiceDefinitionInput.LoadBalancers[0], + } + albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil) + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) + err := cm.Start(context.TODO()) + if err != nil { + t.Fatal(err) + } + err = cm.Wait(context.TODO()) + assert.EqualError(t, err, assert.AnError.Error()) + }) +} + func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { target := &elbv2types.TargetDescription{ Id: aws.String("127.0.0.1"), @@ -136,7 +252,7 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { }, }, nil).Times(1), ) - err := atask.WaitUntilTargetHealthy(context.TODO()) + err := atask.waitUntilTargetHealthy(context.TODO()) assert.NoError(t, err) }) t.Run("should error if DescribeTargetHealth failed", func(t *testing.T) { @@ -144,14 +260,14 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { gomock.InOrder( albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any()).Return(nil, assert.AnError).Times(1), ) - err := atask.WaitUntilTargetHealthy(context.TODO()) + err := atask.waitUntilTargetHealthy(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) t.Run("should error if context is canceled", func(t *testing.T) { _, atask := setup(t) ctx, cancel := context.WithCancel(context.Background()) cancel() - err := atask.WaitUntilTargetHealthy(ctx) + err := atask.waitUntilTargetHealthy(ctx) assert.EqualError(t, err, "context canceled") }) t.Run("should error if target is not registered", func(t *testing.T) { @@ -161,7 +277,7 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { TargetHealthDescriptions: []elbv2types.TargetHealthDescription{}, }, nil).Times(1), ) - err := atask.WaitUntilTargetHealthy(context.TODO()) + err := atask.waitUntilTargetHealthy(context.TODO()) assert.EqualError(t, err, fmt.Sprintf( "'%s' is not registered to the target group '%s'", *target.Id, *atask.lb.TargetGroupArn), ) @@ -177,7 +293,7 @@ func TestAlbTask_WaitUntilTargetHealthy(t *testing.T) { }, }, nil).Times(5), ) - err := atask.WaitUntilTargetHealthy(context.TODO()) + err := atask.waitUntilTargetHealthy(context.TODO()) assert.EqualError(t, err, fmt.Sprintf( "canary task '%s' (%s:%d) hasn't become to be healthy. The most recent state: %s", *atask.taskArn, *target.Id, *target.Port, elbv2types.TargetHealthStateEnumUnhealthy, @@ -206,7 +322,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { ContainerName: aws.String("unknown")}, } atask.taskArn = aws.String("arn://task") - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, "couldn't find host port in container definition") }) t.Run("Fargate", func(t *testing.T) { @@ -271,13 +387,13 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { AvailabilityZone: subnets[0].AvailabilityZone}, }}).Return(nil, nil) atask.taskArn = aws.String("arn://task") - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.NoError(t, err) }) t.Run("should error if DescribeTasks failed", func(t *testing.T) { _, _, ecsMock, atask := setup(t) ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) t.Run("should error if DescribeSubnets failed", func(t *testing.T) { @@ -288,7 +404,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { Attachments: attachments}, }}, nil) ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) t.Run("should error if task is not attached to the network interface", func(t *testing.T) { @@ -298,7 +414,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { LastStatus: aws.String("RUNNING"), }}, }, nil) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, "couldn't find ElasticNetworkInterface attachment in task") }) t.Run("should error if RegisterTargets failed", func(t *testing.T) { @@ -312,7 +428,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { Subnets: subnets, }, nil) albMock.EXPECT().RegisterTargets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) }) @@ -376,13 +492,13 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { Port: aws.Int32(80), AvailabilityZone: subnets[0].AvailabilityZone}, }}).Return(nil, nil) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.NoError(t, err) }) t.Run("should error if DescribeContainerInstances failed", func(t *testing.T) { _, _, ecsMock, atask := setup(t) ecsMock.EXPECT().DescribeContainerInstances(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) t.Run("should error if DescribeInstances failed", func(t *testing.T) { @@ -391,7 +507,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { ContainerInstances: containerInstances, }, nil) ec2Mock.EXPECT().DescribeInstances(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) t.Run("should error if DescribeSubnets failed", func(t *testing.T) { @@ -403,7 +519,7 @@ func TestAlbTask_RegisterToTargetGroup(t *testing.T) { Reservations: reservations, }, nil) ec2Mock.EXPECT().DescribeSubnets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) - err := atask.RegisterToTargetGroup(context.TODO()) + err := atask.registerToTargetGroup(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) }) @@ -432,7 +548,7 @@ func TestAlbTask_GetTargetDeregistrationDelay(t *testing.T) { {Key: aws.String("deregistration_delay.timeout_seconds"), Value: aws.String("100")}, }, }, nil) - delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + delay, err := atask.getTargetDeregistrationDelay(context.TODO()) assert.NoError(t, err) assert.Equal(t, 100*time.Second, delay) }) @@ -441,7 +557,7 @@ func TestAlbTask_GetTargetDeregistrationDelay(t *testing.T) { albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetGroupAttributesOutput{ Attributes: []elbv2types.TargetGroupAttribute{}, }, nil) - delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + delay, err := atask.getTargetDeregistrationDelay(context.TODO()) assert.NoError(t, err) assert.Equal(t, 300*time.Second, delay) }) @@ -452,14 +568,14 @@ func TestAlbTask_GetTargetDeregistrationDelay(t *testing.T) { {Key: aws.String("deregistration_delay.timeout_seconds"), Value: aws.String("invalid")}, }, }, nil) - delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + delay, err := atask.getTargetDeregistrationDelay(context.TODO()) assert.Error(t, err) assert.Equal(t, 300*time.Second, delay) }) t.Run("should error if DescribeTargetGroupAttributes failed", func(t *testing.T) { albMock, atask := setup(t) albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(nil, assert.AnError) - delay, err := atask.GetTargetDeregistrationDelay(context.TODO()) + delay, err := atask.getTargetDeregistrationDelay(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) assert.Equal(t, 300*time.Second, delay) }) @@ -494,7 +610,7 @@ func TestAlbTask_DeregisterTarget(t *testing.T) { } t.Run("should do nothing if target is nil", func(t *testing.T) { atask := &albTask{} - atask.DeregisterTarget(context.TODO()) + atask.deregisterTarget(context.TODO()) }) t.Run("should call DeregisterTargets and wait", func(t *testing.T) { env := test.DefaultEnvars() @@ -514,7 +630,7 @@ func TestAlbTask_DeregisterTarget(t *testing.T) { }, }, nil).Times(1), ) - atask.DeregisterTarget(context.TODO()) + atask.deregisterTarget(context.TODO()) }) t.Run("should return even if DeregisterTargets failed", func(t *testing.T) { env := test.DefaultEnvars() @@ -527,7 +643,7 @@ func TestAlbTask_DeregisterTarget(t *testing.T) { }, nil).Times(1), albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).Return(nil, assert.AnError).Times(1), ) - atask.DeregisterTarget(context.TODO()) + atask.deregisterTarget(context.TODO()) }) t.Run("should return even if deregistration wait counts exceed the limit", func(t *testing.T) { env := test.DefaultEnvars() @@ -541,6 +657,6 @@ func TestAlbTask_DeregisterTarget(t *testing.T) { albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1), albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, assert.AnError).Times(1), ) - atask.DeregisterTarget(context.TODO()) + atask.deregisterTarget(context.TODO()) }) } diff --git a/task/common.go b/task/common.go index eecea12..95d64a7 100644 --- a/task/common.go +++ b/task/common.go @@ -63,7 +63,7 @@ func (c *common) Start(ctx context.Context) error { return nil } -func (c *common) WaitForTaskRunning(ctx context.Context) error { +func (c *common) waitForTaskRunning(ctx context.Context) error { if c.taskArn == nil { return xerrors.New("task is not started") } @@ -80,7 +80,7 @@ func (c *common) WaitForTaskRunning(ctx context.Context) error { return nil } -func (c *common) WaitContainerHealthCheck(ctx context.Context) error { +func (c *common) waitContainerHealthCheck(ctx context.Context) error { log.Infof("😷 ensuring canary task container(s) to become healthy...") containerHasHealthChecks := map[string]struct{}{} for _, definition := range c.TaskDefinition.ContainerDefinitions { @@ -139,7 +139,7 @@ func (c *common) WaitContainerHealthCheck(ctx context.Context) error { return xerrors.Errorf("😨 canary task hasn't become to be healthy") } -func (c *common) StopTask(ctx context.Context) error { +func (c *common) stopTask(ctx context.Context) error { if c.taskArn == nil { return nil } diff --git a/task/common_test.go b/task/common_test.go index a8bb5c1..1d0113a 100644 --- a/task/common_test.go +++ b/task/common_test.go @@ -116,12 +116,12 @@ func TestCommon_WaitForTaskRunning(t *testing.T) { ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()).Return(&ecs.DescribeTasksOutput{ Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING")}}, }, nil) - err := cm.WaitForTaskRunning(context.TODO()) + err := cm.waitForTaskRunning(context.TODO()) assert.NoError(t, err) }) t.Run("should error if task is not started", func(t *testing.T) { cm := &common{} - err := cm.WaitForTaskRunning(context.TODO()) + err := cm.waitForTaskRunning(context.TODO()) assert.EqualError(t, err, "task is not started") }) t.Run("should error if ecs.NewTasksRunningWaiter failed", func(t *testing.T) { @@ -132,7 +132,7 @@ func TestCommon_WaitForTaskRunning(t *testing.T) { &ecs.DescribeTasksOutput{ Tasks: []ecstypes.Task{{LastStatus: aws.String("STOPPED"), StoppedReason: aws.String("reason")}}, }, nil) - err := cm.WaitForTaskRunning(context.TODO()) + err := cm.waitForTaskRunning(context.TODO()) assert.ErrorContains(t, err, "failed to wait for canary task to be running:") }) } @@ -186,14 +186,14 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { }}, }, nil), ) - err := cm.WaitContainerHealthCheck(context.TODO()) + err := cm.waitContainerHealthCheck(context.TODO()) assert.NoError(t, err) }) t.Run("should do nothing if no container has health check", func(t *testing.T) { env := test.DefaultEnvars() env.TaskDefinitionInput.ContainerDefinitions[0].HealthCheck = nil _, _, _, cm := setup(t, env) - err := cm.WaitContainerHealthCheck(context.TODO()) + err := cm.waitContainerHealthCheck(context.TODO()) assert.NoError(t, err) }) t.Run("should error if DescribeTasks failed", func(t *testing.T) { @@ -204,7 +204,7 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { timerMock.EXPECT().NewTimer(15*time.Second).DoAndReturn(faketime.NewTimer), ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")), ) - err := cm.WaitContainerHealthCheck(context.TODO()) + err := cm.waitContainerHealthCheck(context.TODO()) assert.EqualError(t, err, "error") }) t.Run("should error if context is canceled", func(t *testing.T) { @@ -213,7 +213,7 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { timerMock.EXPECT().NewTimer(15 * time.Second).DoAndReturn(time.NewTimer) ctx, cancel := context.WithCancel(context.Background()) cancel() - err := cm.WaitContainerHealthCheck(ctx) + err := cm.waitContainerHealthCheck(ctx) assert.EqualError(t, err, "context canceled") }) t.Run("should error if task is not running", func(t *testing.T) { @@ -229,7 +229,7 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { }, nil), ) - err := cm.WaitContainerHealthCheck(context.TODO()) + err := cm.waitContainerHealthCheck(context.TODO()) assert.EqualError(t, err, "😫 canary task has stopped: reason") }) t.Run("shold error if unhealth counts exceed the limit", func(t *testing.T) { @@ -250,7 +250,7 @@ func TestCommon_WaitContainerHealthCheck(t *testing.T) { }}, }, nil), ) - err := cm.WaitContainerHealthCheck(context.TODO()) + err := cm.waitContainerHealthCheck(context.TODO()) assert.EqualError(t, err, "😨 canary task hasn't become to be healthy") }) } @@ -276,18 +276,18 @@ func TestCommon_StopTask(t *testing.T) { Tasks: []ecstypes.Task{{LastStatus: aws.String("STOPPED")}}, }, nil), ) - err := cm.StopTask(context.TODO()) + err := cm.stopTask(context.TODO()) assert.NoError(t, err) }) t.Run("should do nothing if task is not started", func(t *testing.T) { cm := &common{} - err := cm.StopTask(context.TODO()) + err := cm.stopTask(context.TODO()) assert.NoError(t, err) }) t.Run("should error if StopTask failed", func(t *testing.T) { ecsMock, cm := setup(t, test.DefaultEnvars()) ecsMock.EXPECT().StopTask(gomock.Any(), gomock.Any()).Return(nil, fmt.Errorf("error")) - err := cm.StopTask(context.TODO()) + err := cm.stopTask(context.TODO()) assert.EqualError(t, err, "failed to stop canary task: error") }) t.Run("should error wait time exceeds the limit", func(t *testing.T) { @@ -300,7 +300,7 @@ func TestCommon_StopTask(t *testing.T) { Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING")}}, }, nil), ) - err := cm.StopTask(context.TODO()) + err := cm.stopTask(context.TODO()) assert.ErrorContains(t, err, "failed to wait for canary task to be stopped") }) } diff --git a/task/simple_task.go b/task/simple_task.go index b2b25cd..ad2de30 100644 --- a/task/simple_task.go +++ b/task/simple_task.go @@ -24,20 +24,20 @@ func NewSimpleTask(di *di.D, input *Input) Task { } func (c *simpleTask) Wait(ctx context.Context) error { - if err := c.WaitForTaskRunning(ctx); err != nil { + if err := c.waitForTaskRunning(ctx); err != nil { return err } - if err := c.WaitContainerHealthCheck(ctx); err != nil { + if err := c.waitContainerHealthCheck(ctx); err != nil { return err } - return c.WaitForIdleDuration(ctx) + return c.waitForIdleDuration(ctx) } func (c *simpleTask) Stop(ctx context.Context) error { - return c.StopTask(ctx) + return c.stopTask(ctx) } -func (c *simpleTask) WaitForIdleDuration(ctx context.Context) error { +func (c *simpleTask) waitForIdleDuration(ctx context.Context) error { env := c.di.Get(key.Env).(*env.Envars) timer := c.di.Get(key.Time).(types.Time) log.Infof("wait %d seconds for canary task to be stable...", env.CanaryTaskIdleDuration) diff --git a/task/simple_task_test.go b/task/simple_task_test.go index 233be89..3d164e1 100644 --- a/task/simple_task_test.go +++ b/task/simple_task_test.go @@ -61,6 +61,59 @@ func TestSimpleTask(t *testing.T) { assert.Equal(t, 1, mocker.RunningTaskSize()) } +func TestSimpleTask_Wait(t *testing.T) { + t.Run("should error if task is not running", func(t *testing.T) { + ctrl := gomock.NewController(t) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + cm := &simpleTask{ + common: &common{ + taskArn: aws.String("task-arn"), + Input: &Input{}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, test.DefaultEnvars()) + b.Set(key.EcsCli, ecsMock) + }), + }, + } + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("STOPPED")}}, + }, nil) + err := cm.Wait(context.TODO()) + assert.ErrorContains(t, err, "failed to wait for canary task to be running") + }) + t.Run("should error if container is not healthy", func(t *testing.T) { + ctrl := gomock.NewController(t) + ecsMock := mock_awsiface.NewMockEcsClient(ctrl) + mocker := test.NewMockContext() + td, _ := mocker.Ecs.RegisterTaskDefinition(context.TODO(), test.DefaultEnvars().TaskDefinitionInput) + env := test.DefaultEnvars() + env.CanaryTaskHealthCheckWait = 1 + cm := &simpleTask{ + common: &common{ + taskArn: aws.String("task-arn"), + Input: &Input{TaskDefinition: td.TaskDefinition}, + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + b.Set(key.Time, test.NewFakeTime()) + }), + }, + } + ecsMock.EXPECT().DescribeTasks(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ecs.DescribeTasksOutput{ + Tasks: []ecstypes.Task{{LastStatus: aws.String("RUNNING"), + Containers: []ecstypes.Container{{ + Name: env.TaskDefinitionInput.ContainerDefinitions[0].Name, + HealthStatus: ecstypes.HealthStatusUnhealthy, + }}, + }}, + }, nil).Times(2) + err := cm.Wait(context.TODO()) + assert.ErrorContains(t, err, "canary task hasn't become to be healthy") + }) +} + func TestSimpleTask_WaitForIdleDuration(t *testing.T) { setup := func(t *testing.T, idle int) (*mock_awsiface.MockEcsClient, *mock_types.MockTime, *simpleTask) { ctrl := gomock.NewController(t) @@ -101,7 +154,7 @@ func TestSimpleTask_WaitForIdleDuration(t *testing.T) { }, nil). Times(1), ) - err := cm.WaitForIdleDuration(context.TODO()) + err := cm.waitForIdleDuration(context.TODO()) assert.NoError(t, err) }) t.Run("should error if DescribeTasks failed", func(t *testing.T) { @@ -115,7 +168,7 @@ func TestSimpleTask_WaitForIdleDuration(t *testing.T) { Return(nil, assert.AnError). Times(1), ) - err := cm.WaitForIdleDuration(context.TODO()) + err := cm.waitForIdleDuration(context.TODO()) assert.EqualError(t, err, assert.AnError.Error()) }) t.Run("sholud error if ctx is canceled", func(t *testing.T) { @@ -127,7 +180,7 @@ func TestSimpleTask_WaitForIdleDuration(t *testing.T) { ) ctx, cancel := context.WithCancel(context.Background()) cancel() - err := cm.WaitForIdleDuration(ctx) + err := cm.waitForIdleDuration(ctx) assert.EqualError(t, err, "context canceled") }) t.Run("should error if task is not started", func(t *testing.T) { @@ -145,7 +198,7 @@ func TestSimpleTask_WaitForIdleDuration(t *testing.T) { }}, }, nil), ) - err := cm.WaitForIdleDuration(context.TODO()) + err := cm.waitForIdleDuration(context.TODO()) assert.EqualError(t, err, "😫 canary task has stopped: reason") }) } From f55317bb26588dfb3784e6da5bce615537f24283 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 8 Jul 2024 17:01:57 +0900 Subject: [PATCH 44/47] Update alb_task_test.go --- task/alb_task_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/task/alb_task_test.go b/task/alb_task_test.go index 4281bbc..e652d34 100644 --- a/task/alb_task_test.go +++ b/task/alb_task_test.go @@ -632,6 +632,22 @@ func TestAlbTask_DeregisterTarget(t *testing.T) { ) atask.deregisterTarget(context.TODO()) }) + t.Run("should call DeregisterTargets even if getTargetDeregistrationDelay failed", func(t *testing.T) { + env := test.DefaultEnvars() + albMock, atask := setup(t, env) + gomock.InOrder( + albMock.EXPECT().DescribeTargetGroupAttributes(gomock.Any(), gomock.Any()).Return(nil, assert.AnError).Times(1), + albMock.EXPECT().DeregisterTargets(gomock.Any(), gomock.Any()).Return(nil, nil).Times(1), + albMock.EXPECT().DescribeTargetHealth(gomock.Any(), gomock.Any(), gomock.Any()).Return(&elbv2.DescribeTargetHealthOutput{ + TargetHealthDescriptions: []elbv2types.TargetHealthDescription{ + {TargetHealth: &elbv2types.TargetHealth{State: elbv2types.TargetHealthStateEnumUnused}, + Target: target, + }, + }, + }, nil).Times(1), + ) + atask.deregisterTarget(context.TODO()) + }) t.Run("should return even if DeregisterTargets failed", func(t *testing.T) { env := test.DefaultEnvars() albMock, atask := setup(t, env) From 436391b487c774fee39620698f5cb2cc09bf24a6 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Mon, 8 Jul 2024 17:14:04 +0900 Subject: [PATCH 45/47] tests Update flags.go --- cli/cage/commands/command_test.go | 49 +++++++++++++++++++++++++++++-- cli/cage/commands/flags.go | 2 +- env/env.go | 3 +- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/cli/cage/commands/command_test.go b/cli/cage/commands/command_test.go index 776c39b..804e702 100644 --- a/cli/cage/commands/command_test.go +++ b/cli/cage/commands/command_test.go @@ -1,4 +1,4 @@ -package commands_test +package commands import ( "fmt" @@ -6,9 +6,9 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/loilo-inc/canarycage/cli/cage/commands" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/mocks/mock_types" + "github.com/loilo-inc/canarycage/test" "github.com/loilo-inc/canarycage/types" "github.com/stretchr/testify/assert" "github.com/urfave/cli/v2" @@ -25,7 +25,7 @@ func TestCommands(t *testing.T) { stdin := strings.NewReader(input) cagecli := mock_types.NewMockCage(ctrl) app := cli.NewApp() - cmds := commands.NewCageCommands(stdin, func(envars *env.Envars) (types.Cage, error) { + cmds := NewCageCommands(stdin, func(envars *env.Envars) (types.Cage, error) { return cagecli, nil }) envars := env.Envars{CI: input == ""} @@ -103,3 +103,46 @@ func TestCommands(t *testing.T) { }) }) } + +func TestSetupCage(t *testing.T) { + t.Run("basic", func(t *testing.T) { + envars := &env.Envars{Region: "us-west-2"} + cageCli := mock_types.NewMockCage(gomock.NewController(t)) + cmd := NewCageCommands(nil, func(envars *env.Envars) (types.Cage, error) { + return cageCli, nil + }) + v, err := cmd.setupCage(envars, "../../../fixtures") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, v, cageCli) + assert.Equal(t, envars.Service, "service") + assert.Equal(t, envars.Cluster, "cluster") + assert.NotNil(t, envars.ServiceDefinitionInput) + assert.NotNil(t, envars.TaskDefinitionInput) + }) + t.Run("should skip load task definition if --taskDefinitionArn provided", func(t *testing.T) { + envars := &env.Envars{Region: "us-west-2", TaskDefinitionArn: "arn"} + cageCli := mock_types.NewMockCage(gomock.NewController(t)) + cmd := NewCageCommands(nil, func(envars *env.Envars) (types.Cage, error) { + return cageCli, nil + }) + v, err := cmd.setupCage(envars, "../../../fixtures") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, v, cageCli) + assert.Equal(t, envars.Service, "service") + assert.Equal(t, envars.Cluster, "cluster") + assert.NotNil(t, envars.ServiceDefinitionInput) + assert.Nil(t, envars.TaskDefinitionInput) + }) + t.Run("should error if error returned from NewCage", func(t *testing.T) { + envars := &env.Envars{Region: "us-west-2"} + cmd := NewCageCommands(nil, func(envars *env.Envars) (types.Cage, error) { + return nil, test.Err + }) + _, err := cmd.setupCage(envars, "../../../fixtures") + assert.EqualError(t, err, "error") + }) +} diff --git a/cli/cage/commands/flags.go b/cli/cage/commands/flags.go index 0e228b5..eb991bb 100644 --- a/cli/cage/commands/flags.go +++ b/cli/cage/commands/flags.go @@ -45,7 +45,7 @@ func CanaryTaskIdleDurationFlag(dest *int) *cli.IntFlag { EnvVars: []string{env.CanaryTaskIdleDuration}, Usage: "duration seconds for waiting canary task that isn't attached to target group considered as ready for serving traffic", Destination: dest, - Value: 10, + Value: 15, } } diff --git a/env/env.go b/env/env.go index 39330d7..b46ecf7 100644 --- a/env/env.go +++ b/env/env.go @@ -54,7 +54,8 @@ func EnsureEnvars( } if dest.Cluster == "" { return xerrors.Errorf("--cluster [%s] is required", ClusterKey) - } else if dest.Service == "" { + } + if dest.Service == "" { return xerrors.Errorf("--service [%s] is required", ServiceKey) } if dest.TaskDefinitionArn == "" && dest.TaskDefinitionInput == nil { From b6f9c4517954de3dd108fa237bc8ccd59fcc94fb Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Tue, 9 Jul 2024 18:05:18 +0900 Subject: [PATCH 46/47] PR --- cli/cage/commands/command.go | 4 ++-- env/env.go | 4 ++-- env/env_test.go | 12 +++++------ export_test.go | 9 --------- task_definition_test.go | 39 +++++++++++++++++++----------------- 5 files changed, 31 insertions(+), 37 deletions(-) delete mode 100644 export_test.go diff --git a/cli/cage/commands/command.go b/cli/cage/commands/command.go index 62812e6..c8c162e 100644 --- a/cli/cage/commands/command.go +++ b/cli/cage/commands/command.go @@ -49,13 +49,13 @@ func (c *CageCommands) setupCage( ) (types.Cage, error) { var service *ecs.CreateServiceInput var taskDefinition *ecs.RegisterTaskDefinitionInput - if srv, err := env.LoadServiceDefiniton(dir); err != nil { + if srv, err := env.LoadServiceDefinition(dir); err != nil { return nil, err } else { service = srv } if envars.TaskDefinitionArn == "" { - if td, err := env.LoadTaskDefiniton(dir); err != nil { + if td, err := env.LoadTaskDefinition(dir); err != nil { return nil, err } else { taskDefinition = td diff --git a/env/env.go b/env/env.go index b46ecf7..54bdff5 100644 --- a/env/env.go +++ b/env/env.go @@ -64,7 +64,7 @@ func EnsureEnvars( return nil } -func LoadServiceDefiniton(dir string) (*ecs.CreateServiceInput, error) { +func LoadServiceDefinition(dir string) (*ecs.CreateServiceInput, error) { svcPath := filepath.Join(dir, "service.json") _, noSvc := os.Stat(svcPath) var service ecs.CreateServiceInput @@ -77,7 +77,7 @@ func LoadServiceDefiniton(dir string) (*ecs.CreateServiceInput, error) { return &service, nil } -func LoadTaskDefiniton(dir string) (*ecs.RegisterTaskDefinitionInput, error) { +func LoadTaskDefinition(dir string) (*ecs.RegisterTaskDefinitionInput, error) { tdPath := filepath.Join(dir, "task-definition.json") _, noTd := os.Stat(tdPath) var td ecs.RegisterTaskDefinitionInput diff --git a/env/env_test.go b/env/env_test.go index 2100420..90de950 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -88,36 +88,36 @@ func TestMergeEnvars(t *testing.T) { func TestLoadServiceDefinition(t *testing.T) { t.Run("basic", func(t *testing.T) { - d, err := env.LoadServiceDefiniton("../fixtures") + d, err := env.LoadServiceDefinition("../fixtures") if err != nil { t.Fatalf(err.Error()) } assert.Equal(t, *d.ServiceName, "service") }) t.Run("should error if service.json is not found", func(t *testing.T) { - _, err := env.LoadServiceDefiniton("./testdata") + _, err := env.LoadServiceDefinition("./testdata") assert.EqualError(t, err, "roll out context specified at './testdata' but no 'service.json' or 'task-definition.json'") }) t.Run("should error if service.json is invalid", func(t *testing.T) { - _, err := env.LoadServiceDefiniton("./testdata/invalid") + _, err := env.LoadServiceDefinition("./testdata/invalid") assert.ErrorContains(t, err, "failed to read and unmarshal service.json:") }) } func TestLoadTaskDefinition(t *testing.T) { t.Run("basic", func(t *testing.T) { - d, err := env.LoadTaskDefiniton("../fixtures") + d, err := env.LoadTaskDefinition("../fixtures") if err != nil { t.Fatalf(err.Error()) } assert.Equal(t, *d.Family, "test-task") }) t.Run("should error if task-definition.json is not found", func(t *testing.T) { - _, err := env.LoadTaskDefiniton("./testdata") + _, err := env.LoadTaskDefinition("./testdata") assert.EqualError(t, err, "roll out context specified at './testdata' but no 'service.json' or 'task-definition.json'") }) t.Run("should error if task-definition.json is invalid", func(t *testing.T) { - _, err := env.LoadTaskDefiniton("./testdata/invalid") + _, err := env.LoadTaskDefinition("./testdata/invalid") assert.ErrorContains(t, err, "failed to read and unmarshal task-definition.json:") }) } diff --git a/export_test.go b/export_test.go deleted file mode 100644 index fd50637..0000000 --- a/export_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package cage - -import "github.com/loilo-inc/logos/di" - -type CageExport = cage - -func NewCageExport(di *di.D) *cage { - return &cage{di} -} diff --git a/task_definition_test.go b/task_definition_test.go index db5f8ac..dc96c02 100644 --- a/task_definition_test.go +++ b/task_definition_test.go @@ -1,4 +1,4 @@ -package cage_test +package cage import ( "context" @@ -7,7 +7,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ecs" ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" "github.com/golang/mock/gomock" - cage "github.com/loilo-inc/canarycage" "github.com/loilo-inc/canarycage/env" "github.com/loilo-inc/canarycage/key" "github.com/loilo-inc/canarycage/mocks/mock_awsiface" @@ -24,10 +23,11 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { env := &env.Envars{ TaskDefinitionArn: "arn://aaa", } - c := cage.NewCageExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, env) - b.Set(key.EcsCli, ecsMock) - })) + c := &cage{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })} ecsMock.EXPECT().DescribeTaskDefinition(gomock.Any(), gomock.Any()).Return(&ecs.DescribeTaskDefinitionOutput{ TaskDefinition: &ecstypes.TaskDefinition{}, }, nil) @@ -41,10 +41,11 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { env := &env.Envars{ TaskDefinitionArn: "arn://aaa", } - c := cage.NewCageExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, env) - b.Set(key.EcsCli, ecsMock) - })) + c := &cage{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })} ecsMock.EXPECT().DescribeTaskDefinition(gomock.Any(), gomock.Any()).Return(nil, xerrors.New("error")) td, err := c.CreateNextTaskDefinition(context.Background()) assert.Errorf(t, err, "failed to describe next task definition: error") @@ -54,10 +55,11 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() - c := cage.NewCageExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, env) - b.Set(key.EcsCli, ecsMock) - })) + c := &cage{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })} ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).Return(&ecs.RegisterTaskDefinitionOutput{ TaskDefinition: &ecstypes.TaskDefinition{ Family: env.TaskDefinitionInput.Family, @@ -72,10 +74,11 @@ func TestCage_CreateNextTaskDefinition(t *testing.T) { ctrl := gomock.NewController(t) ecsMock := mock_awsiface.NewMockEcsClient(ctrl) env := test.DefaultEnvars() - c := cage.NewCageExport(di.NewDomain(func(b *di.B) { - b.Set(key.Env, env) - b.Set(key.EcsCli, ecsMock) - })) + c := &cage{ + di: di.NewDomain(func(b *di.B) { + b.Set(key.Env, env) + b.Set(key.EcsCli, ecsMock) + })} ecsMock.EXPECT().RegisterTaskDefinition(gomock.Any(), gomock.Any()).Return(nil, xerrors.New("error")) td, err := c.CreateNextTaskDefinition(context.Background()) assert.Errorf(t, err, "failed to register next task definition: error") From 65afc07d5ec9cc0e7e3b0c8c827f1eef06ebdcf9 Mon Sep 17 00:00:00 2001 From: Yusuke Sakurai Date: Wed, 10 Jul 2024 20:55:54 +0900 Subject: [PATCH 47/47] PR --- env/env.go | 8 ++++---- env/env_test.go | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/env/env.go b/env/env.go index 54bdff5..0922204 100644 --- a/env/env.go +++ b/env/env.go @@ -69,10 +69,10 @@ func LoadServiceDefinition(dir string) (*ecs.CreateServiceInput, error) { _, noSvc := os.Stat(svcPath) var service ecs.CreateServiceInput if noSvc != nil { - return nil, xerrors.Errorf("roll out context specified at '%s' but no 'service.json' or 'task-definition.json'", dir) + return nil, xerrors.Errorf("no 'service.json' found in %s", dir) } if err := ReadAndUnmarshalJson(svcPath, &service); err != nil { - return nil, xerrors.Errorf("failed to read and unmarshal service.json: %s", err) + return nil, xerrors.Errorf("failed to read and unmarshal 'service.json': %s", err) } return &service, nil } @@ -82,10 +82,10 @@ func LoadTaskDefinition(dir string) (*ecs.RegisterTaskDefinitionInput, error) { _, noTd := os.Stat(tdPath) var td ecs.RegisterTaskDefinitionInput if noTd != nil { - return nil, xerrors.Errorf("roll out context specified at '%s' but no 'service.json' or 'task-definition.json'", dir) + return nil, xerrors.Errorf("no 'task-definition.json' found in %s", dir) } if err := ReadAndUnmarshalJson(tdPath, &td); err != nil { - return nil, xerrors.Errorf("failed to read and unmarshal task-definition.json: %s", err) + return nil, xerrors.Errorf("failed to read and unmarshal 'task-definition.json': %s", err) } return &td, nil } diff --git a/env/env_test.go b/env/env_test.go index 90de950..425999e 100644 --- a/env/env_test.go +++ b/env/env_test.go @@ -96,11 +96,11 @@ func TestLoadServiceDefinition(t *testing.T) { }) t.Run("should error if service.json is not found", func(t *testing.T) { _, err := env.LoadServiceDefinition("./testdata") - assert.EqualError(t, err, "roll out context specified at './testdata' but no 'service.json' or 'task-definition.json'") + assert.EqualError(t, err, "no 'service.json' found in ./testdata") }) t.Run("should error if service.json is invalid", func(t *testing.T) { _, err := env.LoadServiceDefinition("./testdata/invalid") - assert.ErrorContains(t, err, "failed to read and unmarshal service.json:") + assert.ErrorContains(t, err, "failed to read and unmarshal 'service.json':") }) } @@ -114,11 +114,11 @@ func TestLoadTaskDefinition(t *testing.T) { }) t.Run("should error if task-definition.json is not found", func(t *testing.T) { _, err := env.LoadTaskDefinition("./testdata") - assert.EqualError(t, err, "roll out context specified at './testdata' but no 'service.json' or 'task-definition.json'") + assert.EqualError(t, err, "no 'task-definition.json' found in ./testdata") }) t.Run("should error if task-definition.json is invalid", func(t *testing.T) { _, err := env.LoadTaskDefinition("./testdata/invalid") - assert.ErrorContains(t, err, "failed to read and unmarshal task-definition.json:") + assert.ErrorContains(t, err, "failed to read and unmarshal 'task-definition.json':") }) }