-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make client race proof and fix reconnect
- Loading branch information
1 parent
b351d5d
commit b922efc
Showing
6 changed files
with
599 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,12 @@ | ||
github.com/antoniomika/syncmap v1.0.0 h1:iFSfbQFQOvHZILFZF+hqWosO0no+W9+uF4y2VEyMKWU= | ||
github.com/antoniomika/syncmap v1.0.0/go.mod h1:fK2829foEYnO4riNfyUn0SHQZt4ue3DStYjGU+sJj38= | ||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= | ||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= | ||
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= | ||
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= | ||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= | ||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= | ||
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= | ||
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
package pipe | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"log/slog" | ||
"net" | ||
"os" | ||
"path/filepath" | ||
"strings" | ||
"sync" | ||
"time" | ||
|
||
"github.com/antoniomika/syncmap" | ||
"golang.org/x/crypto/ssh" | ||
) | ||
|
||
type Client struct { | ||
Logger *slog.Logger | ||
Info *SSHClientInfo | ||
SSHClient *ssh.Client | ||
Sessions *syncmap.Map[string, *Session] | ||
|
||
Done chan struct{} | ||
connectMu sync.Mutex | ||
closeDoneOnce sync.Once | ||
} | ||
|
||
func NewClient(logger *slog.Logger, info *SSHClientInfo) (*Client, error) { | ||
c := &Client{ | ||
Logger: logger, | ||
Info: info, | ||
} | ||
|
||
err := c.Open() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return c, nil | ||
} | ||
|
||
func (c *Client) Open() error { | ||
c.Close() | ||
c.Logger.Info("opening ssh conn", "info", c.Info) | ||
|
||
c.connectMu.Lock() | ||
defer c.connectMu.Unlock() | ||
|
||
c.closeDoneOnce = sync.Once{} | ||
c.Done = make(chan struct{}) | ||
|
||
if c.Sessions == nil { | ||
c.Sessions = syncmap.New[string, *Session]() | ||
} | ||
|
||
sshClient, err := NewSSHClient(c.Info) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
c.SSHClient = sshClient | ||
|
||
c.Sessions.Range(func(key string, value *Session) bool { | ||
value.Open() | ||
return true | ||
}) | ||
|
||
return nil | ||
} | ||
|
||
func (c *Client) Close() error { | ||
c.Logger.Info("closing ssh conn", "info", c.Info) | ||
|
||
c.connectMu.Lock() | ||
defer c.connectMu.Unlock() | ||
|
||
if c.Done != nil { | ||
c.closeDoneOnce.Do(func() { | ||
close(c.Done) | ||
c.Done = nil | ||
}) | ||
} | ||
|
||
var errs []error | ||
|
||
if c.Sessions != nil { | ||
c.Sessions.Range(func(key string, value *Session) bool { | ||
errs = append(errs, value.Close()) | ||
return true | ||
}) | ||
} | ||
|
||
if c.SSHClient != nil { | ||
errs = append(errs, c.SSHClient.Close()) | ||
} | ||
|
||
return errors.Join(errs...) | ||
} | ||
|
||
func (c *Client) AddSession(id string, cmd string, buffer int, timeout time.Duration) (*Session, error) { | ||
if c.SSHClient == nil { | ||
return nil, fmt.Errorf("ssh client is not connected") | ||
} | ||
|
||
if buffer < 0 { | ||
buffer = 0 | ||
} | ||
|
||
if timeout < 0 { | ||
timeout = 10 * time.Millisecond | ||
} | ||
|
||
session := &Session{ | ||
Client: c, | ||
Cmd: cmd, | ||
BufferSize: buffer, | ||
Timeout: timeout, | ||
Done: make(chan struct{}), | ||
In: make(chan []byte, buffer), | ||
Out: make(chan []byte, buffer), | ||
} | ||
|
||
err := session.Open() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
s, _ := c.Sessions.LoadOrStore(id, session) | ||
|
||
return s, nil | ||
} | ||
|
||
func (c *Client) RemoveSession(id string) error { | ||
if c.SSHClient == nil { | ||
return fmt.Errorf("ssh client is not connected") | ||
} | ||
|
||
var err error | ||
|
||
if session, ok := c.Sessions.Load(id); ok { | ||
err = session.Close() | ||
c.Sessions.Delete(id) | ||
} | ||
|
||
return err | ||
} | ||
|
||
type SSHClientInfo struct { | ||
RemoteHost string | ||
RemoteHostname string | ||
RemoteUser string | ||
KeyLocation string | ||
KeyPassphrase string | ||
} | ||
|
||
func NewSSHClient(info *SSHClientInfo) (*ssh.Client, error) { | ||
if info == nil { | ||
return nil, fmt.Errorf("conn info is invalid") | ||
} | ||
|
||
if !strings.Contains(info.RemoteHost, ":") { | ||
info.RemoteHost += ":22" | ||
} | ||
|
||
rawConn, err := net.Dial("tcp", info.RemoteHost) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
var signer ssh.Signer | ||
|
||
if info.KeyLocation != "" { | ||
keyPath, err := filepath.Abs(info.KeyLocation) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
data, err := os.ReadFile(keyPath) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if info.KeyPassphrase != "" { | ||
signer, err = ssh.ParsePrivateKeyWithPassphrase(data, []byte(info.KeyPassphrase)) | ||
} else { | ||
signer, err = ssh.ParsePrivateKey(data) | ||
} | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
var authMethods []ssh.AuthMethod | ||
if signer != nil { | ||
authMethods = append(authMethods, ssh.PublicKeys(signer)) | ||
} | ||
|
||
sshConn, chans, reqs, err := ssh.NewClientConn(rawConn, info.RemoteHostname, &ssh.ClientConfig{ | ||
Auth: authMethods, | ||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), | ||
User: info.RemoteUser, | ||
}) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
sshClient := ssh.NewClient(sshConn, chans, reqs) | ||
|
||
return sshClient, nil | ||
} | ||
|
||
func Base(id string, cmd string, ctx context.Context, info *SSHClientInfo) (io.ReadWriteCloser, error) { | ||
client, err := NewClient(slog.Default(), info) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
session, err := client.AddSession(id, cmd, 0, 0) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
go func() { | ||
<-ctx.Done() | ||
session.Close() | ||
client.Close() | ||
}() | ||
|
||
return session, nil | ||
} | ||
|
||
func Sub(cmd string, ctx context.Context, info *SSHClientInfo) (io.Reader, error) { | ||
return Base("sub", cmd, ctx, info) | ||
} | ||
|
||
func Pub(cmd string, ctx context.Context, info *SSHClientInfo) (io.WriteCloser, error) { | ||
return Base("pub", cmd, ctx, info) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package log | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"log/slog" | ||
"slices" | ||
"sync" | ||
) | ||
|
||
type MultiHandler struct { | ||
Handlers []slog.Handler | ||
mu sync.Mutex | ||
} | ||
|
||
func (m *MultiHandler) Enabled(ctx context.Context, l slog.Level) bool { | ||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
|
||
for _, h := range m.Handlers { | ||
if h.Enabled(ctx, l) { | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
} | ||
|
||
func (m *MultiHandler) Handle(ctx context.Context, r slog.Record) error { | ||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
|
||
var errs []error | ||
for _, h := range m.Handlers { | ||
if h.Enabled(ctx, r.Level) { | ||
errs = append(errs, h.Handle(ctx, r.Clone())) | ||
} | ||
} | ||
|
||
return errors.Join(errs...) | ||
} | ||
|
||
func (m *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler { | ||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
|
||
var handlers []slog.Handler | ||
|
||
for _, h := range m.Handlers { | ||
handlers = append(handlers, h.WithAttrs(slices.Clone(attrs))) | ||
} | ||
|
||
return &MultiHandler{ | ||
Handlers: handlers, | ||
} | ||
} | ||
|
||
func (m *MultiHandler) WithGroup(name string) slog.Handler { | ||
if name == "" { | ||
return m | ||
} | ||
|
||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
|
||
var handlers []slog.Handler | ||
|
||
for _, h := range m.Handlers { | ||
handlers = append(handlers, h.WithGroup(name)) | ||
} | ||
|
||
return &MultiHandler{ | ||
Handlers: handlers, | ||
} | ||
} | ||
|
||
var _ slog.Handler = (*MultiHandler)(nil) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package log | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"log/slog" | ||
"time" | ||
|
||
"github.com/picosh/utils/pipe" | ||
) | ||
|
||
func RegisterLogger(logger *slog.Logger, info *pipe.SSHClientInfo, buffer int, timeout time.Duration) (*slog.Logger, error) { | ||
if buffer < 0 { | ||
buffer = 0 | ||
} | ||
|
||
logWriter, err := pipe.NewClient(logger, info) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
s, err := logWriter.AddSession("rootLogger", "pub log-drain -b=false", buffer, timeout) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
currentHandler := logger.Handler() | ||
return slog.New( | ||
&MultiHandler{ | ||
Handlers: []slog.Handler{ | ||
currentHandler, | ||
slog.NewJSONHandler(s, &slog.HandlerOptions{ | ||
AddSource: true, | ||
Level: slog.LevelDebug, | ||
}), | ||
}, | ||
}, | ||
), nil | ||
} | ||
|
||
func ReadLogs(ctx context.Context, connectionInfo *pipe.SSHClientInfo) (io.Reader, error) { | ||
return pipe.Sub("sub log-drain -k", ctx, connectionInfo) | ||
} |
Oops, something went wrong.