diff --git a/.env.sample b/.env.sample index cf74f98..8dd947b 100644 --- a/.env.sample +++ b/.env.sample @@ -16,6 +16,7 @@ BEMIDB_STORAGE_PATH=./iceberg # AWS_REGION=us-west-1 # AWS_ENDPOINT=s3.amazonaws.com # AWS_S3_BUCKET=[REPLACE_ME] +# AWS_CREDENTIALS_TYPE=STATIC # AWS_ACCESS_KEY_ID=[REPLACE_ME] # AWS_SECRET_ACCESS_KEY=[REPLACE_ME] diff --git a/README.md b/README.md index 2181944..031f333 100644 --- a/README.md +++ b/README.md @@ -209,16 +209,17 @@ psql postgres://localhost:54321/bemidb -c \ #### Other common options -| CLI argument | Environment variable | Default value | Description | -|---------------------------|-------------------------|---------------------------------|------------------------------------------------------| -| `--storage-type` | `BEMIDB_STORAGE_TYPE` | `LOCAL` | Storage type: `LOCAL` or `S3` | -| `--storage-path` | `BEMIDB_STORAGE_PATH` | `iceberg` | Path to the storage folder | -| `--log-level` | `BEMIDB_LOG_LEVEL` | `INFO` | Log level: `ERROR`, `WARN`, `INFO`, `DEBUG`, `TRACE` | -| `--aws-s3-endpoint` | `AWS_S3_ENDPOINT` | `s3.amazonaws.com` | AWS S3 endpoint | -| `--aws-region` | `AWS_REGION` | Required with `S3` storage type | AWS region | -| `--aws-s3-bucket` | `AWS_S3_BUCKET` | Required with `S3` storage type | AWS S3 bucket name | -| `--aws-access-key-id` | `AWS_ACCESS_KEY_ID` | Required with `S3` storage type | AWS access key ID | -| `--aws-secret-access-key` | `AWS_SECRET_ACCESS_KEY` | Required with `S3` storage type | AWS secret access key | +| CLI argument | Environment variable | Default value | Description | +|---------------------------|-------------------------|-------------------------------------------------------------------|------------------------------------------------------| +| `--storage-type` | `BEMIDB_STORAGE_TYPE` | `LOCAL` | Storage type: `LOCAL` or `S3` | +| `--storage-path` | `BEMIDB_STORAGE_PATH` | `iceberg` | Path to the storage folder | +| `--log-level` | `BEMIDB_LOG_LEVEL` | `INFO` | Log level: `ERROR`, `WARN`, `INFO`, `DEBUG`, `TRACE` | +| `--aws-s3-endpoint` | `AWS_S3_ENDPOINT` | `s3.amazonaws.com` | AWS S3 endpoint | +| `--aws-region` | `AWS_REGION` | Required with `S3` storage type | AWS region | +| `--aws-s3-bucket` | `AWS_S3_BUCKET` | Required with `S3` storage type | AWS S3 bucket name | +| `--aws-credentials-type` | `AWS_CREDENTIALS_TYPE` | `STATIC` | AWS credentials type: `STATIC`, `DEFAULT`. | +| `--aws-access-key-id` | `AWS_ACCESS_KEY_ID` | Required with `S3` storage type and aws credentials type `STATIC` | AWS access key ID | +| `--aws-secret-access-key` | `AWS_SECRET_ACCESS_KEY` | Required with `S3` storage type and aws credentials type `STATIC` | AWS secret access key | Note that CLI arguments take precedence over environment variables. I.e. you can override the environment variables with CLI arguments. diff --git a/src/config.go b/src/config.go index 8c4272c..ea92aa0 100644 --- a/src/config.go +++ b/src/config.go @@ -21,6 +21,7 @@ const ( ENV_AWS_REGION = "AWS_REGION" ENV_AWS_S3_ENDPOINT = "AWS_S3_ENDPOINT" ENV_AWS_S3_BUCKET = "AWS_S3_BUCKET" + ENV_AWS_CREDENTIALS_TYPE = "AWS_CREDENTIALS_TYPE" ENV_AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" ENV_AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" @@ -42,7 +43,11 @@ const ( DEFAULT_LOG_LEVEL = "INFO" DEFAULT_DB_STORAGE_TYPE = "LOCAL" - DEFAULT_AWS_S3_ENDPOINT = "s3.amazonaws.com" + DEFAULT_AWS_S3_ENDPOINT = "s3.amazonaws.com" + DEFAULT_AWS_CREDENTIALS_TYPE = "STATIC" + + AWS_CREDENTIALS_TYPE_STATIC = "STATIC" + AWS_CREDENTIALS_TYPE_DEFAULT = "DEFAULT" STORAGE_TYPE_LOCAL = "LOCAL" STORAGE_TYPE_S3 = "S3" @@ -52,6 +57,7 @@ type AwsConfig struct { Region string S3Endpoint string // optional S3Bucket string + CredentialsType string // optional AccessKeyId string SecretAccessKey string } @@ -115,6 +121,7 @@ func registerFlags() { flag.StringVar(&_config.Aws.Region, "aws-region", os.Getenv(ENV_AWS_REGION), "AWS region") flag.StringVar(&_config.Aws.S3Endpoint, "aws-s3-endpoint", os.Getenv(ENV_AWS_S3_ENDPOINT), "AWS S3 endpoint. Default: \""+DEFAULT_AWS_S3_ENDPOINT+"\"") flag.StringVar(&_config.Aws.S3Bucket, "aws-s3-bucket", os.Getenv(ENV_AWS_S3_BUCKET), "AWS S3 bucket name") + flag.StringVar(&_config.Aws.CredentialsType, "aws-credentials-type", os.Getenv(ENV_AWS_CREDENTIALS_TYPE), "AWS credentials type: \"STATIC\", \"DEFAULT\". Default: \""+DEFAULT_AWS_CREDENTIALS_TYPE+"\"") flag.StringVar(&_config.Aws.AccessKeyId, "aws-access-key-id", os.Getenv(ENV_AWS_ACCESS_KEY_ID), "AWS access key ID") flag.StringVar(&_config.Aws.SecretAccessKey, "aws-secret-access-key", os.Getenv(ENV_AWS_SECRET_ACCESS_KEY), "AWS secret access key") } @@ -169,11 +176,18 @@ func parseFlags() { if _config.Aws.S3Bucket == "" { panic("AWS S3 bucket name is required") } - if _config.Aws.AccessKeyId == "" { - panic("AWS access key ID is required") + if _config.Aws.CredentialsType == "" { + _config.Aws.CredentialsType = DEFAULT_AWS_CREDENTIALS_TYPE + } else if !slices.Contains(AWS_CREDENTIALS_TYPE, _config.Aws.CredentialsType) { + panic("Invalid AWS Credentials type " + _config.Aws.CredentialsType + ". Must be one of " + strings.Join(AWS_CREDENTIALS_TYPE, ", ")) } - if _config.Aws.SecretAccessKey == "" { - panic("AWS secret access key is required") + if _config.Aws.CredentialsType == AWS_CREDENTIALS_TYPE_STATIC { + if _config.Aws.AccessKeyId == "" { + panic("AWS access key ID is required") + } + if _config.Aws.SecretAccessKey == "" { + panic("AWS secret access key is required") + } } } if _configParseValues.pgIncludeSchemas != "" && _configParseValues.pgExcludeSchemas != "" { diff --git a/src/duckdb.go b/src/duckdb.go index f8cfea8..0c07447 100644 --- a/src/duckdb.go +++ b/src/duckdb.go @@ -7,6 +7,7 @@ import ( "os" "regexp" "strings" + "time" _ "github.com/marcboeker/go-duckdb" ) @@ -20,8 +21,9 @@ var DEFAULT_BOOT_QUERIES = []string{ } type Duckdb struct { - db *sql.DB - config *Config + refreshQuit chan struct{} + db *sql.DB + config *Config } func NewDuckdb(config *Config) *Duckdb { @@ -30,8 +32,9 @@ func NewDuckdb(config *Config) *Duckdb { PanicIfError(err) duckdb := &Duckdb{ - db: db, - config: config, + db: db, + config: config, + refreshQuit: make(chan struct{}), } bootQueries := readDuckdbInitFile(config) @@ -45,15 +48,20 @@ func NewDuckdb(config *Config) *Duckdb { switch config.StorageType { case STORAGE_TYPE_S3: - query := "CREATE SECRET aws_s3_secret (TYPE S3, KEY_ID '$accessKeyId', SECRET '$secretAccessKey', REGION '$region', ENDPOINT '$endpoint', SCOPE '$s3Bucket')" - _, err = duckdb.ExecContext(ctx, query, map[string]string{ - "accessKeyId": config.Aws.AccessKeyId, - "secretAccessKey": config.Aws.SecretAccessKey, - "region": config.Aws.Region, - "endpoint": config.Aws.S3Endpoint, - "s3Bucket": "s3://" + config.Aws.S3Bucket, - }) - PanicIfError(err) + duckdb.setAwsCredentials(ctx) + ticker := time.NewTicker(10 * time.Minute) + time.Tick(10 * time.Minute) + go func() { + for { + select { + case <-ticker.C: + duckdb.setAwsCredentials(ctx) + case <-duckdb.refreshQuit: + ticker.Stop() + return + } + } + }() if config.LogLevel == LOG_LEVEL_TRACE { _, err = duckdb.ExecContext(ctx, "SET enable_http_logging=true", nil) @@ -64,6 +72,30 @@ func NewDuckdb(config *Config) *Duckdb { return duckdb } +func (duckdb *Duckdb) setAwsCredentials(ctx context.Context) { + config := duckdb.config + switch config.Aws.CredentialsType { + case AWS_CREDENTIALS_TYPE_STATIC: + query := "CREATE OR REPLACE SECRET aws_s3_secret (TYPE S3, KEY_ID '$accessKeyId', SECRET '$secretAccessKey', REGION '$region', ENDPOINT '$endpoint', SCOPE '$s3Bucket')" + _, err := duckdb.ExecContext(ctx, query, map[string]string{ + "accessKeyId": config.Aws.AccessKeyId, + "secretAccessKey": config.Aws.SecretAccessKey, + "region": config.Aws.Region, + "endpoint": config.Aws.S3Endpoint, + "s3Bucket": "s3://" + config.Aws.S3Bucket, + }) + PanicIfError(err) + case AWS_CREDENTIALS_TYPE_DEFAULT: + query := "CREATE OR REPLACE SECRET aws_s3_secret (TYPE S3, PROVIDER CREDENTIAL_CHAIN, REGION '$region', ENDPOINT '$endpoint', SCOPE '$s3Bucket')" + _, err := duckdb.ExecContext(ctx, query, map[string]string{ + "region": config.Aws.Region, + "endpoint": config.Aws.S3Endpoint, + "s3Bucket": "s3://" + config.Aws.S3Bucket, + }) + PanicIfError(err) + } +} + func (duckdb *Duckdb) ExecContext(ctx context.Context, query string, args map[string]string) (sql.Result, error) { LogDebug(duckdb.config, "Querying DuckDB:", query, args) return duckdb.db.ExecContext(ctx, replaceNamedStringArgs(query, args)) @@ -80,6 +112,7 @@ func (duckdb *Duckdb) PrepareContext(ctx context.Context, query string) (*sql.St } func (duckdb *Duckdb) Close() { + close(duckdb.refreshQuit) duckdb.db.Close() } diff --git a/src/storage_s3.go b/src/storage_s3.go index 46212b6..cbb6865 100644 --- a/src/storage_s3.go +++ b/src/storage_s3.go @@ -16,6 +16,8 @@ import ( "github.com/xitongsys/parquet-go-source/s3v2" ) +var AWS_CREDENTIALS_TYPE = []string{AWS_CREDENTIALS_TYPE_STATIC, AWS_CREDENTIALS_TYPE_DEFAULT} + type StorageS3 struct { s3Client *s3.Client config *Config @@ -23,22 +25,29 @@ type StorageS3 struct { } func NewS3Storage(config *Config) *StorageS3 { - awsCredentials := credentials.NewStaticCredentialsProvider( - config.Aws.AccessKeyId, - config.Aws.SecretAccessKey, - "", - ) - var logMode aws.ClientLogMode // if config.LogLevel == LOG_LEVEL_DEBUG { // logMode = aws.LogRequest | aws.LogResponse // } - loadedAwsConfig, err := awsConfig.LoadDefaultConfig( - context.Background(), + var awsConfigOptions = []func(*awsConfig.LoadOptions) error{ awsConfig.WithRegion(config.Aws.Region), - awsConfig.WithCredentialsProvider(awsCredentials), awsConfig.WithClientLogMode(logMode), + } + + if config.Aws.CredentialsType == AWS_CREDENTIALS_TYPE_STATIC { + awsCredentials := credentials.NewStaticCredentialsProvider( + config.Aws.AccessKeyId, + config.Aws.SecretAccessKey, + "", + ) + + awsConfigOptions = append(awsConfigOptions, awsConfig.WithCredentialsProvider(awsCredentials)) + } + + loadedAwsConfig, err := awsConfig.LoadDefaultConfig( + context.Background(), + awsConfigOptions..., ) PanicIfError(err)