Skip to content

Commit

Permalink
Make client race proof and fix reconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniomika committed Nov 16, 2024
1 parent b351d5d commit b922efc
Show file tree
Hide file tree
Showing 6 changed files with 599 additions and 2 deletions.
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ module github.com/picosh/utils

go 1.23.1

require golang.org/x/crypto v0.28.0
require (
github.com/antoniomika/syncmap v1.0.0
golang.org/x/crypto v0.29.0
)

require golang.org/x/sys v0.26.0 // indirect
require golang.org/x/sys v0.27.0 // indirect
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
github.com/antoniomika/syncmap v1.0.0 h1:iFSfbQFQOvHZILFZF+hqWosO0no+W9+uF4y2VEyMKWU=
github.com/antoniomika/syncmap v1.0.0/go.mod h1:fK2829foEYnO4riNfyUn0SHQZt4ue3DStYjGU+sJj38=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ=
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
243 changes: 243 additions & 0 deletions pipe/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
package pipe

import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/antoniomika/syncmap"
"golang.org/x/crypto/ssh"
)

type Client struct {
Logger *slog.Logger
Info *SSHClientInfo
SSHClient *ssh.Client
Sessions *syncmap.Map[string, *Session]

Done chan struct{}
connectMu sync.Mutex
closeDoneOnce sync.Once
}

func NewClient(logger *slog.Logger, info *SSHClientInfo) (*Client, error) {
c := &Client{
Logger: logger,
Info: info,
}

err := c.Open()
if err != nil {
return nil, err
}

return c, nil
}

func (c *Client) Open() error {
c.Close()
c.Logger.Info("opening ssh conn", "info", c.Info)

c.connectMu.Lock()
defer c.connectMu.Unlock()

c.closeDoneOnce = sync.Once{}
c.Done = make(chan struct{})

if c.Sessions == nil {
c.Sessions = syncmap.New[string, *Session]()
}

sshClient, err := NewSSHClient(c.Info)
if err != nil {
return err
}

c.SSHClient = sshClient

c.Sessions.Range(func(key string, value *Session) bool {
value.Open()

Check failure on line 67 in pipe/client.go

View workflow job for this annotation

GitHub Actions / test

Error return value of `value.Open` is not checked (errcheck)
return true
})

return nil
}

func (c *Client) Close() error {
c.Logger.Info("closing ssh conn", "info", c.Info)

c.connectMu.Lock()
defer c.connectMu.Unlock()

if c.Done != nil {
c.closeDoneOnce.Do(func() {
close(c.Done)
c.Done = nil
})
}

var errs []error

if c.Sessions != nil {
c.Sessions.Range(func(key string, value *Session) bool {
errs = append(errs, value.Close())
return true
})
}

if c.SSHClient != nil {
errs = append(errs, c.SSHClient.Close())
}

return errors.Join(errs...)
}

func (c *Client) AddSession(id string, cmd string, buffer int, timeout time.Duration) (*Session, error) {
if c.SSHClient == nil {
return nil, fmt.Errorf("ssh client is not connected")
}

if buffer < 0 {
buffer = 0
}

if timeout < 0 {
timeout = 10 * time.Millisecond
}

session := &Session{
Client: c,
Cmd: cmd,
BufferSize: buffer,
Timeout: timeout,
Done: make(chan struct{}),
In: make(chan []byte, buffer),
Out: make(chan []byte, buffer),
}

err := session.Open()
if err != nil {
return nil, err
}

s, _ := c.Sessions.LoadOrStore(id, session)

return s, nil
}

func (c *Client) RemoveSession(id string) error {
if c.SSHClient == nil {
return fmt.Errorf("ssh client is not connected")
}

var err error

if session, ok := c.Sessions.Load(id); ok {
err = session.Close()
c.Sessions.Delete(id)
}

return err
}

type SSHClientInfo struct {
RemoteHost string
RemoteHostname string
RemoteUser string
KeyLocation string
KeyPassphrase string
}

func NewSSHClient(info *SSHClientInfo) (*ssh.Client, error) {
if info == nil {
return nil, fmt.Errorf("conn info is invalid")
}

if !strings.Contains(info.RemoteHost, ":") {
info.RemoteHost += ":22"
}

rawConn, err := net.Dial("tcp", info.RemoteHost)
if err != nil {
return nil, err
}

var signer ssh.Signer

if info.KeyLocation != "" {
keyPath, err := filepath.Abs(info.KeyLocation)
if err != nil {
return nil, err
}

data, err := os.ReadFile(keyPath)
if err != nil {
return nil, err
}

if info.KeyPassphrase != "" {
signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(info.KeyPassphrase))
} else {
signer, err = ssh.ParsePrivateKey(data)
}

if err != nil {
return nil, err
}
}

var authMethods []ssh.AuthMethod
if signer != nil {
authMethods = append(authMethods, ssh.PublicKeys(signer))
}

sshConn, chans, reqs, err := ssh.NewClientConn(rawConn, info.RemoteHostname, &ssh.ClientConfig{
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
User: info.RemoteUser,
})

if err != nil {
return nil, err
}

sshClient := ssh.NewClient(sshConn, chans, reqs)

return sshClient, nil
}

func Base(id string, cmd string, ctx context.Context, info *SSHClientInfo) (io.ReadWriteCloser, error) {
client, err := NewClient(slog.Default(), info)
if err != nil {
return nil, err
}

session, err := client.AddSession(id, cmd, 0, 0)
if err != nil {
return nil, err
}

go func() {
<-ctx.Done()
session.Close()
client.Close()
}()

return session, nil
}

func Sub(cmd string, ctx context.Context, info *SSHClientInfo) (io.Reader, error) {
return Base("sub", cmd, ctx, info)
}

func Pub(cmd string, ctx context.Context, info *SSHClientInfo) (io.WriteCloser, error) {
return Base("pub", cmd, ctx, info)
}
77 changes: 77 additions & 0 deletions pipe/log/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package log

import (
"context"
"errors"
"log/slog"
"slices"
"sync"
)

type MultiHandler struct {
Handlers []slog.Handler
mu sync.Mutex
}

func (m *MultiHandler) Enabled(ctx context.Context, l slog.Level) bool {
m.mu.Lock()
defer m.mu.Unlock()

for _, h := range m.Handlers {
if h.Enabled(ctx, l) {
return true
}
}

return false
}

func (m *MultiHandler) Handle(ctx context.Context, r slog.Record) error {
m.mu.Lock()
defer m.mu.Unlock()

var errs []error
for _, h := range m.Handlers {
if h.Enabled(ctx, r.Level) {
errs = append(errs, h.Handle(ctx, r.Clone()))
}
}

return errors.Join(errs...)
}

func (m *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
m.mu.Lock()
defer m.mu.Unlock()

var handlers []slog.Handler

for _, h := range m.Handlers {
handlers = append(handlers, h.WithAttrs(slices.Clone(attrs)))
}

return &MultiHandler{
Handlers: handlers,
}
}

func (m *MultiHandler) WithGroup(name string) slog.Handler {
if name == "" {
return m
}

m.mu.Lock()
defer m.mu.Unlock()

var handlers []slog.Handler

for _, h := range m.Handlers {
handlers = append(handlers, h.WithGroup(name))
}

return &MultiHandler{
Handlers: handlers,
}
}

var _ slog.Handler = (*MultiHandler)(nil)
43 changes: 43 additions & 0 deletions pipe/log/log.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package log

import (
"context"
"io"
"log/slog"
"time"

"github.com/picosh/utils/pipe"
)

func RegisterLogger(logger *slog.Logger, info *pipe.SSHClientInfo, buffer int, timeout time.Duration) (*slog.Logger, error) {
if buffer < 0 {
buffer = 0
}

logWriter, err := pipe.NewClient(logger, info)
if err != nil {
return nil, err
}

s, err := logWriter.AddSession("rootLogger", "pub log-drain -b=false", buffer, timeout)
if err != nil {
return nil, err
}

currentHandler := logger.Handler()
return slog.New(
&MultiHandler{
Handlers: []slog.Handler{
currentHandler,
slog.NewJSONHandler(s, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelDebug,
}),
},
},
), nil
}

func ReadLogs(ctx context.Context, connectionInfo *pipe.SSHClientInfo) (io.Reader, error) {
return pipe.Sub("sub log-drain -k", ctx, connectionInfo)
}
Loading

0 comments on commit b922efc

Please sign in to comment.