Skip to content

Commit

Permalink
Add cancel function to session
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniomika committed Nov 20, 2024
1 parent 4b7dcdc commit 4429484
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
5 changes: 5 additions & 0 deletions pipe/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ func (c *Client) AddSession(id string, command string, buffer int, readTimeout,
buffer = 0
}

ctx, cancelFunc := context.WithCancel(c.Context)

session := &Session{
Context: ctx,
Logger: c.Logger.With("id", id, "command", command),
Client: c,
Command: command,
Expand All @@ -150,6 +153,7 @@ func (c *Client) AddSession(id string, command string, buffer int, readTimeout,
Done: make(chan struct{}),
In: make(chan SendData, buffer),
Out: make(chan SendData, buffer),
cancelFunc: cancelFunc,
}

err := session.Open()
Expand All @@ -172,6 +176,7 @@ func (c *Client) RemoveSession(id string) error {

if session, ok := c.Sessions.Load(id); ok {
err = session.Close()
session.Cancel()
c.Sessions.Delete(id)
}

Expand Down
32 changes: 30 additions & 2 deletions pipe/session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pipe

import (
"context"
"fmt"
"io"
"log/slog"
Expand All @@ -13,8 +14,9 @@ import (

// Session represents a session to a remote command.
type Session struct {
Logger *slog.Logger
Client *Client
Context context.Context
Logger *slog.Logger
Client *Client

Done chan struct{}
In chan SendData
Expand All @@ -37,6 +39,8 @@ type Session struct {

connectMu sync.Mutex
reconnectMu sync.Mutex

cancelFunc context.CancelFunc
}

type SendData struct {
Expand Down Expand Up @@ -136,6 +140,8 @@ func (s *Session) Reconnect() {
select {
case <-s.Client.CtxDone:
return
case <-s.Context.Done():
return
default:
err := s.Open()
if err != nil {
Expand Down Expand Up @@ -167,6 +173,9 @@ func (s *Session) Start() {
case <-s.Client.CtxDone:
s.broadcastDone()
return
case <-s.Context.Done():
s.broadcastDone()
return
case data, ok := <-s.In:
_, err := s.StdinPipe.Write(data.Data)
if !ok || err != nil || data.Error != nil {
Expand All @@ -187,6 +196,9 @@ func (s *Session) Start() {
case <-s.Client.CtxDone:
s.broadcastDone()
return
case <-s.Context.Done():
s.broadcastDone()
return
default:
data := make([]byte, 32*1024)

Expand All @@ -201,6 +213,9 @@ func (s *Session) Start() {
case <-s.Client.CtxDone:
s.broadcastDone()
return
case <-s.Context.Done():
s.broadcastDone()
return
}

if err != nil {
Expand Down Expand Up @@ -230,6 +245,9 @@ func (s *Session) Write(data []byte) (int, error) {
case <-s.Client.CtxDone:
s.broadcastDone()
break
case <-s.Context.Done():
s.broadcastDone()
break
case <-s.writeTimeout():
break
}
Expand All @@ -254,13 +272,21 @@ func (s *Session) Read(data []byte) (int, error) {
case <-s.Client.CtxDone:
s.broadcastDone()
break
case <-s.Context.Done():
s.broadcastDone()
break
case <-s.readTimeout():
break
}

return n, err
}

// Cancel cancels the session.
func (s *Session) Cancel() {
s.cancelFunc()
}

func (s *Session) readTimeout() <-chan time.Time {
if s.ReadTimeout < 0 {
return s.Client.CtxDone
Expand All @@ -287,6 +313,8 @@ func (s *Session) broadcastDone() {
break
case <-s.Client.CtxDone:
break
case <-s.Context.Done():
break
default:
break
}
Expand Down

0 comments on commit 4429484

Please sign in to comment.