diff --git a/internal/service/advancedcluster/resource_advanced_cluster_migration_test.go b/internal/service/advancedcluster/resource_advanced_cluster_migration_test.go index 0104376391..71f9af4dfe 100644 --- a/internal/service/advancedcluster/resource_advanced_cluster_migration_test.go +++ b/internal/service/advancedcluster/resource_advanced_cluster_migration_test.go @@ -29,7 +29,6 @@ func TestMigAdvancedCluster_singleShardedMultiCloud(t *testing.T) { } func TestMigAdvancedCluster_symmetricGeoShardedOldSchema(t *testing.T) { - acc.SkipIfAdvancedClusterV2Schema(t) // unexpected update and then: error operation not permitted, nums_shards from 1 -> > 1 testCase := symmetricGeoShardedOldSchemaTestCase(t, false) mig.CreateAndRunTest(t, &testCase) } diff --git a/internal/service/advancedclustertpf/move_upgrade_state.go b/internal/service/advancedclustertpf/move_upgrade_state.go index 9e7d41bb3d..2154d7b5e2 100644 --- a/internal/service/advancedclustertpf/move_upgrade_state.go +++ b/internal/service/advancedclustertpf/move_upgrade_state.go @@ -3,12 +3,12 @@ package advancedclustertpf import ( "context" "fmt" + "math/big" "strings" "github.com/hashicorp/terraform-plugin-framework-timeouts/resource/timeouts" "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/diag" - "github.com/hashicorp/terraform-plugin-framework/path" "github.com/hashicorp/terraform-plugin-framework/resource" "github.com/hashicorp/terraform-plugin-framework/tfsdk" "github.com/hashicorp/terraform-plugin-framework/types" @@ -18,10 +18,12 @@ import ( "go.mongodb.org/atlas-sdk/v20241113004/admin" ) +// MoveState is used with moved block to upgrade from cluster to adv_cluster func (r *rs) MoveState(context.Context) []resource.StateMover { return []resource.StateMover{{StateMover: stateMover}} } +// UpgradeState is used to upgrade from adv_cluster schema v1 (SDKv2) to v2 (TPF) func (r *rs) UpgradeState(ctx context.Context) map[int64]resource.StateUpgrader { return map[int64]resource.StateUpgrader{ 1: {StateUpgrader: stateUpgraderFromV1}, @@ -39,91 +41,95 @@ func stateUpgraderFromV1(ctx context.Context, req resource.UpgradeStateRequest, setStateResponse(ctx, &resp.Diagnostics, req.RawState, &resp.State) } -func setStateResponse(ctx context.Context, diags *diag.Diagnostics, stateIn *tfprotov6.RawState, stateOut *tfsdk.State) { - rawStateValue, err := stateIn.UnmarshalWithOpts(tftypes.Object{ +// stateAttrs has the attributes needed from source schema. +// Filling these attributes in the destination will prevent plan changes when moving/upgrading state. +// Read will fill in the rest. +var stateAttrs = map[string]tftypes.Type{ + "project_id": tftypes.String, // project_id and name to identify the cluster + "name": tftypes.String, + "retain_backups_enabled": tftypes.Bool, // TF specific so can't be got in Read + "mongo_db_major_version": tftypes.String, // Has special logic in overrideAttributesWithPrevStateValue that needs the previous state + "timeouts": tftypes.Object{ // TF specific so can't be got in Read AttributeTypes: map[string]tftypes.Type{ - "project_id": tftypes.String, - "name": tftypes.String, - "retain_backups_enabled": tftypes.Bool, - "mongo_db_major_version": tftypes.String, - "timeouts": tftypes.Object{ - AttributeTypes: map[string]tftypes.Type{ - "create": tftypes.String, - "update": tftypes.String, - "delete": tftypes.String, - }, + "create": tftypes.String, + "update": tftypes.String, + "delete": tftypes.String, + }, + }, + "replication_specs": tftypes.List{ // Needed to send num_shards to Read so it can decide if it's using the legacy schema. + ElementType: tftypes.Object{ + AttributeTypes: map[string]tftypes.Type{ + "num_shards": tftypes.Number, }, }, + }, +} + +func setStateResponse(ctx context.Context, diags *diag.Diagnostics, stateIn *tfprotov6.RawState, stateOut *tfsdk.State) { + rawStateValue, err := stateIn.UnmarshalWithOpts(tftypes.Object{ + AttributeTypes: stateAttrs, }, tfprotov6.UnmarshalOpts{ValueFromJSONOpts: tftypes.ValueFromJSONOpts{IgnoreUndefinedAttributes: true}}) if err != nil { diags.AddError("Unable to Unmarshal state", err.Error()) return } - var rawState map[string]tftypes.Value - if err := rawStateValue.As(&rawState); err != nil { + var stateObj map[string]tftypes.Value + if err := rawStateValue.As(&stateObj); err != nil { diags.AddError("Unable to Parse state", err.Error()) return } - - projectID := getAttrFromRawState[string](diags, rawState, "project_id") - name := getAttrFromRawState[string](diags, rawState, "name") + projectID, name := getProjectIDNameFromStateObj(diags, stateObj) if diags.HasError() { return } - if !conversion.IsStringPresent(projectID) || !conversion.IsStringPresent(name) { - diags.AddError("Unable to read project_id or name from state", fmt.Sprintf("project_id: %s, name: %s", - conversion.SafeString(projectID), conversion.SafeString(name))) - return - } - model := NewTFModel(ctx, &admin.ClusterDescription20240805{ GroupId: projectID, Name: name, - }, getAttrTimeout(diags, rawState), diags, ExtraAPIInfo{}) + }, getTimeoutFromStateObj(stateObj), diags, ExtraAPIInfo{}) if diags.HasError() { return } - - if retainBackupsEnabled := getAttrFromRawState[bool](diags, rawState, "retain_backups_enabled"); retainBackupsEnabled != nil { - model.RetainBackupsEnabled = types.BoolPointerValue(retainBackupsEnabled) - } - if mongoDBMajorVersion := getAttrFromRawState[string](diags, rawState, "mongo_db_major_version"); mongoDBMajorVersion != nil { - model.MongoDBMajorVersion = types.StringPointerValue(mongoDBMajorVersion) - } - if diags.HasError() { - return - } - AddAdvancedConfig(ctx, model, nil, nil, diags) if diags.HasError() { return } + setOptionalModelAttrs(ctx, stateObj, model) diags.Append(stateOut.Set(ctx, model)...) } -func getAttrFromRawState[T any](diags *diag.Diagnostics, rawState map[string]tftypes.Value, attrName string) *T { +func getAttrFromStateObj[T any](rawState map[string]tftypes.Value, attrName string) *T { var ret *T if err := rawState[attrName].As(&ret); err != nil { - diags.AddAttributeError(path.Root(attrName), fmt.Sprintf("Unable to read cluster %s", attrName), err.Error()) return nil } return ret } -func getAttrTimeout(diags *diag.Diagnostics, rawState map[string]tftypes.Value) timeouts.Value { +func getProjectIDNameFromStateObj(diags *diag.Diagnostics, stateObj map[string]tftypes.Value) (projectID, name *string) { + projectID = getAttrFromStateObj[string](stateObj, "project_id") + name = getAttrFromStateObj[string](stateObj, "name") + if !conversion.IsStringPresent(projectID) || !conversion.IsStringPresent(name) { + diags.AddError("Unable to read project_id or name from state", fmt.Sprintf("project_id: %s, name: %s", + conversion.SafeString(projectID), conversion.SafeString(name))) + return + } + return projectID, name +} + +func getTimeoutFromStateObj(stateObj map[string]tftypes.Value) timeouts.Value { attrTypes := map[string]attr.Type{ "create": types.StringType, "update": types.StringType, "delete": types.StringType, } nullObj := timeouts.Value{Object: types.ObjectNull(attrTypes)} - timeoutState := getAttrFromRawState[map[string]tftypes.Value](diags, rawState, "timeouts") - if diags.HasError() || timeoutState == nil { + timeoutState := getAttrFromStateObj[map[string]tftypes.Value](stateObj, "timeouts") + if timeoutState == nil { return nullObj } timeoutMap := make(map[string]attr.Value) for action := range attrTypes { - actionTimeout := getAttrFromRawState[string](diags, *timeoutState, action) + actionTimeout := getAttrFromStateObj[string](*timeoutState, action) if actionTimeout == nil { timeoutMap[action] = types.StringNull() } else { @@ -131,9 +137,49 @@ func getAttrTimeout(diags *diag.Diagnostics, rawState map[string]tftypes.Value) } } obj, d := types.ObjectValue(attrTypes, timeoutMap) - diags.Append(d...) - if diags.HasError() { + if d.HasError() { return nullObj } return timeouts.Value{Object: obj} } + +func setOptionalModelAttrs(ctx context.Context, stateObj map[string]tftypes.Value, model *TFModel) { + if retainBackupsEnabled := getAttrFromStateObj[bool](stateObj, "retain_backups_enabled"); retainBackupsEnabled != nil { + model.RetainBackupsEnabled = types.BoolPointerValue(retainBackupsEnabled) + } + if mongoDBMajorVersion := getAttrFromStateObj[string](stateObj, "mongo_db_major_version"); mongoDBMajorVersion != nil { + model.MongoDBMajorVersion = types.StringPointerValue(mongoDBMajorVersion) + } + if specsVal := getAttrFromStateObj[[]tftypes.Value](stateObj, "replication_specs"); specsVal != nil { + var specModels []TFReplicationSpecsModel + for _, specVal := range *specsVal { + var specObj map[string]tftypes.Value + if err := specVal.As(&specObj); err != nil { + continue + } + if specModel := replicationSpecModelWithNumShards(specObj["num_shards"]); specModel != nil { + specModels = append(specModels, *specModel) + } + } + if len(specModels) > 0 { + model.ReplicationSpecs, _ = types.ListValueFrom(ctx, ReplicationSpecsObjType, specModels) + } + } +} + +func replicationSpecModelWithNumShards(numShardsVal tftypes.Value) *TFReplicationSpecsModel { + var numShardsFloat *big.Float + if err := numShardsVal.As(&numShardsFloat); err != nil || numShardsFloat == nil { + return nil + } + numShards, _ := numShardsFloat.Int64() + return &TFReplicationSpecsModel{ + NumShards: types.Int64Value(numShards), + RegionConfigs: types.ListNull(RegionConfigsObjType), + ContainerId: types.MapNull(types.StringType), + Id: types.StringNull(), + ExternalId: types.StringNull(), + ZoneId: types.StringNull(), + ZoneName: types.StringNull(), + } +}