diff --git a/pkg/generator/generator.go b/pkg/generator/generator.go index 62f15513..af2ea2b8 100644 --- a/pkg/generator/generator.go +++ b/pkg/generator/generator.go @@ -1,8 +1,10 @@ package generator import ( + "bytes" "fmt" "go/ast" + "go/format" "go/parser" "go/token" "io" @@ -40,6 +42,37 @@ func DefaultWriter(filePath string) (*os.File, error) { return file, nil } +type FileWriter struct { + FilePath string + file *os.File +} + +func (fw *FileWriter) Write(p []byte) (int, error) { + if fw.file == nil { + file, err := utils.CreateFile(fw.FilePath, true) + if err != nil { + return 0, fmt.Errorf("failed create file %s : %v", fw.FilePath, err) + } + fw.file = file + defer fw.Close() + } + + formattedCode, err := format.Source(p) + if err != nil { + return 0, fmt.Errorf("error format code : %v", err) + } + + return fw.file.Write(formattedCode) +} + +// Close closes the underlying file +func (fw *FileWriter) Close() error { + if fw.file != nil { + return fw.file.Close() + } + return nil +} + func Generate(input GenerateInput, writer io.Writer) error { // set default writer if writer == nil { @@ -47,7 +80,6 @@ func Generate(input GenerateInput, writer io.Writer) error { if err != nil { return err } - defer file.Close() writer = file } @@ -61,7 +93,18 @@ func Generate(input GenerateInput, writer io.Writer) error { return fmt.Errorf("error parsing : %v", err) } - return tmpl.Execute(writer, input.BindData) + var renderedCode bytes.Buffer + err = tmpl.Execute(&renderedCode, input.BindData) + if err != nil { + return fmt.Errorf("error execute template : %v", err) + } + + _, err = writer.Write(renderedCode.Bytes()) + if err != nil { + return err + } + + return nil } func CreateInternalFolder(basePath string) (err error) { diff --git a/pkg/generator/generator_test.go b/pkg/generator/generator_test.go new file mode 100644 index 00000000..80183e99 --- /dev/null +++ b/pkg/generator/generator_test.go @@ -0,0 +1,76 @@ +package generator_test + +import ( + "testing" + + "github.com/sev-2/raiden/pkg/generator" + "github.com/stretchr/testify/assert" +) + +func TestGenerate_ErrorWritingToFile(t *testing.T) { + invalidPath := "/invalid_path/output.txt" + + tmpl := "{{ .Name }}" + input := generator.GenerateInput{ + BindData: struct{ Name string }{"John"}, + Template: tmpl, + TemplateName: "testTemplate", + OutputPath: invalidPath, + } + + err := generator.Generate(input, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed create file") +} + +func TestGenerate_ErrorParsingTemplate(t *testing.T) { + input := generator.GenerateInput{ + BindData: nil, + Template: "{{ .invalid}", + TemplateName: "testTemplate", + OutputPath: "test_output.txt", + } + + err := generator.Generate(input, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "error parsing") +} + +func TestGenerate_ErrorExecutingTemplate(t *testing.T) { + tmpl := "{{ .Name }}" + input := generator.GenerateInput{ + BindData: struct{}{}, + Template: tmpl, + TemplateName: "testTemplate", + OutputPath: "test_output.txt", + } + + err := generator.Generate(input, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "error execute template") +} + +// Test for FileWriter Write function +func TestFileWriter_Write_ErrorCreatingFile(t *testing.T) { + invalidPath := "/invalid_path/output.txt" // Simulating an invalid path + + fw := &generator.FileWriter{FilePath: invalidPath} + _, err := fw.Write([]byte("test content")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed create file") +} + +func TestFileWriter_Write_ErrorFormattingCode(t *testing.T) { + validPath := "test_err_output_formatting.txt" + + fw := &generator.FileWriter{FilePath: validPath} + + input := generator.GenerateInput{ + BindData: struct{ Name string }{"invalid code"}, + Template: "{{ .Name }}", + TemplateName: "testTemplate", + } + + err := generator.Generate(input, fw) + assert.Contains(t, err.Error(), "error format code") +} diff --git a/pkg/generator/model.go b/pkg/generator/model.go index dad0549f..c88bbf22 100644 --- a/pkg/generator/model.go +++ b/pkg/generator/model.go @@ -147,8 +147,11 @@ func GenerateModel(folderPath string, input *GenerateModelInput, generateFn Gene OutputPath: filePath, } + // setup writer + writer := FileWriter{FilePath: filePath} + ModelLogger.Debug("generate model", "path", generateInput.OutputPath) - return generateFn(generateInput, nil) + return generateFn(generateInput, &writer) } // map table to column, map pg type to go type and get dependency import path diff --git a/pkg/resource/import.go b/pkg/resource/import.go index dcf6b705..f35db986 100644 --- a/pkg/resource/import.go +++ b/pkg/resource/import.go @@ -305,7 +305,7 @@ func generateImportResource(config *raiden.Config, importState *state.LocalState func ImportDecorateFunc[T any](data []T, findFunc func(T, generator.GenerateInput) bool, stateChan chan any) generator.GenerateFn { return func(input generator.GenerateInput, writer io.Writer) error { - if err := generator.Generate(input, nil); err != nil { + if err := generator.Generate(input, writer); err != nil { return err } if rs, found := FindImportResource(data, input, findFunc); found {