diff --git a/pipe.go b/pipe.go index a2da6639..c5999a5d 100644 --- a/pipe.go +++ b/pipe.go @@ -11,6 +11,7 @@ import ( "net" "os" "runtime" + "sync" "time" "unsafe" @@ -316,8 +317,9 @@ type win32PipeListener struct { path string config PipeConfig acceptCh chan (chan acceptResponse) - closeCh chan int - doneCh chan int + closeOnce sync.Once + closeCh chan struct{} + doneCh chan struct{} } func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (windows.Handle, error) { @@ -530,8 +532,8 @@ func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { path: path, config: *c, acceptCh: make(chan (chan acceptResponse)), - closeCh: make(chan int), - doneCh: make(chan int), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), } go l.listenerRoutine() return l, nil @@ -573,11 +575,10 @@ func (l *win32PipeListener) Accept() (net.Conn, error) { } func (l *win32PipeListener) Close() error { - select { - case l.closeCh <- 1: - <-l.doneCh - case <-l.doneCh: - } + l.closeOnce.Do(func() { + close(l.closeCh) + }) + <-l.doneCh return nil } diff --git a/pipe_test.go b/pipe_test.go index 8b1d5b94..27223850 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -644,5 +644,49 @@ func TestListenConnectRace(t *testing.T) { s.Close() } wg.Wait() + + t.Logf("iteration %d", i) + } +} + +func TestCloseRace(t *testing.T) { + for i := 0; i < 1000 && !t.Failed(); i++ { + l, err := ListenPipe(testPipeName, &PipeConfig{MessageMode: true}) + if err != nil { + t.Fatal(err) + } + go func() { + for { + c, err := l.Accept() + if err != nil { + return + } + b, err := io.ReadAll(c) + if err != nil { + t.Error(err) + return + } + _, _ = c.Write(b) + _ = c.Close() + } + }() + + c, err := DialPipe(testPipeName, nil) + if err != nil { + t.Fatal(err) + } + if _, err = c.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + if err := c.(CloseWriter).CloseWrite(); err != nil { + t.Fatal(err) + } + if _, err := io.ReadAll(c); err != nil { + t.Fatal(err) + } + _ = c.Close() + _ = l.Close() + + t.Logf("iteration %d", i) } }