Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spinner): add Output method to expose overwriting the output #405

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion spinner/spinner.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ func (s *Spinner) Accessible(accessible bool) *Spinner {
return s
}

// Output sets the output of the spinner.
func (s *Spinner) Output(output *termenv.Output) *Spinner {
s.output = output
return s
}

// New creates a new spinner.
func New() *Spinner {
s := spinner.New()
Expand Down Expand Up @@ -174,7 +180,7 @@ func (s *Spinner) runAccessible() error {
s.output.HideCursor()
frame := s.spinner.Style.Render("...")
title := s.titleStyle.Render(strings.TrimSuffix(s.title, "..."))
fmt.Println(title + frame)
fmt.Fprintln(s.output, title+frame)

if s.ctx == nil {
s.action()
Expand Down
88 changes: 88 additions & 0 deletions spinner/spinner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ package spinner

import (
"context"
"errors"
"io"
"os"
"reflect"
"strings"
"testing"
"time"

"github.com/charmbracelet/bubbles/spinner"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/muesli/termenv"
)

func TestNewSpinner(t *testing.T) {
Expand Down Expand Up @@ -112,3 +117,86 @@ func TestAccessibleSpinner(t *testing.T) {
t.Errorf("Run() in accessible mode returned an error: %v", err)
}
}

func TestSpinnerOutput(t *testing.T) {
tests := []struct {
name string
wantStdout bool
wantStderr bool
}{
{
name: "stdout",
wantStdout: true,
wantStderr: false,
},
{
name: "stderr",
wantStdout: false,
wantStderr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
const title = "Test Output"

// Save original stderr and stdout
oldStderr := os.Stderr
oldStdout := os.Stdout

// Create pipes for stderr and stdout
stderrReader, stderrWriter, _ := os.Pipe()
stdoutReader, stdoutWriter, _ := os.Pipe()

// Set global stderr and stdout to our pipes
os.Stderr = stderrWriter
os.Stdout = stdoutWriter

// Create a spinner and set its output
s := New().Title(title).Accessible(true)
if tc.wantStderr {
s.Output(termenv.NewOutput(os.Stderr))
}
if tc.wantStdout {
s.Output(termenv.NewOutput(os.Stdout))
}
s.action = func() { time.Sleep(100 * time.Millisecond) }
if err := s.Run(); err != nil {
t.Errorf("Spinner.Run() returned an error: %v", err)
}

// Restore original stderr and stdout
os.Stderr = oldStderr
os.Stdout = oldStdout

// Close the pipes
if err := errors.Join(stderrWriter.Close(), stdoutWriter.Close()); err != nil {
t.Errorf("Failed to close pipes: %v", err)
}

// Read from the pipes
stderrOutput, stderrErr := io.ReadAll(stderrReader)
stdoutOutput, stdoutErr := io.ReadAll(stdoutReader)
if err := errors.Join(stderrErr, stdoutErr); err != nil {
t.Errorf("Failed to read from pipes: %v", err)
}

// Check the output
if tc.wantStderr {
if !strings.Contains(string(stderrOutput), title) {
t.Errorf("Stderr got %q, but wanted %q", stderrOutput, title)
}
if len(stdoutOutput) > 0 {
t.Errorf("Expected no output on stdout, but got %q", stdoutOutput)
}
}
if tc.wantStdout {
if !strings.Contains(string(stdoutOutput), title) {
t.Errorf("Stdout got %q, but wanted %q", stdoutOutput, title)
}
if len(stderrOutput) > 0 {
t.Errorf("Expected no output on stderr, but got %q", stderrOutput)
}
}
})
}
}