Skip to content

Commit

Permalink
feat(go): Add MultipartWriter w/ Content-Type (#5131)
Browse files Browse the repository at this point in the history
  • Loading branch information
amckinney authored Nov 8, 2024
1 parent a125be8 commit d0fb1b4
Show file tree
Hide file tree
Showing 296 changed files with 47,481 additions and 187 deletions.
67 changes: 66 additions & 1 deletion generators/go/internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,13 @@ func (g *Generator) generate(ir *fernir.IntermediateRepresentation, mode Mode) (
}
files = append(files, newCoreFile(g.coordinator))
files = append(files, newCoreTestFile(g.coordinator))
files = append(files, newFileParamFile(g.coordinator, rootPackageName, generatedNames))
files = append(files, newMultipartFile(g.coordinator))
files = append(files, newMultipartTestFile(g.coordinator))
files = append(files, newPointerFile(g.coordinator, rootPackageName, generatedNames))
files = append(files, newRetrierFile(g.coordinator))
files = append(files, newQueryFile(g.coordinator))
files = append(files, newQueryTestFile(g.coordinator))
files = append(files, newRetrierFile(g.coordinator))
if ir.SdkConfig.HasStreamingEndpoints {
files = append(files, newStreamFile(g.coordinator))
}
Expand Down Expand Up @@ -924,6 +927,52 @@ func newPointerFile(coordinator *coordinator.Client, rootPackageName string, gen
)
}

// newFileParamFile returns a *File containing the FileParam helper type
// for multipart file uploads.
//
// In general, this file is deposited at the root of the SDK so that users can
// access the helpers alongside the rest of the top-level definitions. However,
// if any naming conflict exists between the generated types, this file is
// deposited in the core package.
func newFileParamFile(coordinator *coordinator.Client, rootPackageName string, generatedNames map[string]struct{}) *File {
// First determine whether or not we need to generate the type in the
// core package.
var useCorePackage bool
for generatedName := range generatedNames {
if _, ok := pointerFunctionNames[generatedName]; ok {
useCorePackage = true
break
}
}
if useCorePackage {
return NewFile(
coordinator,
"core/file_param.go",
[]byte(fileParamFile),
)
}
// We're going to generate the pointers at the root of the repository,
// so now we need to determine whether or not we can use the standard
// filename, or if it needs a prefix.
filename := "file_param.go"
if _, ok := generatedNames["FileParam"]; ok {
filename = "_file_param.go"
}
// Finally, we need to replace the package declaration so that it matches
// the root package declaration of the generated SDK.
content := strings.Replace(
fileParamFile,
"package core",
fmt.Sprintf("package %s", rootPackageName),
1,
)
return NewFile(
coordinator,
filename,
[]byte(content),
)
}

func newClientTestFile(
baseImportPath string,
coordinator *coordinator.Client,
Expand Down Expand Up @@ -959,6 +1008,22 @@ func newCoreTestFile(coordinator *coordinator.Client) *File {
)
}

func newMultipartFile(coordinator *coordinator.Client) *File {
return NewFile(
coordinator,
"core/multipart.go",
[]byte(multipartFile),
)
}

func newMultipartTestFile(coordinator *coordinator.Client) *File {
return NewFile(
coordinator,
"core/multipart_test.go",
[]byte(multipartTestFile),
)
}

func newOptionalFile(coordinator *coordinator.Client) *File {
return NewFile(
coordinator,
Expand Down
160 changes: 114 additions & 46 deletions generators/go/internal/generator/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ var (
//go:embed sdk/core/extra_properties_test.go
extraPropertiesTestFile string

//go:embed sdk/core/file_param.go
fileParamFile string

//go:embed sdk/core/multipart.go
multipartFile string

//go:embed sdk/core/multipart_test.go
multipartTestFile string

//go:embed sdk/core/optional.go
optionalFile string

Expand Down Expand Up @@ -1123,50 +1132,51 @@ func (f *fileWriter) WriteClient(
}

if len(endpoint.FileProperties) > 0 || len(endpoint.FileBodyProperties) > 0 {
f.P("requestBuffer := bytes.NewBuffer(nil)")
f.P("writer := multipart.NewWriter(requestBuffer)")
f.P("writer := core.NewMultipartWriter()")
for _, fileProperty := range endpoint.FileProperties {
filePropertyInfo, err := filePropertyToInfo(fileProperty)
if err != nil {
return nil, err
}
var (
fileVariable = filePropertyInfo.Key.Name.CamelCase.SafeName
filenameVariable = filePropertyInfo.Key.Name.CamelCase.UnsafeName + "Filename"
filenameValue = filePropertyInfo.Key.Name.CamelCase.UnsafeName + "_filename"
partVariable = filePropertyInfo.Key.Name.CamelCase.UnsafeName + "Part"
fileVariable = filePropertyInfo.Key.Name.CamelCase.SafeName
contentTypeVariableName = filePropertyInfo.Key.Name.CamelCase.SafeName + "ContentType"
)
if filePropertyInfo.IsArray {
// We don't care whether the file array is optional or not; the range
// handles that for us.
f.P("for i, f := range ", fileVariable, "{")
f.P(filenameVariable, ` := fmt.Sprintf("`, filenameValue, `_%d", i)`)
f.P("if named, ok := f.(interface{ Name() string }); ok {")
f.P(fmt.Sprintf("%s = named.Name()", filenameVariable))
f.P("}")
f.P(partVariable, `, err := writer.CreateFormFile("`, filePropertyInfo.Key.WireValue, `", `, filenameVariable, ")")
f.P("if err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
f.P("if _, err := io.Copy(", partVariable, ", f); err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
f.P("for _, f := range ", fileVariable, "{")
if filePropertyInfo.ContentType != "" {
f.P(contentTypeVariableName, " := \"", filePropertyInfo.ContentType, "\"")
f.P("if contentTyped, ok := f.(core.ContentTyped); ok {")
f.P(contentTypeVariableName, " = contentTyped.ContentType()")
f.P("}")
f.P("if err := writer.WriteFile(\"", filePropertyInfo.Key.WireValue, "\", f, core.WithMultipartContentType(", contentTypeVariableName, ")); err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
} else {
f.P("if err := writer.WriteFile(\"", filePropertyInfo.Key.WireValue, "\", f); err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
}
f.P("}")
} else {
if filePropertyInfo.IsOptional {
f.P("if ", fileVariable, " != nil {")
}
f.P(fmt.Sprintf("%s := %q", filenameVariable, filenameValue))
f.P("if named, ok := ", fileVariable, ".(interface{ Name() string }); ok {")
f.P(fmt.Sprintf("%s = named.Name()", filenameVariable))
f.P("}")
f.P(partVariable, `, err := writer.CreateFormFile("`, filePropertyInfo.Key.WireValue, `", `, filenameVariable, ")")
f.P("if err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
f.P("if _, err := io.Copy(", partVariable, ", ", fileVariable, "); err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
if filePropertyInfo.ContentType != "" {
f.P(contentTypeVariableName, " := \"", filePropertyInfo.ContentType, "\"")
f.P("if contentTyped, ok := ", fileVariable, ".(core.ContentTyped); ok {")
f.P(contentTypeVariableName, " = contentTyped.ContentType()")
f.P("}")
f.P("if err := writer.WriteFile(\"", filePropertyInfo.Key.WireValue, "\", ", fileVariable, ", core.WithMultipartContentType(", contentTypeVariableName, ")); err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
} else {
f.P("if err := writer.WriteFile(\"", filePropertyInfo.Key.WireValue, "\", ", fileVariable, "); err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
}
if filePropertyInfo.IsOptional {
f.P("}")
}
Expand All @@ -1186,28 +1196,39 @@ func (f *fileWriter) WriteClient(
// Encapsulate the multipart form WriteField in a closure so that we can easily
// wrap it with an optional nil check below.
writeField := func() {
if !valueTypeFormat.IsPrimitive {
// Non-primitive types need to be JSON-serialized (e.g. lists, objects, etc).
f.P(`if err := core.WriteMultipartJSON(writer, "`, fileBodyProperty.Name.WireValue, `", `, requestField, "); err != nil {")
field := requestField
if valueTypeFormat.IsIterable {
field = "part"
}
if valueTypeFormat.IsPrimitive {
f.P(`if err := writer.WriteField("`, fileBodyProperty.Name.WireValue, `", fmt.Sprintf("%v", `, field, ")); err != nil {")
} else if fileBodyProperty.ContentType != nil {
f.P(`if err := writer.WriteJSON("`, fileBodyProperty.Name.WireValue, `", `, field, `, core.WithMultipartContentType("`, *fileBodyProperty.ContentType, `")); err != nil {`)
} else {
f.P(`if err := writer.WriteField("`, fileBodyProperty.Name.WireValue, `", fmt.Sprintf("%v", `, requestField, ")); err != nil {")
f.P(`if err := writer.WriteJSON("`, fileBodyProperty.Name.WireValue, `", `, field, "); err != nil {")
}
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
}

if valueTypeFormat.IsOptional {
f.P("if ", endpoint.RequestParameterName, ".", fileBodyProperty.Name.Name.PascalCase.UnsafeName, "!= nil {")
writeField()
}
if valueTypeFormat.IsIterable {
f.P("for _, part := range ", endpoint.RequestParameterName, ".", fileBodyProperty.Name.Name.PascalCase.UnsafeName, " {")
}
writeField()
if valueTypeFormat.IsIterable {
f.P("}")
}
if valueTypeFormat.IsOptional {
f.P("}")
} else {
writeField()
}
}
f.P("if err := writer.Close(); err != nil {")
f.P("return ", endpoint.ErrorReturnValues)
f.P("}")
f.P(headersParameter, `.Set("Content-Type", writer.FormDataContentType())`)
f.P(headersParameter, `.Set("Content-Type", writer.ContentType())`)
f.P()
}

Expand Down Expand Up @@ -2014,23 +2035,34 @@ func generatedClientInstantiation(
}

type filePropertyInfo struct {
Key *ir.NameAndWireValue
IsOptional bool
IsArray bool
Key *ir.NameAndWireValue
IsOptional bool
IsArray bool
ContentType string
}

func filePropertyToInfo(fileProperty *ir.FileProperty) (*filePropertyInfo, error) {
switch fileProperty.Type {
case "file":
var contentType string
if fileProperty.File.ContentType != nil {
contentType = *fileProperty.File.ContentType
}
return &filePropertyInfo{
Key: fileProperty.File.Key,
IsOptional: fileProperty.File.IsOptional,
Key: fileProperty.File.Key,
IsOptional: fileProperty.File.IsOptional,
ContentType: contentType,
}, nil
case "fileArray":
var contentType string
if fileProperty.FileArray.ContentType != nil {
contentType = *fileProperty.FileArray.ContentType
}
return &filePropertyInfo{
Key: fileProperty.FileArray.Key,
IsOptional: fileProperty.FileArray.IsOptional,
IsArray: true,
Key: fileProperty.FileArray.Key,
IsOptional: fileProperty.FileArray.IsOptional,
IsArray: true,
ContentType: contentType,
}, nil
}
return nil, fmt.Errorf("file property %s is not yet supported", fileProperty.Type)
Expand Down Expand Up @@ -2235,7 +2267,7 @@ func (f *fileWriter) endpointFromIR(
if irEndpoint.RequestBody != nil && irEndpoint.RequestBody.FileUpload != nil {
// This is a file upload request, so we prepare a buffer for the request body
// instead of just using the request specified by the function signature.
requestValueName = "requestBuffer"
requestValueName = "writer.Buffer()"
}
}

Expand Down Expand Up @@ -3085,6 +3117,7 @@ type valueTypeFormat struct {
ZeroValue string
IsOptional bool
IsPrimitive bool
IsIterable bool
}

func formatForValueType(typeReference *ir.TypeReference, types map[ir.TypeId]*ir.TypeDeclaration) *valueTypeFormat {
Expand All @@ -3094,6 +3127,18 @@ func formatForValueType(typeReference *ir.TypeReference, types map[ir.TypeId]*ir
isOptional bool
isPrimitive bool
)
iterableType := maybeIterableType(typeReference, types)
if iterableType != nil {
value := formatForValueType(iterableType, types)
return &valueTypeFormat{
Prefix: value.Prefix,
Suffix: value.Suffix,
ZeroValue: value.ZeroValue,
IsOptional: value.IsOptional,
IsPrimitive: value.IsPrimitive,
IsIterable: true,
}
}
if typeReference.Container != nil && typeReference.Container.Optional != nil {
isOptional = true
if needsOptionalDereference(typeReference.Container.Optional, types) {
Expand Down Expand Up @@ -3269,6 +3314,29 @@ func isOptionalType(typeReference *ir.TypeReference, types map[ir.TypeId]*ir.Typ
return typeReference.Container != nil && typeReference.Container.Optional != nil
}

// maybeIterableType returns the given type reference's iterable type, if any.
func maybeIterableType(typeReference *ir.TypeReference, types map[ir.TypeId]*ir.TypeDeclaration) *ir.TypeReference {
if typeReference.Named != nil {
typeDeclaration := types[typeReference.Named.TypeId]
if typeDeclaration.Shape.Alias != nil {
return maybeIterableType(typeDeclaration.Shape.Alias.AliasOf, types)
}
return nil
}
if typeReference.Container != nil {
if typeReference.Container.Optional != nil {
return maybeIterableType(typeReference.Container.Optional, types)
}
if typeReference.Container.List != nil {
return typeReference.Container.List
}
if typeReference.Container.Set != nil {
return typeReference.Container.Set
}
}
return nil
}

// needsOptionalDereference returns true if the optional type needs to be referenced.
//
// Container types like lists, maps, and sets are already nil-able, so they don't
Expand Down
41 changes: 41 additions & 0 deletions generators/go/internal/generator/sdk/core/file_param.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package core

import (
"io"
)

// FileParam is a file type suitable for multipart/form-data uploads.
type FileParam struct {
io.Reader
filename string
contentType string
}

// FileParamOption adapts the behavior of the FileParam. No options are
// implemented yet, but this interface allows for future extensibility.
type FileParamOption interface {
apply()
}

// NewFileParam returns a *FileParam type suitable for multipart/form-data uploads. All file
// upload endpoints accept a simple io.Reader, which is usually created by opening a file
// via os.Open.
//
// However, some endpoints require additional metadata about the file such as a specific
// Content-Type or custom filename. FileParam makes it easier to create the correct type
// signature for these endpoints.
func NewFileParam(
reader io.Reader,
filename string,
contentType string,
opts ...FileParamOption,
) *FileParam {
return &FileParam{
Reader: reader,
filename: filename,
contentType: contentType,
}
}

func (f *FileParam) Name() string { return f.filename }
func (f *FileParam) ContentType() string { return f.contentType }
Loading

0 comments on commit d0fb1b4

Please sign in to comment.