Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing bad connection error when reading large compressed packets #863

Merged
merged 12 commits into from
May 7, 2024
125 changes: 83 additions & 42 deletions packet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/sha1"
"crypto/x509"
"encoding/pem"
goErrors "errors"
"io"
"net"
"sync"
Expand Down Expand Up @@ -65,8 +66,6 @@ type Conn struct {

compressedHeader [7]byte

compressedReaderActive bool

compressedReader io.Reader
}

Expand Down Expand Up @@ -107,42 +106,17 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
}()

if c.Compression != MYSQL_COMPRESS_NONE {
if !c.compressedReaderActive {
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
}

compressedSequence := c.compressedHeader[3]
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
if compressedSequence != c.CompressedSequence {
return nil, errors.Errorf("invalid compressed sequence %d != %d",
compressedSequence, c.CompressedSequence)
}

if uncompressedLength > 0 {
var err error
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
c.compressedReader, err = zlib.NewReader(c.reader)
case MYSQL_COMPRESS_ZSTD:
c.compressedReader, err = zstd.NewReader(c.reader)
}
if err != nil {
return nil, err
}
if c.compressedReader == nil {
var err error
c.compressedReader, err = c.newCompressedPacketReader()
if err != nil {
return nil, err
}
c.compressedReaderActive = true
}
}

if c.compressedReader != nil {
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
return nil, errors.Trace(err)
}
} else {
if err := c.ReadPacketTo(buf, c.reader); err != nil {
return nil, errors.Trace(err)
}
if err := c.ReadPacketTo(buf); err != nil {
return nil, errors.Trace(err)
}

readBytes := buf.Bytes()
Expand All @@ -167,22 +141,78 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
return result, nil
}

func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
// newCompressedPacketReader creates a new compressed packet reader.
func (c *Conn) newCompressedPacketReader() (io.Reader, error) {
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
}

compressedSequence := c.compressedHeader[3]
if compressedSequence != c.CompressedSequence {
return nil, errors.Errorf("invalid compressed sequence %d != %d",
compressedSequence, c.CompressedSequence)
}

compressedLength := int(uint32(c.compressedHeader[0]) | uint32(c.compressedHeader[1])<<8 | uint32(c.compressedHeader[2])<<16)
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
if uncompressedLength > 0 {
limitedReader := io.LimitReader(c.reader, int64(compressedLength))
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
return zlib.NewReader(limitedReader)
case MYSQL_COMPRESS_ZSTD:
return zstd.NewReader(limitedReader)
}
}

return nil, nil
}

func (c *Conn) currentPacketReader() io.Reader {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there're

	compressedReaderActive bool

	compressedReader io.Reader

in Conn. Seems we can directly check c.compressedReader == nil as the returned reader for currentPacketReader. And compressedReaderActive always has the same value for c.compressedReader == nil so we can delete it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I attempted to delete compressedReaderActive the tests in the client package all began failing when I ran them with compression enabled. I think this is because compressedReaderActive is reset to false in WritePacket after writing the compressed packet. So I don't think I can delete it, or at least I feel deleting it is out of scope for this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this PR with your suggestion. It was 2am for me and I wasn't thinking clearly, but after more sleep, I realized I could easily remove the compressedReaderActive boolean property from Conn.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take care of your health ❤️

if c.Compression == MYSQL_COMPRESS_NONE || c.compressedReader == nil {
return c.reader
} else {
return c.compressedReader
}
}

func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) {
var written int64

for n > 0 {
bcap := cap(c.copyNBuf)
if int64(bcap) > n {
bcap = int(n)
}
buf := c.copyNBuf[:bcap]

rd, err := io.ReadAtLeast(src, buf, bcap)
// Call ReadAtLeast with the currentPacketReader as it may change on every iteration
// of this loop.
rd, err := io.ReadAtLeast(c.currentPacketReader(), buf, bcap)

n -= int64(rd)

// ReadAtLeast will return EOF or ErrUnexpectedEOF when fewer than the min
// bytes are read. In this case, and when we have compression then advance
// the sequence number and reset the compressed reader to continue reading
// the remaining bytes in the next compressed packet.
if c.Compression != MYSQL_COMPRESS_NONE &&
(goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) {
// we have read to EOF and read an incomplete uncompressed packet
// so advance the compressed sequence number and reset the compressed reader
// to get the remaining unread uncompressed bytes from the next compressed packet.
c.CompressedSequence++
if c.compressedReader, err = c.newCompressedPacketReader(); err != nil {
return written, errors.Trace(err)
}
}

if err != nil {
return written, errors.Trace(err)
}

wr, err := dst.Write(buf)
// careful to only write from the buffer the number of bytes read
wr, err := dst.Write(buf[:rd])
written += int64(wr)
if err != nil {
return written, errors.Trace(err)
Expand All @@ -192,9 +222,21 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
return written, nil
}

func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
if _, err := io.ReadFull(r, c.header[:4]); err != nil {
func (c *Conn) ReadPacketTo(w io.Writer) error {
b := utils.BytesBufferGet()
defer func() {
utils.BytesBufferPut(b)
}()

// packets that come in a compressed packet may be partial
// so use the copyN function to read the packet header into a
// buffer, since copyN is capable of getting the next compressed
// packet and updating the Conn state with a new compressedReader.
if _, err := c.copyN(b, 4); err != nil {
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
} else {
// copy was successful so copy the 4 bytes from the buffer to the header
copy(c.header[:4], b.Bytes()[:4])
}

length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16)
Expand All @@ -211,7 +253,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
buf.Grow(length)
}

if n, err := c.copyN(w, r, int64(length)); err != nil {
if n, err := c.copyN(w, int64(length)); err != nil {
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
} else if n != int64(length) {
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
Expand All @@ -220,7 +262,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
return nil
}

if err = c.ReadPacketTo(w, r); err != nil {
if err = c.ReadPacketTo(w); err != nil {
return errors.Wrap(err, "ReadPacketTo failed")
}
}
Expand Down Expand Up @@ -270,7 +312,6 @@ func (c *Conn) WritePacket(data []byte) error {
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
}
c.compressedReader = nil
c.compressedReaderActive = false
default:
return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set")
}
Expand Down
Loading