diff --git a/peer/channel.go b/peer/channel.go index c0415e50baa1a..d7119d1eafb7d 100644 --- a/peer/channel.go +++ b/peer/channel.go @@ -106,12 +106,15 @@ func (c *Channel) init() { // write operations to block once the threshold is set. c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) c.dc.OnBufferedAmountLow(func() { + // Grab the lock to protect the sendMore channel from being + // closed in between the isClosed check and the send. + c.closeMutex.Lock() + defer c.closeMutex.Unlock() if c.isClosed() { return } select { case <-c.closed: - return case c.sendMore <- struct{}{}: default: } @@ -122,15 +125,16 @@ func (c *Channel) init() { }) c.dc.OnOpen(func() { c.closeMutex.Lock() - defer c.closeMutex.Unlock() - c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label())) var err error c.rwc, err = c.dc.Detach() if err != nil { + c.closeMutex.Unlock() _ = c.closeWithError(xerrors.Errorf("detach: %w", err)) return } + c.closeMutex.Unlock() + // pion/webrtc will return an io.ErrShortBuffer when a read // is triggerred with a buffer size less than the chunks written. // @@ -189,9 +193,6 @@ func (c *Channel) init() { // // This will block until the underlying DataChannel has been opened. func (c *Channel) Read(bytes []byte) (int, error) { - if c.isClosed() { - return 0, c.closeError - } err := c.waitOpened() if err != nil { return 0, err @@ -228,9 +229,6 @@ func (c *Channel) Write(bytes []byte) (n int, err error) { c.writeMutex.Lock() defer c.writeMutex.Unlock() - if c.isClosed() { - return 0, c.closeWithError(nil) - } err = c.waitOpened() if err != nil { return 0, err @@ -308,6 +306,10 @@ func (c *Channel) isClosed() bool { func (c *Channel) waitOpened() error { select { case <-c.opened: + // Re-check the closed channel to prioritize closure. + if c.isClosed() { + return c.closeError + } return nil case <-c.closed: return c.closeError diff --git a/peer/conn.go b/peer/conn.go index 8eae101ccdbbe..2e67b500ee5fd 100644 --- a/peer/conn.go +++ b/peer/conn.go @@ -3,7 +3,6 @@ package peer import ( "bytes" "context" - "crypto/rand" "io" "sync" @@ -256,7 +255,6 @@ func (c *Conn) init() error { c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate)) select { case <-c.closed: - break case c.localCandidateChannel <- iceCandidate.ToJSON(): } }() @@ -265,7 +263,6 @@ func (c *Conn) init() error { go func() { select { case <-c.closed: - return case c.dcOpenChannel <- dc: } }() @@ -435,9 +432,6 @@ func (c *Conn) pingEchoChannel() (*Channel, error) { data := make([]byte, pingDataLength) bytesRead, err := c.pingEchoChan.Read(data) if err != nil { - if c.isClosed() { - return - } _ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err)) return } diff --git a/peer/conn_test.go b/peer/conn_test.go index 20f4c84638b0c..d1fbf63d15ab6 100644 --- a/peer/conn_test.go +++ b/peer/conn_test.go @@ -91,6 +91,8 @@ func TestConn(t *testing.T) { // Create a channel that closes on disconnect. channel, err := server.CreateChannel(context.Background(), "wow", nil) assert.NoError(t, err) + defer channel.Close() + err = wan.Stop() require.NoError(t, err) // Once the connection is marked as disconnected, this @@ -107,10 +109,13 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) require.NoError(t, err) + defer cch.Close() - sch, err := server.Accept(context.Background()) + sch, err := server.Accept(ctx) require.NoError(t, err) defer sch.Close() @@ -123,9 +128,12 @@ func TestConn(t *testing.T) { t.Parallel() client, server, wan := createPair(t) exchange(t, client, server) - cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) require.NoError(t, err) - sch, err := server.Accept(context.Background()) + defer cch.Close() + sch, err := server.Accept(ctx) require.NoError(t, err) defer sch.Close() @@ -140,26 +148,44 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) - cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{}) - require.NoError(t, err) - sch, err := server.Accept(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{}) require.NoError(t, err) - defer sch.Close() + defer cch.Close() + + readErr := make(chan error, 1) go func() { + sch, err := server.Accept(ctx) + if err != nil { + readErr <- err + _ = cch.Close() + return + } + defer sch.Close() + bytes := make([]byte, 4096) - for i := 0; i < 1024; i++ { - _, err := cch.Write(bytes) - require.NoError(t, err) + for { + _, err = sch.Read(bytes) + if err != nil { + readErr <- err + return + } } - _ = cch.Close() }() + bytes := make([]byte, 4096) - for { - _, err = sch.Read(bytes) - if err != nil { - require.ErrorIs(t, err, peer.ErrClosed) - break - } + for i := 0; i < 1024; i++ { + _, err = cch.Write(bytes) + require.NoError(t, err, "write i=%d", i) + } + _ = cch.Close() + + select { + case err = <-readErr: + require.ErrorIs(t, err, peer.ErrClosed, "read error") + case <-ctx.Done(): + require.Fail(t, "timeout waiting for read error") } }) @@ -170,13 +196,29 @@ func TestConn(t *testing.T) { srv, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) defer srv.Close() + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() go func() { - sch, err := server.Accept(context.Background()) - assert.NoError(t, err) + sch, err := server.Accept(ctx) + if err != nil { + assert.NoError(t, err) + return + } + defer sch.Close() + nc2 := sch.NetConn() + defer nc2.Close() + nc1, err := net.Dial("tcp", srv.Addr().String()) - assert.NoError(t, err) + if err != nil { + assert.NoError(t, err) + return + } + defer nc1.Close() + go func() { + defer nc1.Close() + defer nc2.Close() _, _ = io.Copy(nc1, nc2) }() _, _ = io.Copy(nc2, nc1) @@ -204,7 +246,7 @@ func TestConn(t *testing.T) { c := http.Client{ Transport: defaultTransport, } - req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil) + req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil) require.NoError(t, err) resp, err := c.Do(req) require.NoError(t, err) @@ -272,14 +314,21 @@ func TestConn(t *testing.T) { t.Parallel() client, server, _ := createPair(t) exchange(t, client, server) + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() go func() { - channel, err := client.CreateChannel(context.Background(), "test", nil) - assert.NoError(t, err) + channel, err := client.CreateChannel(ctx, "test", nil) + if err != nil { + assert.NoError(t, err) + return + } + defer channel.Close() _, err = channel.Write([]byte{1, 2}) assert.NoError(t, err) }() - channel, err := server.Accept(context.Background()) + channel, err := server.Accept(ctx) require.NoError(t, err) + defer channel.Close() data := make([]byte, 1) _, err = channel.Read(data) require.NoError(t, err)
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: