Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the /dbfs/put API endpoint to upload smaller DBFS files #1951

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c4cea1a
Use the `/dbfs/put` API endpoint to upload smaller DBFS files
shreyas-goenka Dec 2, 2024
91a2dfa
-
shreyas-goenka Dec 2, 2024
06af01c
refactor to make tests easier
shreyas-goenka Dec 2, 2024
4b484fd
todo
shreyas-goenka Dec 2, 2024
04cae2f
Merge remote-tracking branch 'origin' into multipart-dbfs
shreyas-goenka Dec 4, 2024
f690f0a
Merge remote-tracking branch 'origin' into multipart-dbfs
shreyas-goenka Dec 5, 2024
da46c14
Merge remote-tracking branch 'origin' into multipart-dbfs
shreyas-goenka Dec 30, 2024
78b9788
some cleanup
shreyas-goenka Dec 30, 2024
2717aca
Merge remote-tracking branch 'origin' into multipart-dbfs
shreyas-goenka Dec 31, 2024
0ce50fa
add unit test
shreyas-goenka Dec 31, 2024
63e599c
added integration test
shreyas-goenka Dec 31, 2024
932aeee
ignore linter
shreyas-goenka Dec 31, 2024
9d8ba09
fix fd lingering'
shreyas-goenka Dec 31, 2024
09bf4fa
add unit test
shreyas-goenka Dec 31, 2024
cf51636
overwrite fix
shreyas-goenka Dec 31, 2024
be62ead
lint
shreyas-goenka Dec 31, 2024
7084392
cleanup code
shreyas-goenka Dec 31, 2024
ee80173
fix size
shreyas-goenka Dec 31, 2024
69fdd97
reduce diff
shreyas-goenka Dec 31, 2024
92e97ad
-
shreyas-goenka Dec 31, 2024
95f41b1
add streaming uploads
shreyas-goenka Jan 2, 2025
8ec1e07
Revert "add streaming uploads"
shreyas-goenka Jan 2, 2025
890b48f
Reapply "add streaming uploads"
shreyas-goenka Jan 2, 2025
f70c472
calculate content length before upload
shreyas-goenka Jan 2, 2025
1e2545e
merge
shreyas-goenka Jan 2, 2025
6991dea
make content length work
shreyas-goenka Jan 2, 2025
583637a
lint
shreyas-goenka Jan 2, 2025
9552131
lint
shreyas-goenka Jan 2, 2025
7ab9fb7
simplify copy
shreyas-goenka Jan 2, 2025
ac37ca0
-
shreyas-goenka Jan 2, 2025
f4623eb
cleanup
shreyas-goenka Jan 2, 2025
ee9499b
use write testutil
shreyas-goenka Jan 2, 2025
e9b0afb
-
shreyas-goenka Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions integration/libs/filer/filer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"encoding/json"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
"testing"

Expand Down Expand Up @@ -893,3 +895,80 @@ func TestWorkspaceFilesExtensions_ExportFormatIsPreserved(t *testing.T) {
})
}
}

func TestDbfsFilerForStreamingUploads(t *testing.T) {
ctx := context.Background()
f, _ := setupDbfsFiler(t)

// Set MaxDbfsPutFileSize to 1 to force streaming uploads
prevV := filer.MaxDbfsPutFileSize
filer.MaxDbfsPutFileSize = 1
t.Cleanup(func() {
filer.MaxDbfsPutFileSize = prevV
})

// Write a file to local disk.
tmpDir := t.TempDir()
testutil.WriteFile(t, filepath.Join(tmpDir, "foo.txt"), "foobar")

fd, err := os.Open(filepath.Join(tmpDir, "foo.txt"))
require.NoError(t, err)
defer fd.Close()

// Write a file with streaming upload
err = f.Write(ctx, "foo.txt", fd)
require.NoError(t, err)

// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "foobar")

// Overwrite the file with streaming upload, and fail
err = f.Write(ctx, "foo.txt", strings.NewReader("barfoo"))
require.ErrorIs(t, err, fs.ErrExist)

// Overwrite the file with streaming upload, and succeed
err = f.Write(ctx, "foo.txt", strings.NewReader("barfoo"), filer.OverwriteIfExists)
require.NoError(t, err)

// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "barfoo")
}

func TestDbfsFilerForPutUploads(t *testing.T) {
ctx := context.Background()
f, _ := setupDbfsFiler(t)

// Write a file to local disk.
tmpDir := t.TempDir()
testutil.WriteFile(t, filepath.Join(tmpDir, "foo.txt"), "foobar")
testutil.WriteFile(t, filepath.Join(tmpDir, "bar.txt"), "barfoo")
fdFoo, err := os.Open(filepath.Join(tmpDir, "foo.txt"))
require.NoError(t, err)
defer fdFoo.Close()

fdBar, err := os.Open(filepath.Join(tmpDir, "bar.txt"))
require.NoError(t, err)
defer fdBar.Close()

// Write a file with PUT upload
err = f.Write(ctx, "foo.txt", fdFoo)
require.NoError(t, err)

// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "foobar")

// Try to overwrite the file, and fail.
err = f.Write(ctx, "foo.txt", fdBar)
require.ErrorIs(t, err, fs.ErrExist)

// Reset the file descriptor.
_, err = fdBar.Seek(0, io.SeekStart)
require.NoError(t, err)

// Overwrite the file with OverwriteIfExists flag
err = f.Write(ctx, "foo.txt", fdBar, filer.OverwriteIfExists)
require.NoError(t, err)

// Assert contents
filerTest{t, f}.assertContents(ctx, "foo.txt", "barfoo")
}
158 changes: 150 additions & 8 deletions libs/filer/dbfs_client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package filer

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"mime/multipart"
"net/http"
"os"
"path"
"slices"
"sort"
Expand All @@ -14,6 +18,7 @@ import (

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/apierr"
"github.com/databricks/databricks-sdk-go/client"
"github.com/databricks/databricks-sdk-go/service/files"
)

Expand Down Expand Up @@ -63,33 +68,142 @@ func (info dbfsFileInfo) Sys() any {
return info.fi
}

// Interface to allow mocking of the Databricks API client.
type databricksClient interface {
Do(ctx context.Context, method, path string, headers map[string]string,
requestBody, responseBody any, visitors ...func(*http.Request) error) error
}

// DbfsClient implements the [Filer] interface for the DBFS backend.
type DbfsClient struct {
workspaceClient *databricks.WorkspaceClient

apiClient databricksClient

// File operations will be relative to this path.
root WorkspaceRootPath
}

func NewDbfsClient(w *databricks.WorkspaceClient, root string) (Filer, error) {
apiClient, err := client.New(w.Config)
if err != nil {
return nil, fmt.Errorf("failed to create API client: %w", err)
}

return &DbfsClient{
workspaceClient: w,
apiClient: apiClient,

root: NewWorkspaceRootPath(root),
}, nil
}

// The PUT API for DBFS requires setting the content length header beforehand in the HTTP
// request.
func contentLength(path, overwriteField string, file *os.File) (int64, error) {
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
err := writer.WriteField("path", path)
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return 0, fmt.Errorf("failed to write field path field in multipart form: %w", err)
}
err = writer.WriteField("overwrite", overwriteField)
if err != nil {
return 0, fmt.Errorf("failed to write field overwrite field in multipart form: %w", err)
}
_, err = writer.CreateFormFile("contents", "")
if err != nil {
return 0, fmt.Errorf("failed to write contents field in multipart form: %w", err)
}
err = writer.Close()
if err != nil {
return 0, fmt.Errorf("failed to close multipart form writer: %w", err)
}

stat, err := file.Stat()
if err != nil {
return 0, fmt.Errorf("failed to stat file %s: %w", path, err)
}

return int64(buf.Len()) + stat.Size(), nil
}

func contentLengthVisitor(path, overwriteField string, file *os.File) func(*http.Request) error {
return func(r *http.Request) error {
cl, err := contentLength(path, overwriteField, file)
if err != nil {
return fmt.Errorf("failed to calculate content length: %w", err)
}
r.ContentLength = cl
return nil
}
}

func (w *DbfsClient) putFile(ctx context.Context, path string, overwrite bool, file *os.File) error {
overwriteField := "False"
if overwrite {
overwriteField = "True"
}

pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
go func() {
defer pw.Close()

err := writer.WriteField("path", path)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write field path field in multipart form: %w", err))
return
}
err = writer.WriteField("overwrite", overwriteField)
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write field overwrite field in multipart form: %w", err))
return
}
contents, err := writer.CreateFormFile("contents", "")
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to write contents field in multipart form: %w", err))
return
}
_, err = io.Copy(contents, file)
if err != nil {
pw.CloseWithError(fmt.Errorf("error while streaming file to dbfs: %w", err))
return
}
err = writer.Close()
if err != nil {
pw.CloseWithError(fmt.Errorf("failed to close multipart form writer: %w", err))
return
}
}()

// Request bodies of Content-Type multipart/form-data are not supported by
// the Go SDK directly for DBFS. So we use the Do method directly.
err := w.apiClient.Do(ctx,
http.MethodPost,
"/api/2.0/dbfs/put",
map[string]string{"Content-Type": writer.FormDataContentType()},
pr,
nil,
contentLengthVisitor(path, overwriteField, file))
var aerr *apierr.APIError
if errors.As(err, &aerr) && aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" {
return FileAlreadyExistsError{path}
}
return err
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
}

// MaxUploadLimitForPutApi is the maximum size in bytes of a file that can be uploaded
// using the /dbfs/put API. If the file is larger than this limit, the streaming
// API (/dbfs/create and /dbfs/add-block) will be used instead.
var MaxDbfsPutFileSize int64 = 2 * 1024 * 1024 * 1024

func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, mode ...WriteMode) error {
absPath, err := w.root.Join(name)
if err != nil {
return err
}

fileMode := files.FileModeWrite
if slices.Contains(mode, OverwriteIfExists) {
fileMode |= files.FileModeOverwrite
}

// Issue info call before write because it automatically creates parent directories.
//
// For discussion: we could decide this is actually convenient, remove the call below,
Expand All @@ -114,7 +228,36 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m
}
}

handle, err := w.workspaceClient.Dbfs.Open(ctx, absPath, fileMode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This new approach really belongs in the SDK.

There is an existing interface for dealing with DBFS files (that is called here).

The details of streaming a file or doing a single put call can be abstracted there and surfaced here with a dedicated FileMode to indicate whether it should be a single call or multiple calls.

The size of a file can be retrieved through the io.Seeker interface.

The change here should really be limited to determining the file mode and not the implementation.

The SDK guarantees the correctness of the implementation in either streaming or single-call mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that DBFS is a service that will be deprecated, at least the public-facing part I think we should just keep it in the CLI rather than invest time to define the interfaces on the SDK side and using them here.

This PR is really meant to address the regression from the legacy Databricks CLI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to move it to the SDK if you disagree.

localFile, ok := reader.(*os.File)

// If the source is not a local file, we'll always use the streaming API endpoint.
if !ok {
return w.streamFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), reader)
}

stat, err := localFile.Stat()
if err != nil {
return fmt.Errorf("failed to stat file: %w", err)
}

// If the source is a local file, but is too large then we'll use the streaming API endpoint.
if stat.Size() > MaxDbfsPutFileSize {
return w.streamFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile)
}

// Use the /dbfs/put API when the file is on the local filesystem
// and is small enough. This is the most common case when users use the
// `databricks fs cp` command.
return w.putFile(ctx, absPath, slices.Contains(mode, OverwriteIfExists), localFile)
}

func (w *DbfsClient) streamFile(ctx context.Context, path string, overwrite bool, reader io.Reader) error {
fileMode := files.FileModeWrite
if overwrite {
fileMode |= files.FileModeOverwrite
}

handle, err := w.workspaceClient.Dbfs.Open(ctx, path, fileMode)
if err != nil {
var aerr *apierr.APIError
if !errors.As(err, &aerr) {
Expand All @@ -124,7 +267,7 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m
// This API returns a 400 if the file already exists.
if aerr.StatusCode == http.StatusBadRequest {
if aerr.ErrorCode == "RESOURCE_ALREADY_EXISTS" {
return FileAlreadyExistsError{absPath}
return FileAlreadyExistsError{path}
}
}

Expand All @@ -136,7 +279,6 @@ func (w *DbfsClient) Write(ctx context.Context, name string, reader io.Reader, m
if err == nil {
err = cerr
}

return err
}

Expand Down
Loading
Loading