Skip to content

Commit

Permalink
fix: backport tinygo/v1 fd_write fixes
Browse files Browse the repository at this point in the history
Backport following commits for tinygo/v0:

* fix: unfairWorker may fail due to partial write

* refactor: use == comparison for basic errors

Signed-off-by: Gaukas Wang <[email protected]>
  • Loading branch information
gaukas committed Jun 18, 2024
1 parent 9845789 commit 658dde3
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 48 deletions.
1 change: 1 addition & 0 deletions tinygo/v0/examples/plain/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import v0 "github.com/refraction-networking/watm/tinygo/v0"

func init() {
v0.WorkerFairness(false) // by default, use unfairWorker for better performance
v0.BuildDialerWithWrappingTransport(&PlainWrappingTransport{})
v0.BuildListenerWithWrappingTransport(&PlainWrappingTransport{})
v0.BuildRelayWithWrappingTransport(&PlainWrappingTransport{}, v0.RelayWrapRemote)
Expand Down
4 changes: 2 additions & 2 deletions tinygo/v0/net/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ func (c *TCPConn) Write(b []byte) (n int, writeErr error) {
// if no deadline set, behavior depends on blocking mode of the
// underlying file descriptor.
return syscallFnFd(c.rawConn, func(fd uintptr) (int, error) {
return syscall.Write(syscallFd(fd), b)
return writeFD(fd, b)
})
} else {
// writeDeadline is set, if EAGAIN/EWOULDBLOCK is returned,
// we retry until the deadline is reached.
if err := c.rawConn.Write(func(fd uintptr) (done bool) {
n, writeErr = syscall.Write(syscallFd(fd), b)
n, writeErr = writeFD(fd, b)
if errors.Is(writeErr, syscall.EAGAIN) {
if time.Now().Before(wdl) {
return false
Expand Down
47 changes: 47 additions & 0 deletions tinygo/v0/net/fd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package net

import "syscall"

// writeFD writes data to the file descriptor fd. When a partial write occurs,
// it will continue with the remaining data until all data is written or an
// error occurs. If no progress is made in a single write call, it will return
// syscall.EIO.
//
// It is ported from (*FD).Write in golang/go/src/internal/poll/fd_unix.go
func writeFD(fd uintptr, p []byte) (int, error) {
var nn int
for {
n, err := ignoringEINTRIO(syscall.Write, syscallFd(fd), p[nn:])
if n > 0 {
nn += n
}
if nn == len(p) {
return nn, err
}
if err != nil {
return nn, err
}
if n == 0 {
return nn, syscall.EIO
}

// // TODO: retry if EAGAIN or no progress?
// if n == 0 {
// noprogress++
// }
// if noprogress == 10 {
// return nn, syscall.EIO
// }
// runtime.Gosched()
}
}

// ignoringEINTRIO is like ignoringEINTR, but just for IO calls.
func ignoringEINTRIO(fn func(fd syscallFd, p []byte) (int, error), fd syscallFd, p []byte) (int, error) {
for {
n, err := fn(fd, p)
if err != syscall.EINTR {
return n, err
}
}
}
111 changes: 65 additions & 46 deletions tinygo/v0/worker.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package v0

import (
"errors"
"io"
"log"
"net"
"runtime"
"syscall"

v0net "github.com/refraction-networking/watm/tinygo/v0/net"
Expand Down Expand Up @@ -32,7 +32,7 @@ var sourceConn v0net.Conn // sourceConn is used to communicate between WASM and
var remoteConn v0net.Conn // remoteConn is used to communicate between WASM and a dialed remote destination (for dialer/relay) or a dialing party (for listener only)
var cancelConn v0net.Conn // cancelConn is used to cancel the entire worker.

var workerFn func() int32 = unfairWorker // by default, use unfairWorker
var workerFn func() int32 = unfairWorker // by default, use unfairWorker for better performance under mostly unidirectional I/O

var readBuf []byte = make([]byte, 16384) // 16k buffer for reading

Expand Down Expand Up @@ -79,8 +79,6 @@ func untilError(f func() error) error {
// connection is not properly set to non-blocking mode, i.e., never returns
// EAGAIN, this function will block forever and never work on a lower priority
// connection. Thus it is called unfairWorker.
//
// TODO: use poll_oneoff instead of busy polling
func unfairWorker() int32 {
for {
pollFd := []pollFd{
Expand All @@ -100,8 +98,8 @@ func unfairWorker() int32 {

n, err := _poll(pollFd, -1)
if n == 0 { // TODO: re-evaluate the condition
if err == nil || errors.Is(err, syscall.EAGAIN) {
usleep(100) // wait 100us before retrying _poll
if err == nil || err == syscall.EAGAIN {
runtime.Gosched() // yield the current goroutine
continue
}
log.Println("worker: unfairWorker: _poll:", err)
Expand All @@ -110,8 +108,8 @@ func unfairWorker() int32 {

// 1st priority: cancelConn
_, err = cancelConn.Read(readBuf)
if !errors.Is(err, syscall.EAGAIN) {
if errors.Is(err, io.EOF) || err == nil {
if !(err == syscall.EAGAIN) {
if err == io.EOF || err == nil {
log.Println("worker: unfairWorker: cancelConn is closed")
return wasip1.EncodeWATERError(syscall.ECANCELED) // operation canceled
}
Expand All @@ -121,57 +119,67 @@ func unfairWorker() int32 {

// 2nd priority: sourceConn
if err := untilError(func() error {
readN, readErr := sourceConn.Read(readBuf)
nRead, readErr := sourceConn.Read(readBuf)
if readErr != nil {
if readErr != syscall.EAGAIN {
log.Println("worker: unfairWorker: sourceConn.Read:", readErr)
}
return readErr
}

writeN, writeErr := remoteConn.Write(readBuf[:readN])
nWritten, writeErr := remoteConn.Write(readBuf[:nRead])
if writeErr != nil {
log.Println("worker: unfairWorker: remoteConn.Write:", writeErr)
return syscall.EIO // input/output error, we cannot retry async write yet
return writeErr
}

if readN != writeN {
log.Println("worker: unfairWorker: readN != writeN")
return syscall.EIO // input/output error
if nRead != nWritten {
log.Printf("worker: unfairWorker: nRead != nWritten")
return syscall.EMSGSIZE // message too long to fit in send buffer even after auto partial write
}

return nil
}); !errors.Is(err, syscall.EAGAIN) {
if errors.Is(err, io.EOF) {
}); err != syscall.EAGAIN { // silently ignore EAGAIN
if err == io.EOF {
log.Println("worker: unfairWorker: sourceConn is closed")
return wasip1.EncodeWATERError(syscall.EPIPE) // broken pipe
return wasip1.EncodeWATERError(0) // success, no error
}
if errno, ok := err.(syscall.Errno); ok {
return wasip1.EncodeWATERError(errno)
}
log.Println("worker: unfairWorker: sourceConn.Read:", err)
return wasip1.EncodeWATERError(syscall.EIO) // input/output error
}

// 3rd priority: remoteConn
if err := untilError(func() error {
readN, readErr := remoteConn.Read(readBuf)
nRead, readErr := remoteConn.Read(readBuf)
if readErr != nil {
if readErr != syscall.EAGAIN {
log.Println("worker: unfairWorker: remoteConn.Read:", readErr)
}
return readErr
}

writeN, writeErr := sourceConn.Write(readBuf[:readN])
nWrite, writeErr := sourceConn.Write(readBuf[:nRead])
if writeErr != nil {
log.Println("worker: unfairWorker: sourceConn.Write:", writeErr)
return syscall.EIO // input/output error, we cannot retry async write yet
return writeErr
}

if readN != writeN {
log.Println("worker: unfairWorker: readN != writeN")
return syscall.EIO // input/output error
if nRead != nWrite {
log.Printf("worker: unfairWorker: nRead != nWrite")
return syscall.EMSGSIZE // message too long to fit in send buffer even after auto partial write
}

return nil
}); !errors.Is(err, syscall.EAGAIN) {
if errors.Is(err, io.EOF) {
}); err != syscall.EAGAIN { // silently ignore EAGAIN
if err == io.EOF {
log.Println("worker: unfairWorker: remoteConn is closed")
return wasip1.EncodeWATERError(syscall.EPIPE) // broken pipe
return wasip1.EncodeWATERError(0) // success, no error
}
if errno, ok := err.(syscall.Errno); ok {
return wasip1.EncodeWATERError(errno)
}
log.Println("worker: unfairWorker: remoteConn.Read:", err)
return wasip1.EncodeWATERError(syscall.EIO) // input/output error
}
}
Expand All @@ -183,8 +191,6 @@ func unfairWorker() int32 {
// But different from unfairWorker, fairWorker spend equal amount of turns on each connection
// for calling Read. Therefore it has a better fairness than unfairWorker, which may still
// make progress if one of the connection is not properly set to non-blocking mode.
//
// TODO: use poll_oneoff instead of busy polling
func fairWorker() int32 {
for {
pollFd := []pollFd{
Expand All @@ -204,22 +210,22 @@ func fairWorker() int32 {

n, err := _poll(pollFd, -1)
if n == 0 { // TODO: re-evaluate the condition
if err == nil || errors.Is(err, syscall.EAGAIN) {
usleep(100) // wait 100us before retrying _poll
if err == nil || err == syscall.EAGAIN {
runtime.Gosched() // yield the current goroutine
continue
}
log.Println("worker: unfairWorker: _poll:", err)
log.Println("worker: fairWorker: _poll:", err)
return int32(err.(syscall.Errno))
}

// 1st priority: cancelConn
_, err = cancelConn.Read(readBuf)
if !errors.Is(err, syscall.EAGAIN) {
if errors.Is(err, io.EOF) || err == nil {
log.Println("worker: unfairWorker: cancelConn is closed")
if !(err == syscall.EAGAIN) {
if err == io.EOF || err == nil {
log.Println("worker: fairWorker: cancelConn is closed")
return wasip1.EncodeWATERError(syscall.ECANCELED) // operation canceled
}
log.Println("worker: unfairWorker: cancelConn.Read:", err)
log.Println("worker: fairWorker: cancelConn.Read:", err)
return wasip1.EncodeWATERError(syscall.EIO) // input/output error
}

Expand All @@ -230,7 +236,13 @@ func fairWorker() int32 {
remoteConn, // dst
sourceConn, // src
readBuf); err != nil {
return wasip1.EncodeWATERError(err.(syscall.Errno))
if err == io.EOF {
return wasip1.EncodeWATERError(0) // success, no error
}
if errno, ok := err.(syscall.Errno); ok {
return wasip1.EncodeWATERError(errno)
}
return wasip1.EncodeWATERError(syscall.EIO) // other input/output error
}

// 3rd priority: remoteConn -> sourceConn
Expand All @@ -240,7 +252,13 @@ func fairWorker() int32 {
sourceConn, // dst
remoteConn, // src
readBuf); err != nil {
return wasip1.EncodeWATERError(err.(syscall.Errno))
if err == io.EOF {
return wasip1.EncodeWATERError(0)
}
if errno, ok := err.(syscall.Errno); ok {
return wasip1.EncodeWATERError(errno)
}
return wasip1.EncodeWATERError(syscall.EIO) // other input/output error
}
}
}
Expand All @@ -250,23 +268,24 @@ func copyOnce(dstName, srcName string, dst, src net.Conn, buf []byte) error {
buf = make([]byte, 16384) // 16k buffer for reading
}

readN, readErr := src.Read(buf)
if !errors.Is(readErr, syscall.EAGAIN) { // if EAGAIN, do nothing and return
if errors.Is(readErr, io.EOF) {
return syscall.EPIPE // broken pipe
nRead, readErr := src.Read(buf)
if !(readErr == syscall.EAGAIN) { // if EAGAIN, do nothing and return
if readErr == io.EOF {
log.Printf("worker: copyOnce: EOF on %s", srcName)
return io.EOF
} else if readErr != nil {
log.Printf("worker: copyOnce: %s.Read: %v", srcName, readErr)
return syscall.EIO // input/output error
}

writeN, writeErr := dst.Write(buf[:readN])
nWritten, writeErr := dst.Write(buf[:nRead])
if writeErr != nil {
log.Printf("worker: copyOnce: %s.Write: %v", dstName, writeErr)
return syscall.EIO // no matter input/output error or EAGAIN we cannot retry async write yet
}

if readN != writeN {
log.Printf("worker: copyOnce: %s.read != %s.written", srcName, dstName)
if nRead != nWritten {
log.Printf("worker: copyOnce: %s.nRead != %s.nWritten", srcName, dstName)
return syscall.EIO // input/output error
}
}
Expand Down

0 comments on commit 658dde3

Please sign in to comment.