-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
S3 relay interface #833
S3 relay interface #833
Changes from 20 commits
6814f14
c7ba7d6
e5dcc78
17ac4bc
356959c
958e6e1
a2bb06f
95b9d75
9b9d57d
e6c7a0e
b712c88
93e2310
6e9aa2e
8323596
5d5f425
2ddd2c4
220d48f
4996fb0
06cb92e
dc5b6fc
9104ab1
319fca6
fe9e08c
8f269de
c5d5b96
a4c081d
c4c7e2d
684dab8
958f593
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,20 +3,46 @@ package aws | |
import ( | ||
"github.com/Layr-Labs/eigenda/common" | ||
"github.com/urfave/cli" | ||
"time" | ||
) | ||
|
||
var ( | ||
RegionFlagName = "aws.region" | ||
AccessKeyIdFlagName = "aws.access-key-id" | ||
SecretAccessKeyFlagName = "aws.secret-access-key" | ||
EndpointURLFlagName = "aws.endpoint-url" | ||
RegionFlagName = "aws.region" | ||
AccessKeyIdFlagName = "aws.access-key-id" | ||
SecretAccessKeyFlagName = "aws.secret-access-key" | ||
EndpointURLFlagName = "aws.endpoint-url" | ||
FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" | ||
FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" | ||
FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" | ||
FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" | ||
FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" | ||
) | ||
|
||
type ClientConfig struct { | ||
Region string | ||
AccessKey string | ||
// The region to use when interacting with S3. Default is "us-east-2". | ||
Region string | ||
// The access key to use when interacting with S3. | ||
AccessKey string | ||
// The secret key to use when interacting with S3. | ||
SecretAccessKey string | ||
EndpointURL string | ||
// The URL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. | ||
EndpointURL string | ||
|
||
// The number of characters of the key to use as the prefix for fragmented files. | ||
// A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. | ||
FragmentPrefixChars int | ||
// This framework utilizes a pool of workers to help upload/download files. A non-zero value for this parameter | ||
// adds a number of workers equal to the number of cores times this value. Default is 8. In general, the number | ||
// of workers here can be a lot larger than the number of cores because the workers will be blocked on I/O most | ||
// of the time. | ||
FragmentParallelismFactor int | ||
// This framework utilizes a pool of workers to help upload/download files. A non-zero value for this parameter | ||
// adds a constant number of workers. Default is 0. | ||
FragmentParallelismConstant int | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just use a single field to explicitly define the number of workers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is based on @dmanc's request to scale the number of workers based on the number of cores on the machine. The pattern here allows for the user to specify either a fixed number of threads or a number that varies with the number of cores. If you think this is over complicated I'm willing to go back to having a constant number of worker threads. |
||
// If a single fragmented read takes longer than this value then the read will be aborted. Default is 30 seconds. | ||
FragmentReadTimeout time.Duration | ||
// If a single fragmented write takes longer than this value then the write will be aborted. Default is 30 seconds. | ||
FragmentWriteTimeout time.Duration | ||
} | ||
|
||
func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { | ||
|
@@ -48,14 +74,66 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { | |
Value: "", | ||
EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), | ||
}, | ||
cli.IntFlag{ | ||
Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), | ||
Usage: "The number of characters of the key to use as the prefix for fragmented files", | ||
Required: false, | ||
Value: 3, | ||
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), | ||
}, | ||
cli.IntFlag{ | ||
Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), | ||
Usage: "Add this many threads times the number of cores to the worker pool", | ||
Required: false, | ||
Value: 8, | ||
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), | ||
}, | ||
cli.IntFlag{ | ||
Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), | ||
Usage: "Add this many threads to the worker pool", | ||
Required: false, | ||
Value: 0, | ||
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), | ||
}, | ||
cli.DurationFlag{ | ||
Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), | ||
Usage: "The maximum time to wait for a single fragmented read", | ||
Required: false, | ||
Value: 30 * time.Second, | ||
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), | ||
}, | ||
cli.DurationFlag{ | ||
Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), | ||
Usage: "The maximum time to wait for a single fragmented write", | ||
Required: false, | ||
Value: 30 * time.Second, | ||
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), | ||
}, | ||
} | ||
} | ||
|
||
func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { | ||
return ClientConfig{ | ||
Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), | ||
AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), | ||
SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), | ||
EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), | ||
Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), | ||
AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), | ||
SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), | ||
EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), | ||
FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), | ||
FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), | ||
FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), | ||
FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), | ||
FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), | ||
} | ||
} | ||
|
||
// DefaultClientConfig returns a new ClientConfig with default values. | ||
func DefaultClientConfig() *ClientConfig { | ||
return &ClientConfig{ | ||
Region: "us-east-2", | ||
FragmentPrefixChars: 3, | ||
FragmentParallelismFactor: 8, | ||
FragmentParallelismConstant: 0, | ||
FragmentReadTimeout: 30 * time.Second, | ||
FragmentWriteTimeout: 30 * time.Second, | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ import ( | |
"bytes" | ||
"context" | ||
"errors" | ||
"github.com/gammazero/workerpool" | ||
"runtime" | ||
"sync" | ||
|
||
commonaws "github.com/Layr-Labs/eigenda/common/aws" | ||
|
@@ -27,7 +29,9 @@ type Object struct { | |
} | ||
|
||
type client struct { | ||
cfg *commonaws.ClientConfig | ||
s3Client *s3.Client | ||
pool *workerpool.WorkerPool | ||
logger logging.Logger | ||
} | ||
|
||
|
@@ -36,18 +40,19 @@ var _ Client = (*client)(nil) | |
func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { | ||
var err error | ||
once.Do(func() { | ||
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { | ||
if cfg.EndpointURL != "" { | ||
return aws.Endpoint{ | ||
PartitionID: "aws", | ||
URL: cfg.EndpointURL, | ||
SigningRegion: cfg.Region, | ||
}, nil | ||
} | ||
|
||
// returning EndpointNotFoundError will allow the service to fallback to its default resolution | ||
return aws.Endpoint{}, &aws.EndpointNotFoundError{} | ||
}) | ||
customResolver := aws.EndpointResolverWithOptionsFunc( | ||
func(service, region string, options ...interface{}) (aws.Endpoint, error) { | ||
if cfg.EndpointURL != "" { | ||
return aws.Endpoint{ | ||
PartitionID: "aws", | ||
URL: cfg.EndpointURL, | ||
SigningRegion: cfg.Region, | ||
}, nil | ||
} | ||
|
||
// returning EndpointNotFoundError will allow the service to fallback to its default resolution | ||
return aws.Endpoint{}, &aws.EndpointNotFoundError{} | ||
}) | ||
|
||
options := [](func(*config.LoadOptions) error){ | ||
config.WithRegion(cfg.Region), | ||
|
@@ -56,18 +61,40 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L | |
} | ||
// If access key and secret access key are not provided, use the default credential provider | ||
if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { | ||
options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) | ||
options = append(options, | ||
config.WithCredentialsProvider( | ||
credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) | ||
} | ||
awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) | ||
|
||
if errCfg != nil { | ||
err = errCfg | ||
return | ||
} | ||
|
||
s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { | ||
o.UsePathStyle = true | ||
}) | ||
ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")} | ||
|
||
workers := 0 | ||
if cfg.FragmentParallelismConstant > 0 { | ||
workers = cfg.FragmentParallelismConstant | ||
} | ||
if cfg.FragmentParallelismFactor > 0 { | ||
workers = cfg.FragmentParallelismFactor * runtime.NumCPU() | ||
} | ||
|
||
if workers == 0 { | ||
workers = 1 | ||
} | ||
pool := workerpool.New(workers) | ||
|
||
ref = &client{ | ||
cfg: &cfg, | ||
s3Client: s3Client, | ||
pool: pool, | ||
logger: logger.With("component", "S3Client"), | ||
} | ||
}) | ||
return ref, err | ||
} | ||
|
@@ -148,3 +175,162 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) | |
} | ||
return objects, nil | ||
} | ||
|
||
func (s *client) CreateBucket(ctx context.Context, bucket string) error { | ||
_, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ | ||
Bucket: aws.String(bucket), | ||
}) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (s *client) FragmentedUploadObject( | ||
ctx context.Context, | ||
bucket string, | ||
key string, | ||
data []byte, | ||
fragmentSize int) error { | ||
|
||
fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) | ||
if err != nil { | ||
return err | ||
} | ||
resultChannel := make(chan error, len(fragments)) | ||
|
||
ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) | ||
defer cancel() | ||
|
||
for _, fragment := range fragments { | ||
fragmentCapture := fragment | ||
s.pool.Submit(func() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This pool is shared across all ongoing uploads and downloads. Which means when there are many uploads/downloads in flight, it can build backpressure. Is that problematic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My previous iteration had a config parameter that allowed the client to be have a configurable sized work queue (i.e. when I was using my own implementation of From the implementation:
By default, channels have a buffer size of 0. This means that if the number of read/write tasks exceed the number of available workers, the caller will block until all tasks are accepted by a worker. This will provide back pressure if more work is scheduled than there are workers to handle that work. Are you ok with the way this is configured by default? If not, we will probably need our own implementation of |
||
s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) | ||
}) | ||
} | ||
|
||
for range fragments { | ||
err := <-resultChannel | ||
if err != nil { | ||
return err | ||
} | ||
} | ||
return ctx.Err() | ||
|
||
} | ||
|
||
// fragmentedWriteTask writes a single file to S3. | ||
func (s *client) fragmentedWriteTask( | ||
ctx context.Context, | ||
resultChannel chan error, | ||
fragment *Fragment, | ||
bucket string) { | ||
|
||
_, err := s.s3Client.PutObject(ctx, | ||
&s3.PutObjectInput{ | ||
Bucket: aws.String(bucket), | ||
Key: aws.String(fragment.FragmentKey), | ||
Body: bytes.NewReader(fragment.Data), | ||
}) | ||
|
||
resultChannel <- err | ||
} | ||
|
||
func (s *client) FragmentedDownloadObject( | ||
ctx context.Context, | ||
bucket string, | ||
key string, | ||
fileSize int, | ||
fragmentSize int) ([]byte, error) { | ||
|
||
if fragmentSize <= 0 { | ||
return nil, errors.New("fragmentSize must be greater than 0") | ||
} | ||
|
||
fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
resultChannel := make(chan *readResult, len(fragmentKeys)) | ||
|
||
ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) | ||
defer cancel() | ||
|
||
for i, fragmentKey := range fragmentKeys { | ||
boundFragmentKey := fragmentKey | ||
boundI := i | ||
s.pool.Submit(func() { | ||
s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) | ||
}) | ||
} | ||
|
||
fragments := make([]*Fragment, len(fragmentKeys)) | ||
for i := 0; i < len(fragmentKeys); i++ { | ||
result := <-resultChannel | ||
if result.err != nil { | ||
return nil, result.err | ||
} | ||
fragments[result.fragment.Index] = result.fragment | ||
} | ||
|
||
if ctx.Err() != nil { | ||
return nil, ctx.Err() | ||
} | ||
|
||
return RecombineFragments(fragments) | ||
|
||
} | ||
|
||
// readResult is the result of a read task. | ||
type readResult struct { | ||
fragment *Fragment | ||
err error | ||
} | ||
|
||
// readTask reads a single file from S3. | ||
func (s *client) readTask( | ||
ctx context.Context, | ||
resultChannel chan *readResult, | ||
bucket string, | ||
key string, | ||
index int) { | ||
|
||
result := &readResult{} | ||
defer func() { | ||
resultChannel <- result | ||
}() | ||
|
||
ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ | ||
Bucket: aws.String(bucket), | ||
Key: aws.String(key), | ||
}) | ||
|
||
if err != nil { | ||
result.err = err | ||
return | ||
} | ||
|
||
data := make([]byte, *ret.ContentLength) | ||
bytesRead := 0 | ||
|
||
for bytesRead < len(data) && ctx.Err() == nil { | ||
count, err := ret.Body.Read(data[bytesRead:]) | ||
if err != nil && err.Error() != "EOF" { | ||
result.err = err | ||
return | ||
} | ||
bytesRead += count | ||
} | ||
|
||
result.fragment = &Fragment{ | ||
FragmentKey: key, | ||
Data: data, | ||
Index: index, | ||
} | ||
|
||
err = ret.Body.Close() | ||
if err != nil { | ||
result.err = err | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: start comment with the name of the field..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed