From 4429484464d43fe75e1cbc3453af20e6a9ef554a Mon Sep 17 00:00:00 2001 From: Antonio Mika Date: Tue, 19 Nov 2024 20:02:52 -0500 Subject: [PATCH] Add cancel function to session --- pipe/client.go | 5 +++++ pipe/session.go | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/pipe/client.go b/pipe/client.go index 45e8e4f..8bae228 100644 --- a/pipe/client.go +++ b/pipe/client.go @@ -140,7 +140,10 @@ func (c *Client) AddSession(id string, command string, buffer int, readTimeout, buffer = 0 } + ctx, cancelFunc := context.WithCancel(c.Context) + session := &Session{ + Context: ctx, Logger: c.Logger.With("id", id, "command", command), Client: c, Command: command, @@ -150,6 +153,7 @@ func (c *Client) AddSession(id string, command string, buffer int, readTimeout, Done: make(chan struct{}), In: make(chan SendData, buffer), Out: make(chan SendData, buffer), + cancelFunc: cancelFunc, } err := session.Open() @@ -172,6 +176,7 @@ func (c *Client) RemoveSession(id string) error { if session, ok := c.Sessions.Load(id); ok { err = session.Close() + session.Cancel() c.Sessions.Delete(id) } diff --git a/pipe/session.go b/pipe/session.go index 8755fae..f4ee0c7 100644 --- a/pipe/session.go +++ b/pipe/session.go @@ -1,6 +1,7 @@ package pipe import ( + "context" "fmt" "io" "log/slog" @@ -13,8 +14,9 @@ import ( // Session represents a session to a remote command. type Session struct { - Logger *slog.Logger - Client *Client + Context context.Context + Logger *slog.Logger + Client *Client Done chan struct{} In chan SendData @@ -37,6 +39,8 @@ type Session struct { connectMu sync.Mutex reconnectMu sync.Mutex + + cancelFunc context.CancelFunc } type SendData struct { @@ -136,6 +140,8 @@ func (s *Session) Reconnect() { select { case <-s.Client.CtxDone: return + case <-s.Context.Done(): + return default: err := s.Open() if err != nil { @@ -167,6 +173,9 @@ func (s *Session) Start() { case <-s.Client.CtxDone: s.broadcastDone() return + case <-s.Context.Done(): + s.broadcastDone() + return case data, ok := <-s.In: _, err := s.StdinPipe.Write(data.Data) if !ok || err != nil || data.Error != nil { @@ -187,6 +196,9 @@ func (s *Session) Start() { case <-s.Client.CtxDone: s.broadcastDone() return + case <-s.Context.Done(): + s.broadcastDone() + return default: data := make([]byte, 32*1024) @@ -201,6 +213,9 @@ func (s *Session) Start() { case <-s.Client.CtxDone: s.broadcastDone() return + case <-s.Context.Done(): + s.broadcastDone() + return } if err != nil { @@ -230,6 +245,9 @@ func (s *Session) Write(data []byte) (int, error) { case <-s.Client.CtxDone: s.broadcastDone() break + case <-s.Context.Done(): + s.broadcastDone() + break case <-s.writeTimeout(): break } @@ -254,6 +272,9 @@ func (s *Session) Read(data []byte) (int, error) { case <-s.Client.CtxDone: s.broadcastDone() break + case <-s.Context.Done(): + s.broadcastDone() + break case <-s.readTimeout(): break } @@ -261,6 +282,11 @@ func (s *Session) Read(data []byte) (int, error) { return n, err } +// Cancel cancels the session. +func (s *Session) Cancel() { + s.cancelFunc() +} + func (s *Session) readTimeout() <-chan time.Time { if s.ReadTimeout < 0 { return s.Client.CtxDone @@ -287,6 +313,8 @@ func (s *Session) broadcastDone() { break case <-s.Client.CtxDone: break + case <-s.Context.Done(): + break default: break }