From dcc95d69611b04c3f1ed5e4babdd3896e1f05c4b Mon Sep 17 00:00:00 2001 From: chris erway Date: Fri, 20 Oct 2023 16:33:04 -0400 Subject: [PATCH] add high-priority stream flag using SetWriteDeadline magic value --- session.go | 47 +++++++++++++++++++++++++++++++++++------------ session_test.go | 6 +++--- stream.go | 21 ++++++++++++++++----- 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/session.go b/session.go index bc775f8..c524c90 100644 --- a/session.go +++ b/session.go @@ -96,6 +96,9 @@ type Session struct { // sendCh is used to send messages sendCh chan []byte + // highPrioSendCh is used to send messages for streams marked high priority. + highPrioSendCh chan []byte + // pingCh and pingCh are used to send pings and pongs pongCh, pingCh chan uint32 @@ -144,6 +147,7 @@ func newSession(config *Config, conn net.Conn, client bool, readBuf int, newMemo synCh: make(chan struct{}, config.AcceptBacklog), acceptCh: make(chan *Stream, config.AcceptBacklog), sendCh: make(chan []byte, 64), + highPrioSendCh: make(chan []byte, 64), pongCh: make(chan uint32, config.PingBacklog), pingCh: make(chan uint32), recvDoneCh: make(chan struct{}), @@ -326,7 +330,7 @@ func (s *Session) exitErr(err error) { // GoAway can be used to prevent accepting further // connections. It does not close the underlying conn. func (s *Session) GoAway() error { - return s.sendMsg(s.goAway(goAwayNormal), nil, nil) + return s.sendMsg(s.goAway(goAwayNormal), nil, nil, false) } // goAway is used to send a goAway message @@ -483,7 +487,7 @@ func (s *Session) extendKeepalive() { } // send sends the header and body. -func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) error { +func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}, highPriority bool) error { select { case <-s.shutdownCh: return s.shutdownErr @@ -495,11 +499,16 @@ func (s *Session) sendMsg(hdr header, body []byte, deadline <-chan struct{}) err copy(buf[:headerSize], hdr[:]) copy(buf[headerSize:], body) + sendCh := s.sendCh + if highPriority { + sendCh = s.highPrioSendCh + } + select { case <-s.shutdownCh: pool.Put(buf) return s.shutdownErr - case s.sendCh <- buf: + case sendCh <- buf: return nil case <-deadline: pool.Put(buf) @@ -579,9 +588,9 @@ func (s *Session) sendLoop() (err error) { hdr := encode(typePing, flagACK, 0, pingID) copy(buf, hdr[:]) default: - // Then send normal data. + // Next, the highPrioSendCh gets to send before other streams, if data is available. select { - case buf = <-s.sendCh: + case buf = <-s.highPrioSendCh: case pingID := <-s.pingCh: buf = pool.Get(headerSize) hdr := encode(typePing, flagSYN, 0, pingID) @@ -590,8 +599,22 @@ func (s *Session) sendLoop() (err error) { buf = pool.Get(headerSize) hdr := encode(typePing, flagACK, 0, pingID) copy(buf, hdr[:]) - case <-s.shutdownCh: - return nil + default: + // Then send normal data. + select { + case buf = <-s.highPrioSendCh: + case buf = <-s.sendCh: + case pingID := <-s.pingCh: + buf = pool.Get(headerSize) + hdr := encode(typePing, flagSYN, 0, pingID) + copy(buf, hdr[:]) + case pingID := <-s.pongCh: + buf = pool.Get(headerSize) + hdr := encode(typePing, flagACK, 0, pingID) + copy(buf, hdr[:]) + case <-s.shutdownCh: + return nil + } // default: // select { // case buf = <-s.sendCh: @@ -734,7 +757,7 @@ func (s *Session) handleStreamMessage(hdr header) error { // Read the new data if err := stream.readData(hdr, flags, s.reader); err != nil { - if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { + if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil, false); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } return err @@ -802,7 +825,7 @@ func (s *Session) incomingStream(id uint32) error { // Reject immediately if we are doing a go away if atomic.LoadInt32(&s.localGoAway) == 1 { hdr := encode(typeWindowUpdate, flagRST, id, 0) - return s.sendMsg(hdr, nil, nil) + return s.sendMsg(hdr, nil, nil, false) } // Allocate a new stream @@ -821,7 +844,7 @@ func (s *Session) incomingStream(id uint32) error { // Check if stream already exists if _, ok := s.streams[id]; ok { s.logger.Printf("[ERR] yamux: duplicate stream declared") - if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { + if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil, false); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } span.Done() @@ -833,7 +856,7 @@ func (s *Session) incomingStream(id uint32) error { s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset") defer span.Done() hdr := encode(typeWindowUpdate, flagRST, id, 0) - return s.sendMsg(hdr, nil, nil) + return s.sendMsg(hdr, nil, nil, false) } s.numIncomingStreams++ @@ -850,7 +873,7 @@ func (s *Session) incomingStream(id uint32) error { s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset") s.deleteStream(id) hdr := encode(typeWindowUpdate, flagRST, id, 0) - return s.sendMsg(hdr, nil, nil) + return s.sendMsg(hdr, nil, nil, false) } } diff --git a/session_test.go b/session_test.go index 974b6d5..ed13e60 100644 --- a/session_test.go +++ b/session_test.go @@ -1322,7 +1322,7 @@ func TestSession_sendMsg_Timeout(t *testing.T) { hdr := encode(typePing, flagACK, 0, 0) for { - err := client.sendMsg(hdr, nil, nil) + err := client.sendMsg(hdr, nil, nil, false) if err == nil { continue } else if err == ErrConnectionWriteTimeout { @@ -1345,14 +1345,14 @@ func TestWindowOverflow(t *testing.T) { defer server.Close() hdr1 := encode(typeData, flagSYN, i, 0) - _ = client.sendMsg(hdr1, nil, nil) + _ = client.sendMsg(hdr1, nil, nil, false) s, err := server.AcceptStream() if err != nil { t.Fatal(err) } msg := make([]byte, client.config.MaxStreamWindowSize*2) hdr2 := encode(typeData, 0, i, uint32(len(msg))) - _ = client.sendMsg(hdr2, msg, nil) + _ = client.sendMsg(hdr2, msg, nil, false) _, err = io.ReadAll(s) if err == nil { t.Fatal("expected to read no data") diff --git a/stream.go b/stream.go index e1e5602..f617dbb 100644 --- a/stream.go +++ b/stream.go @@ -26,6 +26,11 @@ const ( halfReset ) +// HighPriorityWriteDeadlineMagicValue is a special value that can be passed to +// SetWriteDeadline to indicate that this stream should get to send its data +// before other streams. +var HighPriorityWriteDeadlineMagicValue = time.Unix(1<<60, 0) + // Stream is used to represent a logical stream // within a session. type Stream struct { @@ -49,6 +54,8 @@ type Stream struct { sendNotifyCh chan struct{} readDeadline, writeDeadline pipeDeadline + + highPriority bool } // newStream is used to construct a new stream within a given session for an ID. @@ -70,6 +77,7 @@ func newStream(session *Session, id uint32, state streamState, initialWindow uin epochStart: time.Now(), recvNotifyCh: make(chan struct{}, 1), sendNotifyCh: make(chan struct{}, 1), + highPriority: false, } return s } @@ -179,7 +187,7 @@ START: // Send the header hdr = encode(typeData, flags, s.id, max) - if err = s.session.sendMsg(hdr, b[:max], s.writeDeadline.wait()); err != nil { + if err = s.session.sendMsg(hdr, b[:max], s.writeDeadline.wait(), s.highPriority); err != nil { return 0, err } @@ -238,7 +246,7 @@ func (s *Stream) sendWindowUpdate(deadline <-chan struct{}) error { s.epochStart = now hdr := encode(typeWindowUpdate, flags, s.id, delta) - return s.session.sendMsg(hdr, nil, deadline) + return s.session.sendMsg(hdr, nil, deadline, s.highPriority) } // sendClose is used to send a FIN @@ -246,13 +254,13 @@ func (s *Stream) sendClose() error { flags := s.sendFlags() flags |= flagFIN hdr := encode(typeWindowUpdate, flags, s.id, 0) - return s.session.sendMsg(hdr, nil, nil) + return s.session.sendMsg(hdr, nil, nil, s.highPriority) } // sendReset is used to send a RST func (s *Stream) sendReset() error { hdr := encode(typeWindowUpdate, flagRST, s.id, 0) - return s.session.sendMsg(hdr, nil, nil) + return s.session.sendMsg(hdr, nil, nil, s.highPriority) } // Reset resets the stream (forcibly closes the stream) @@ -490,7 +498,10 @@ func (s *Stream) SetReadDeadline(t time.Time) error { func (s *Stream) SetWriteDeadline(t time.Time) error { s.stateLock.Lock() defer s.stateLock.Unlock() - if s.writeState == halfOpen { + // handle magic time.Time value to signal this is a high-priority stream. + if t.Equal(HighPriorityWriteDeadlineMagicValue) { + s.highPriority = true + } else if s.writeState == halfOpen { s.writeDeadline.set(t) } return nil