diff --git a/go.mod b/go.mod index 9298fc4..8add649 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,9 @@ module github.com/picosh/utils go 1.23.1 -require golang.org/x/crypto v0.28.0 +require ( + github.com/antoniomika/syncmap v1.0.0 + golang.org/x/crypto v0.29.0 +) -require golang.org/x/sys v0.26.0 // indirect +require golang.org/x/sys v0.27.0 // indirect diff --git a/go.sum b/go.sum index 8063669..3f26f60 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,12 @@ +github.com/antoniomika/syncmap v1.0.0 h1:iFSfbQFQOvHZILFZF+hqWosO0no+W9+uF4y2VEyMKWU= +github.com/antoniomika/syncmap v1.0.0/go.mod h1:fK2829foEYnO4riNfyUn0SHQZt4ue3DStYjGU+sJj38= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= diff --git a/pipe/client.go b/pipe/client.go new file mode 100644 index 0000000..ad9391b --- /dev/null +++ b/pipe/client.go @@ -0,0 +1,243 @@ +package pipe + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/antoniomika/syncmap" + "golang.org/x/crypto/ssh" +) + +type Client struct { + Logger *slog.Logger + Info *SSHClientInfo + SSHClient *ssh.Client + Sessions *syncmap.Map[string, *Session] + + Done chan struct{} + connectMu sync.Mutex + closeDoneOnce sync.Once +} + +func NewClient(logger *slog.Logger, info *SSHClientInfo) (*Client, error) { + c := &Client{ + Logger: logger, + Info: info, + } + + err := c.Open() + if err != nil { + return nil, err + } + + return c, nil +} + +func (c *Client) Open() error { + c.Close() + c.Logger.Info("opening ssh conn", "info", c.Info) + + c.connectMu.Lock() + defer c.connectMu.Unlock() + + c.closeDoneOnce = sync.Once{} + c.Done = make(chan struct{}) + + if c.Sessions == nil { + c.Sessions = syncmap.New[string, *Session]() + } + + sshClient, err := NewSSHClient(c.Info) + if err != nil { + return err + } + + c.SSHClient = sshClient + + c.Sessions.Range(func(key string, value *Session) bool { + value.Open() + return true + }) + + return nil +} + +func (c *Client) Close() error { + c.Logger.Info("closing ssh conn", "info", c.Info) + + c.connectMu.Lock() + defer c.connectMu.Unlock() + + if c.Done != nil { + c.closeDoneOnce.Do(func() { + close(c.Done) + c.Done = nil + }) + } + + var errs []error + + if c.Sessions != nil { + c.Sessions.Range(func(key string, value *Session) bool { + errs = append(errs, value.Close()) + return true + }) + } + + if c.SSHClient != nil { + errs = append(errs, c.SSHClient.Close()) + } + + return errors.Join(errs...) +} + +func (c *Client) AddSession(id string, cmd string, buffer int, timeout time.Duration) (*Session, error) { + if c.SSHClient == nil { + return nil, fmt.Errorf("ssh client is not connected") + } + + if buffer < 0 { + buffer = 0 + } + + if timeout < 0 { + timeout = 10 * time.Millisecond + } + + session := &Session{ + Client: c, + Cmd: cmd, + BufferSize: buffer, + Timeout: timeout, + Done: make(chan struct{}), + In: make(chan []byte, buffer), + Out: make(chan []byte, buffer), + } + + err := session.Open() + if err != nil { + return nil, err + } + + s, _ := c.Sessions.LoadOrStore(id, session) + + return s, nil +} + +func (c *Client) RemoveSession(id string) error { + if c.SSHClient == nil { + return fmt.Errorf("ssh client is not connected") + } + + var err error + + if session, ok := c.Sessions.Load(id); ok { + err = session.Close() + c.Sessions.Delete(id) + } + + return err +} + +type SSHClientInfo struct { + RemoteHost string + RemoteHostname string + RemoteUser string + KeyLocation string + KeyPassphrase string +} + +func NewSSHClient(info *SSHClientInfo) (*ssh.Client, error) { + if info == nil { + return nil, fmt.Errorf("conn info is invalid") + } + + if !strings.Contains(info.RemoteHost, ":") { + info.RemoteHost += ":22" + } + + rawConn, err := net.Dial("tcp", info.RemoteHost) + if err != nil { + return nil, err + } + + var signer ssh.Signer + + if info.KeyLocation != "" { + keyPath, err := filepath.Abs(info.KeyLocation) + if err != nil { + return nil, err + } + + data, err := os.ReadFile(keyPath) + if err != nil { + return nil, err + } + + if info.KeyPassphrase != "" { + signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(info.KeyPassphrase)) + } else { + signer, err = ssh.ParsePrivateKey(data) + } + + if err != nil { + return nil, err + } + } + + var authMethods []ssh.AuthMethod + if signer != nil { + authMethods = append(authMethods, ssh.PublicKeys(signer)) + } + + sshConn, chans, reqs, err := ssh.NewClientConn(rawConn, info.RemoteHostname, &ssh.ClientConfig{ + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + User: info.RemoteUser, + }) + + if err != nil { + return nil, err + } + + sshClient := ssh.NewClient(sshConn, chans, reqs) + + return sshClient, nil +} + +func Base(id string, cmd string, ctx context.Context, info *SSHClientInfo) (io.ReadWriteCloser, error) { + client, err := NewClient(slog.Default(), info) + if err != nil { + return nil, err + } + + session, err := client.AddSession(id, cmd, 0, 0) + if err != nil { + return nil, err + } + + go func() { + <-ctx.Done() + session.Close() + client.Close() + }() + + return session, nil +} + +func Sub(cmd string, ctx context.Context, info *SSHClientInfo) (io.Reader, error) { + return Base("sub", cmd, ctx, info) +} + +func Pub(cmd string, ctx context.Context, info *SSHClientInfo) (io.WriteCloser, error) { + return Base("pub", cmd, ctx, info) +} diff --git a/pipe/log/handler.go b/pipe/log/handler.go new file mode 100644 index 0000000..5309f68 --- /dev/null +++ b/pipe/log/handler.go @@ -0,0 +1,77 @@ +package log + +import ( + "context" + "errors" + "log/slog" + "slices" + "sync" +) + +type MultiHandler struct { + Handlers []slog.Handler + mu sync.Mutex +} + +func (m *MultiHandler) Enabled(ctx context.Context, l slog.Level) bool { + m.mu.Lock() + defer m.mu.Unlock() + + for _, h := range m.Handlers { + if h.Enabled(ctx, l) { + return true + } + } + + return false +} + +func (m *MultiHandler) Handle(ctx context.Context, r slog.Record) error { + m.mu.Lock() + defer m.mu.Unlock() + + var errs []error + for _, h := range m.Handlers { + if h.Enabled(ctx, r.Level) { + errs = append(errs, h.Handle(ctx, r.Clone())) + } + } + + return errors.Join(errs...) +} + +func (m *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + m.mu.Lock() + defer m.mu.Unlock() + + var handlers []slog.Handler + + for _, h := range m.Handlers { + handlers = append(handlers, h.WithAttrs(slices.Clone(attrs))) + } + + return &MultiHandler{ + Handlers: handlers, + } +} + +func (m *MultiHandler) WithGroup(name string) slog.Handler { + if name == "" { + return m + } + + m.mu.Lock() + defer m.mu.Unlock() + + var handlers []slog.Handler + + for _, h := range m.Handlers { + handlers = append(handlers, h.WithGroup(name)) + } + + return &MultiHandler{ + Handlers: handlers, + } +} + +var _ slog.Handler = (*MultiHandler)(nil) diff --git a/pipe/log/log.go b/pipe/log/log.go new file mode 100644 index 0000000..fc11d96 --- /dev/null +++ b/pipe/log/log.go @@ -0,0 +1,43 @@ +package log + +import ( + "context" + "io" + "log/slog" + "time" + + "github.com/picosh/utils/pipe" +) + +func RegisterLogger(logger *slog.Logger, info *pipe.SSHClientInfo, buffer int, timeout time.Duration) (*slog.Logger, error) { + if buffer < 0 { + buffer = 0 + } + + logWriter, err := pipe.NewClient(logger, info) + if err != nil { + return nil, err + } + + s, err := logWriter.AddSession("rootLogger", "pub log-drain -b=false", buffer, timeout) + if err != nil { + return nil, err + } + + currentHandler := logger.Handler() + return slog.New( + &MultiHandler{ + Handlers: []slog.Handler{ + currentHandler, + slog.NewJSONHandler(s, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + }), + }, + }, + ), nil +} + +func ReadLogs(ctx context.Context, connectionInfo *pipe.SSHClientInfo) (io.Reader, error) { + return pipe.Sub("sub log-drain -k", ctx, connectionInfo) +} diff --git a/pipe/session.go b/pipe/session.go new file mode 100644 index 0000000..e698562 --- /dev/null +++ b/pipe/session.go @@ -0,0 +1,225 @@ +package pipe + +import ( + "fmt" + "io" + "slices" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +type Session struct { + Client *Client + Cmd string + Session *ssh.Session + StdinPipe io.WriteCloser + StdoutPipe io.Reader + Done chan struct{} + In chan []byte + Out chan []byte + Timeout time.Duration + BufferSize int + + startOnce sync.Once + cleanDoneOnce sync.Once + reconnectOnce sync.Once + + connectMu sync.Mutex + reconnectMu sync.Mutex +} + +var _ io.ReadWriteCloser = (*Session)(nil) + +func (s *Session) Open() error { + s.Close() + s.Client.Logger.Info("opening ssh session", "sessionCmd", s.Cmd) + + s.connectMu.Lock() + defer s.connectMu.Unlock() + + if s.Client == nil { + return fmt.Errorf("client is nil") + } + + if s.Client.SSHClient == nil { + return fmt.Errorf("ssh client is nil") + } + + session, err := s.Client.SSHClient.NewSession() + if err != nil { + return err + } + + stdinPipe, err := session.StdinPipe() + if err != nil { + return err + } + + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return err + } + + err = session.Start(s.Cmd) + if err != nil { + return err + } + + s.Session = session + s.StdinPipe = stdinPipe + s.StdoutPipe = stdoutPipe + + s.startOnce = sync.Once{} + s.cleanDoneOnce = sync.Once{} + + s.Start() + + return nil +} + +func (s *Session) Close() error { + s.Client.Logger.Info("closing session", "sessionCmd", s.Cmd) + s.connectMu.Lock() + defer s.connectMu.Unlock() + + var err error + + if s.Session != nil { + err = s.Session.Close() + } + + s.cleanDoneOnce.Do(func() { + broadcastDone(s.Done, s.Timeout) + + for len(s.Done) > 0 { + <-s.Done + } + }) + + return err +} + +func (s *Session) Reconnect() { + s.reconnectMu.Lock() + defer s.reconnectMu.Unlock() + + s.reconnectOnce.Do(func() { + go func() { + s.reconnectMu.Lock() + defer s.reconnectMu.Unlock() + + for { + err := s.Open() + if err != nil { + if s.Client != nil { + err = s.Client.Open() + } + } + + if err == nil { + break + } + + time.Sleep(5 * time.Second) + } + + s.reconnectOnce = sync.Once{} + }() + }) +} + +func (s *Session) Start() { + s.startOnce.Do(func() { + go func() { + for { + select { + case <-s.Done: + select { + case s.Done <- struct{}{}: + break + case <-time.After(s.Timeout): + break + } + return + case data, ok := <-s.In: + _, err := s.StdinPipe.Write(data) + if !ok || err != nil { + s.Client.Logger.Error("received error on write, reopening conn", "error", err) + s.Reconnect() + return + } + } + } + }() + + go func() { + for { + select { + case <-s.Done: + + return + default: + data := make([]byte, 32*1024) + + n, err := s.StdoutPipe.Read(data) + if err != nil { + s.Client.Logger.Error("received error on read, reopening conn", "error", err) + s.Reconnect() + return + } + + s.Out <- data[:n] + } + } + }() + }) +} + +func (s *Session) Write(data []byte) (int, error) { + var ( + n int + err error + ) + + select { + case s.In <- slices.Clone(data): + n = len(data) + case <-time.After(s.Timeout): + err = fmt.Errorf("unable to send data within timeout") + case <-s.Done: + broadcastDone(s.Done, s.Timeout) + break + } + + return n, err +} + +func (s *Session) Read(data []byte) (int, error) { + var ( + n int + err error + ) + + select { + case d := <-s.Out: + n = copy(data, d) + case <-time.After(s.Timeout): + err = fmt.Errorf("unable to read data within timeout") + case <-s.Done: + broadcastDone(s.Done, s.Timeout) + break + } + + return n, err +} + +func broadcastDone(done chan struct{}, timeout time.Duration) { + select { + case done <- struct{}{}: + break + case <-time.After(timeout): + break + } +}