Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S3 relay interface #833

Merged
merged 29 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6814f14
Created core framework, need to add unit tests
cody-littley Oct 23, 2024
c7ba7d6
Add unit tests for fragmentation logic.
cody-littley Oct 23, 2024
e5dcc78
Incremental progress.
cody-littley Oct 23, 2024
17ac4bc
Fix some bugs.
cody-littley Oct 24, 2024
356959c
Fixed bugs.
cody-littley Oct 24, 2024
958e6e1
Test against localstack.
cody-littley Oct 24, 2024
a2bb06f
lint
cody-littley Oct 24, 2024
95b9d75
Cleanup.
cody-littley Oct 24, 2024
9b9d57d
Merge branch 'master' into s3-relay-interface
cody-littley Oct 24, 2024
e6c7a0e
Remove TTL from upload
cody-littley Oct 25, 2024
b712c88
Merge branch 'master' into s3-relay-interface
cody-littley Oct 25, 2024
93e2310
Use v2 APIs.
cody-littley Oct 28, 2024
6e9aa2e
Made suggested changes.
cody-littley Oct 28, 2024
8323596
Made suggested changes.
cody-littley Oct 28, 2024
5d5f425
Incremental checkin.
cody-littley Oct 28, 2024
2ddd2c4
Finished migration.
cody-littley Oct 28, 2024
220d48f
Fix flags.
cody-littley Oct 28, 2024
4996fb0
Fix bug.
cody-littley Oct 29, 2024
06cb92e
Merge branch 'master' into s3-relay-interface
cody-littley Oct 29, 2024
dc5b6fc
Fix unit test.
cody-littley Oct 29, 2024
9104ab1
Fix unit test.
cody-littley Oct 29, 2024
319fca6
Fix unit test.
cody-littley Oct 29, 2024
fe9e08c
Fix unit test.
cody-littley Oct 29, 2024
8f269de
Make suggested changes.
cody-littley Oct 29, 2024
c5d5b96
Tweak unit test.
cody-littley Oct 29, 2024
a4c081d
Merge branch 'master' into s3-relay-interface
cody-littley Oct 29, 2024
c4c7e2d
Change localstack port.
cody-littley Oct 29, 2024
684dab8
Add debug code.
cody-littley Oct 30, 2024
958f593
Fiddle with inabox settings.
cody-littley Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 89 additions & 11 deletions common/aws/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,46 @@ package aws
import (
"github.com/Layr-Labs/eigenda/common"
"github.com/urfave/cli"
"time"
)

var (
RegionFlagName = "aws.region"
AccessKeyIdFlagName = "aws.access-key-id"
SecretAccessKeyFlagName = "aws.secret-access-key"
EndpointURLFlagName = "aws.endpoint-url"
RegionFlagName = "aws.region"
AccessKeyIdFlagName = "aws.access-key-id"
SecretAccessKeyFlagName = "aws.secret-access-key"
EndpointURLFlagName = "aws.endpoint-url"
FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars"
FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor"
FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant"
FragmentReadTimeoutFlagName = "aws.fragment-read-timeout"
FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout"
)

type ClientConfig struct {
Region string
AccessKey string
// The region to use when interacting with S3. Default is "us-east-2".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: start comment with the name of the field..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Region string
// The access key to use when interacting with S3.
AccessKey string
// The secret key to use when interacting with S3.
SecretAccessKey string
EndpointURL string
// The URL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used.
EndpointURL string

// The number of characters of the key to use as the prefix for fragmented files.
// A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3.
FragmentPrefixChars int
// This framework utilizes a pool of workers to help upload/download files. A non-zero value for this parameter
// adds a number of workers equal to the number of cores times this value. Default is 8. In general, the number
// of workers here can be a lot larger than the number of cores because the workers will be blocked on I/O most
// of the time.
FragmentParallelismFactor int
// This framework utilizes a pool of workers to help upload/download files. A non-zero value for this parameter
// adds a constant number of workers. Default is 0.
FragmentParallelismConstant int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use a single field to explicitly define the number of workers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on @dmanc's request to scale the number of workers based on the number of cores on the machine. The pattern here allows for the user to specify either a fixed number of threads or a number that varies with the number of cores.

If you think this is over complicated I'm willing to go back to having a constant number of worker threads.

// If a single fragmented read takes longer than this value then the read will be aborted. Default is 30 seconds.
FragmentReadTimeout time.Duration
// If a single fragmented write takes longer than this value then the write will be aborted. Default is 30 seconds.
FragmentWriteTimeout time.Duration
}

func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag {
Expand Down Expand Up @@ -48,14 +74,66 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag {
Value: "",
EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"),
},
cli.IntFlag{
Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName),
Usage: "The number of characters of the key to use as the prefix for fragmented files",
Required: false,
Value: 3,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"),
},
cli.IntFlag{
Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName),
Usage: "Add this many threads times the number of cores to the worker pool",
Required: false,
Value: 8,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"),
},
cli.IntFlag{
Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName),
Usage: "Add this many threads to the worker pool",
Required: false,
Value: 0,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"),
},
cli.DurationFlag{
Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName),
Usage: "The maximum time to wait for a single fragmented read",
Required: false,
Value: 30 * time.Second,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"),
},
cli.DurationFlag{
Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName),
Usage: "The maximum time to wait for a single fragmented write",
Required: false,
Value: 30 * time.Second,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"),
},
}
}

func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig {
return ClientConfig{
Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)),
AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)),
SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)),
EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)),
Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)),
AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)),
SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)),
EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)),
FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)),
FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)),
FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)),
FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)),
FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)),
}
}

// DefaultClientConfig returns a new ClientConfig with default values.
func DefaultClientConfig() *ClientConfig {
return &ClientConfig{
Region: "us-east-2",
FragmentPrefixChars: 3,
FragmentParallelismFactor: 8,
FragmentParallelismConstant: 0,
FragmentReadTimeout: 30 * time.Second,
FragmentWriteTimeout: 30 * time.Second,
}
}
214 changes: 200 additions & 14 deletions common/aws/s3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"context"
"errors"
"github.com/gammazero/workerpool"
"runtime"
"sync"

commonaws "github.com/Layr-Labs/eigenda/common/aws"
Expand All @@ -27,7 +29,9 @@ type Object struct {
}

type client struct {
cfg *commonaws.ClientConfig
s3Client *s3.Client
pool *workerpool.WorkerPool
logger logging.Logger
}

Expand All @@ -36,18 +40,19 @@ var _ Client = (*client)(nil)
func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) {
var err error
once.Do(func() {
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
if cfg.EndpointURL != "" {
return aws.Endpoint{
PartitionID: "aws",
URL: cfg.EndpointURL,
SigningRegion: cfg.Region,
}, nil
}

// returning EndpointNotFoundError will allow the service to fallback to its default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})
customResolver := aws.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (aws.Endpoint, error) {
if cfg.EndpointURL != "" {
return aws.Endpoint{
PartitionID: "aws",
URL: cfg.EndpointURL,
SigningRegion: cfg.Region,
}, nil
}

// returning EndpointNotFoundError will allow the service to fallback to its default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})

options := [](func(*config.LoadOptions) error){
config.WithRegion(cfg.Region),
Expand All @@ -56,18 +61,40 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L
}
// If access key and secret access key are not provided, use the default credential provider
if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 {
options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, "")))
options = append(options,
config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, "")))
}
awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...)

if errCfg != nil {
err = errCfg
return
}

s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) {
o.UsePathStyle = true
})
ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")}

workers := 0
if cfg.FragmentParallelismConstant > 0 {
workers = cfg.FragmentParallelismConstant
}
if cfg.FragmentParallelismFactor > 0 {
workers = cfg.FragmentParallelismFactor * runtime.NumCPU()
}

if workers == 0 {
workers = 1
}
pool := workerpool.New(workers)

ref = &client{
cfg: &cfg,
s3Client: s3Client,
pool: pool,
logger: logger.With("component", "S3Client"),
}
})
return ref, err
}
Expand Down Expand Up @@ -148,3 +175,162 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string)
}
return objects, nil
}

func (s *client) CreateBucket(ctx context.Context, bucket string) error {
_, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: aws.String(bucket),
})
if err != nil {
return err
}

return nil
}

func (s *client) FragmentedUploadObject(
ctx context.Context,
bucket string,
key string,
data []byte,
fragmentSize int) error {

fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize)
if err != nil {
return err
}
resultChannel := make(chan error, len(fragments))

ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout)
defer cancel()

for _, fragment := range fragments {
fragmentCapture := fragment
s.pool.Submit(func() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pool is shared across all ongoing uploads and downloads. Which means when there are many uploads/downloads in flight, it can build backpressure. Is that problematic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My previous iteration had a config parameter that allowed the client to be have a configurable sized work queue (i.e. when I was using my own implementation of workerpool. Unfortunately, the workerpool library does not allow us to override the channel used to send data to the workers.

From the implementation:

workerQueue: make(chan func())

By default, channels have a buffer size of 0. This means that if the number of read/write tasks exceed the number of available workers, the caller will block until all tasks are accepted by a worker. This will provide back pressure if more work is scheduled than there are workers to handle that work.

Are you ok with the way this is configured by default? If not, we will probably need our own implementation of workerpool.

s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket)
})
}

for range fragments {
err := <-resultChannel
if err != nil {
return err
}
}
return ctx.Err()

}

// fragmentedWriteTask writes a single file to S3.
func (s *client) fragmentedWriteTask(
ctx context.Context,
resultChannel chan error,
fragment *Fragment,
bucket string) {

_, err := s.s3Client.PutObject(ctx,
&s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(fragment.FragmentKey),
Body: bytes.NewReader(fragment.Data),
})

resultChannel <- err
}

func (s *client) FragmentedDownloadObject(
ctx context.Context,
bucket string,
key string,
fileSize int,
fragmentSize int) ([]byte, error) {

if fragmentSize <= 0 {
return nil, errors.New("fragmentSize must be greater than 0")
}

fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize))
if err != nil {
return nil, err
}
resultChannel := make(chan *readResult, len(fragmentKeys))

ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout)
defer cancel()

for i, fragmentKey := range fragmentKeys {
boundFragmentKey := fragmentKey
boundI := i
s.pool.Submit(func() {
s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI)
})
}

fragments := make([]*Fragment, len(fragmentKeys))
for i := 0; i < len(fragmentKeys); i++ {
result := <-resultChannel
if result.err != nil {
return nil, result.err
}
fragments[result.fragment.Index] = result.fragment
}

if ctx.Err() != nil {
return nil, ctx.Err()
}

return RecombineFragments(fragments)

}

// readResult is the result of a read task.
type readResult struct {
fragment *Fragment
err error
}

// readTask reads a single file from S3.
func (s *client) readTask(
ctx context.Context,
resultChannel chan *readResult,
bucket string,
key string,
index int) {

result := &readResult{}
defer func() {
resultChannel <- result
}()

ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})

if err != nil {
result.err = err
return
}

data := make([]byte, *ret.ContentLength)
bytesRead := 0

for bytesRead < len(data) && ctx.Err() == nil {
count, err := ret.Body.Read(data[bytesRead:])
if err != nil && err.Error() != "EOF" {
result.err = err
return
}
bytesRead += count
}

result.fragment = &Fragment{
FragmentKey: key,
Data: data,
Index: index,
}

err = ret.Body.Close()
if err != nil {
result.err = err
}
}
Loading
Loading