Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to add features to existing feature groups in Sagemaker #27933

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions internal/service/sagemaker/feature_group.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package sagemaker

import (
"context"
"fmt"
"log"
"regexp"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/sagemaker"
"github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
Expand All @@ -26,6 +28,20 @@ func ResourceFeatureGroup() *schema.Resource {
Importer: &schema.ResourceImporter{
State: schema.ImportStatePassthrough,
},
CustomizeDiff: customdiff.Sequence(
func(_ context.Context, d *schema.ResourceDiff, meta interface{}) error {
o, n := d.GetChange("feature_definition")
var featureDefinitionsOld = expandFeatureGroupFeatureDefinition(o.([]interface{}))
var featureDefinitionsNew = expandFeatureGroupFeatureDefinition(n.([]interface{}))

if !checkIfDefinitionsUnchanged(featureDefinitionsOld, featureDefinitionsNew) {
return fmt.Errorf("existing feature_definitions of SageMaker Feature Group (%s) can not "+
"be changed", d.Id())
}
return nil
},
verify.SetTagsDiff,
),

Schema: map[string]*schema.Schema{
"arn": {
Expand Down Expand Up @@ -77,7 +93,6 @@ func ResourceFeatureGroup() *schema.Resource {
"feature_definition": {
Type: schema.TypeList,
Required: true,
ForceNew: true,
MinItems: 1,
MaxItems: 2500,
Elem: &schema.Resource{
Expand Down Expand Up @@ -191,8 +206,6 @@ func ResourceFeatureGroup() *schema.Resource {
"tags": tftags.TagsSchema(),
"tags_all": tftags.TagsSchemaComputed(),
},

CustomizeDiff: verify.SetTagsDiff,
}
}

Expand Down Expand Up @@ -326,6 +339,40 @@ func resourceFeatureGroupUpdate(d *schema.ResourceData, meta interface{}) error
}
}

o, n := d.GetChange("feature_definition")

var featureDefinitionsOld = expandFeatureGroupFeatureDefinition(o.([]interface{}))
var featureDefinitionsNew = expandFeatureGroupFeatureDefinition(n.([]interface{}))

var newFeatures []*sagemaker.FeatureDefinition

for _, elem := range featureDefinitionsNew {
var elementExists = false

for _, existingElem := range featureDefinitionsOld {
if *existingElem.FeatureName == *elem.FeatureName {
elementExists = true
break
}
}

if !elementExists {
newFeatures = append(newFeatures, elem)
}
}

if len(newFeatures) > 0 {
input := &sagemaker.UpdateFeatureGroupInput{
FeatureGroupName: aws.String(d.Id()),
FeatureAdditions: newFeatures,
}

_, err := conn.UpdateFeatureGroup(input)
if err != nil {
return fmt.Errorf("adding new feature_definition for SageMaker Feature Group (%s): %w", d.Id(), err)
}
}

return resourceFeatureGroupRead(d, meta)
}

Expand Down Expand Up @@ -551,3 +598,23 @@ func flattenFeatureGroupOfflineStoreConfigDataCatalogConfig(config *sagemaker.Da

return []map[string]interface{}{m}
}

func checkIfDefinitionsUnchanged(o []*sagemaker.FeatureDefinition, n []*sagemaker.FeatureDefinition) bool {
var res = true
for _, elem := range o {
var elementExists = false
for _, newElem := range n {
if *newElem.FeatureName == *elem.FeatureName && *newElem.FeatureType == *elem.FeatureType {
elementExists = true
break
}
}

if !elementExists {
res = false
break
}
}

return res
}
111 changes: 111 additions & 0 deletions internal/service/sagemaker/feature_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package sagemaker_test

import (
"fmt"
"log"
"regexp"
"testing"
"time"

"github.com/aws/aws-sdk-go/service/sagemaker"
sdkacctest "github.com/hashicorp/terraform-plugin-sdk/v2/helper/acctest"
Expand All @@ -26,6 +28,8 @@ func TestAccSageMakerFeatureGroup_serial(t *testing.T) {
"offlineConfig_providedCatalog": TestAccSageMakerFeatureGroup_Offline_providedCatalog,
"onlineConfigSecurityConfig": testAccFeatureGroup_onlineConfigSecurityConfig,
"tags": testAccFeatureGroup_tags,
"update": testAccFeatureGroup_update,
"existingFeaturesCantChange": testAccFeatureGroup_existingFeaturesCantChange,
}

for name, tc := range testCases {
Expand Down Expand Up @@ -211,6 +215,93 @@ func testAccFeatureGroup_onlineConfigSecurityConfig(t *testing.T) {
})
}

func testAccFeatureGroup_update(t *testing.T) {
var featureGroup sagemaker.DescribeFeatureGroupOutput
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_feature_group.test"

propagationSleep := func() resource.TestCheckFunc {
return func(s *terraform.State) error {
// Feature addition happen asynchronously, hence we will wait a nominal amount of time
// to wait until it is complete
log.Print("[DEBUG] Test: Sleep to allow feature.")
time.Sleep(5 * time.Second)
return nil
}
}

resource.Test(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckFeatureGroupDestroy,
Steps: []resource.TestStep{
{
Config: testAccFeatureGroupConfig_basic(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckFeatureGroupExists(resourceName, &featureGroup),
resource.TestCheckResourceAttr(resourceName, "feature_group_name", rName),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
{
Config: testAccFeatureGroupConfig_multi(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckFeatureGroupExists(resourceName, &featureGroup),
propagationSleep(),
resource.TestCheckResourceAttr(resourceName, "feature_group_name", rName),
resource.TestCheckResourceAttr(resourceName, "feature_definition.#", "2"),
resource.TestCheckResourceAttr(resourceName, "feature_definition.0.feature_name", rName),
resource.TestCheckResourceAttr(resourceName, "feature_definition.0.feature_type", "String"),
resource.TestCheckResourceAttr(resourceName, "feature_definition.1.feature_name", fmt.Sprintf("%s-2", rName)),
resource.TestCheckResourceAttr(resourceName, "feature_definition.1.feature_type", "Integral"),
),
},
},
})
}

func testAccFeatureGroup_existingFeaturesCantChange(t *testing.T) {
var featureGroup sagemaker.DescribeFeatureGroupOutput
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_feature_group.test"

resource.Test(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckFeatureGroupDestroy,
Steps: []resource.TestStep{
{
Config: testAccFeatureGroupConfig_basic(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckFeatureGroupExists(resourceName, &featureGroup),
resource.TestCheckResourceAttr(resourceName, "feature_group_name", rName),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
{
Config: testAccFeatureGroupConfig_featureChanged(rName, fmt.Sprintf(`%[1]s-2`, rName), "String"),
ExpectError: regexp.MustCompile("existing feature_definitions of SageMaker Feature Group \\((.+)\\) can not be changed"),
PlanOnly: true,
},
{
Config: testAccFeatureGroupConfig_featureChanged(rName, rName, "Integral"),
ExpectError: regexp.MustCompile("existing feature_definitions of SageMaker Feature Group \\((.+)\\) can not be changed"),
PlanOnly: true,
},
},
})
}

func testAccFeatureGroup_offlineConfig_basic(t *testing.T) {
var featureGroup sagemaker.DescribeFeatureGroupOutput
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -485,6 +576,26 @@ resource "aws_sagemaker_feature_group" "test" {
`, rName))
}

func testAccFeatureGroupConfig_featureChanged(rName string, featureName string, featureType string) string {
return acctest.ConfigCompose(testAccFeatureGroupBaseConfig(rName), fmt.Sprintf(`
resource "aws_sagemaker_feature_group" "test" {
feature_group_name = %[1]q
record_identifier_feature_name = %[1]q
event_time_feature_name = %[1]q
role_arn = aws_iam_role.test.arn

feature_definition {
feature_name = %[2]q
feature_type = %[3]q
}

online_store_config {
enable_online_store = true
}
}
`, rName, featureName, featureType))
}

func testAccFeatureGroupConfig_multi(rName string) string {
return acctest.ConfigCompose(testAccFeatureGroupBaseConfig(rName), fmt.Sprintf(`
resource "aws_sagemaker_feature_group" "test" {
Expand Down