Skip to content

Commit e33a749

Browse files
mafredriEmyrk
andauthored
fix: Deadlock and race in peer, test improvements (#3086)
* fix: Potential deadlock in peer.Channel dc.OnOpen * fix: Potential send on closed channel * fix: Improve robustness of waitOpened during close * chore: Simplify statements * fix: Improve teardown and timeout of peer tests * fix: Improve robustness of TestConn/Buffering test * Update peer/channel.go Co-authored-by: Steven Masley <Emyrk@users.noreply.github.com>
1 parent 62e6856 commit e33a749

File tree

3 files changed

+85
-40
lines changed

3 files changed

+85
-40
lines changed

peer/channel.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,15 @@ func (c *Channel) init() {
106106
// write operations to block once the threshold is set.
107107
c.dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
108108
c.dc.OnBufferedAmountLow(func() {
109+
// Grab the lock to protect the sendMore channel from being
110+
// closed in between the isClosed check and the send.
111+
c.closeMutex.Lock()
112+
defer c.closeMutex.Unlock()
109113
if c.isClosed() {
110114
return
111115
}
112116
select {
113117
case <-c.closed:
114-
return
115118
case c.sendMore <- struct{}{}:
116119
default:
117120
}
@@ -122,15 +125,16 @@ func (c *Channel) init() {
122125
})
123126
c.dc.OnOpen(func() {
124127
c.closeMutex.Lock()
125-
defer c.closeMutex.Unlock()
126-
127128
c.conn.logger().Debug(context.Background(), "datachannel opening", slog.F("id", c.dc.ID()), slog.F("label", c.dc.Label()))
128129
var err error
129130
c.rwc, err = c.dc.Detach()
130131
if err != nil {
132+
c.closeMutex.Unlock()
131133
_ = c.closeWithError(xerrors.Errorf("detach: %w", err))
132134
return
133135
}
136+
c.closeMutex.Unlock()
137+
134138
// pion/webrtc will return an io.ErrShortBuffer when a read
135139
// is triggerred with a buffer size less than the chunks written.
136140
//
@@ -189,9 +193,6 @@ func (c *Channel) init() {
189193
//
190194
// This will block until the underlying DataChannel has been opened.
191195
func (c *Channel) Read(bytes []byte) (int, error) {
192-
if c.isClosed() {
193-
return 0, c.closeError
194-
}
195196
err := c.waitOpened()
196197
if err != nil {
197198
return 0, err
@@ -228,9 +229,6 @@ func (c *Channel) Write(bytes []byte) (n int, err error) {
228229
c.writeMutex.Lock()
229230
defer c.writeMutex.Unlock()
230231

231-
if c.isClosed() {
232-
return 0, c.closeWithError(nil)
233-
}
234232
err = c.waitOpened()
235233
if err != nil {
236234
return 0, err
@@ -308,6 +306,10 @@ func (c *Channel) isClosed() bool {
308306
func (c *Channel) waitOpened() error {
309307
select {
310308
case <-c.opened:
309+
// Re-check the closed channel to prioritize closure.
310+
if c.isClosed() {
311+
return c.closeError
312+
}
311313
return nil
312314
case <-c.closed:
313315
return c.closeError

peer/conn.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package peer
33
import (
44
"bytes"
55
"context"
6-
76
"crypto/rand"
87
"io"
98
"sync"
@@ -256,7 +255,6 @@ func (c *Conn) init() error {
256255
c.logger().Debug(context.Background(), "sending local candidate", slog.F("candidate", iceCandidate.ToJSON().Candidate))
257256
select {
258257
case <-c.closed:
259-
break
260258
case c.localCandidateChannel <- iceCandidate.ToJSON():
261259
}
262260
}()
@@ -265,7 +263,6 @@ func (c *Conn) init() error {
265263
go func() {
266264
select {
267265
case <-c.closed:
268-
return
269266
case c.dcOpenChannel <- dc:
270267
}
271268
}()
@@ -435,9 +432,6 @@ func (c *Conn) pingEchoChannel() (*Channel, error) {
435432
data := make([]byte, pingDataLength)
436433
bytesRead, err := c.pingEchoChan.Read(data)
437434
if err != nil {
438-
if c.isClosed() {
439-
return
440-
}
441435
_ = c.CloseWithError(xerrors.Errorf("read ping echo channel: %w", err))
442436
return
443437
}

peer/conn_test.go

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ func TestConn(t *testing.T) {
9191
// Create a channel that closes on disconnect.
9292
channel, err := server.CreateChannel(context.Background(), "wow", nil)
9393
assert.NoError(t, err)
94+
defer channel.Close()
95+
9496
err = wan.Stop()
9597
require.NoError(t, err)
9698
// Once the connection is marked as disconnected, this
@@ -107,10 +109,13 @@ func TestConn(t *testing.T) {
107109
t.Parallel()
108110
client, server, _ := createPair(t)
109111
exchange(t, client, server)
110-
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
112+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
113+
defer cancel()
114+
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
111115
require.NoError(t, err)
116+
defer cch.Close()
112117

113-
sch, err := server.Accept(context.Background())
118+
sch, err := server.Accept(ctx)
114119
require.NoError(t, err)
115120
defer sch.Close()
116121

@@ -123,9 +128,12 @@ func TestConn(t *testing.T) {
123128
t.Parallel()
124129
client, server, wan := createPair(t)
125130
exchange(t, client, server)
126-
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
131+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
132+
defer cancel()
133+
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
127134
require.NoError(t, err)
128-
sch, err := server.Accept(context.Background())
135+
defer cch.Close()
136+
sch, err := server.Accept(ctx)
129137
require.NoError(t, err)
130138
defer sch.Close()
131139

@@ -140,26 +148,44 @@ func TestConn(t *testing.T) {
140148
t.Parallel()
141149
client, server, _ := createPair(t)
142150
exchange(t, client, server)
143-
cch, err := client.CreateChannel(context.Background(), "hello", &peer.ChannelOptions{})
144-
require.NoError(t, err)
145-
sch, err := server.Accept(context.Background())
151+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
152+
defer cancel()
153+
cch, err := client.CreateChannel(ctx, "hello", &peer.ChannelOptions{})
146154
require.NoError(t, err)
147-
defer sch.Close()
155+
defer cch.Close()
156+
157+
readErr := make(chan error, 1)
148158
go func() {
159+
sch, err := server.Accept(ctx)
160+
if err != nil {
161+
readErr <- err
162+
_ = cch.Close()
163+
return
164+
}
165+
defer sch.Close()
166+
149167
bytes := make([]byte, 4096)
150-
for i := 0; i < 1024; i++ {
151-
_, err := cch.Write(bytes)
152-
require.NoError(t, err)
168+
for {
169+
_, err = sch.Read(bytes)
170+
if err != nil {
171+
readErr <- err
172+
return
173+
}
153174
}
154-
_ = cch.Close()
155175
}()
176+
156177
bytes := make([]byte, 4096)
157-
for {
158-
_, err = sch.Read(bytes)
159-
if err != nil {
160-
require.ErrorIs(t, err, peer.ErrClosed)
161-
break
162-
}
178+
for i := 0; i < 1024; i++ {
179+
_, err = cch.Write(bytes)
180+
require.NoError(t, err, "write i=%d", i)
181+
}
182+
_ = cch.Close()
183+
184+
select {
185+
case err = <-readErr:
186+
require.ErrorIs(t, err, peer.ErrClosed, "read error")
187+
case <-ctx.Done():
188+
require.Fail(t, "timeout waiting for read error")
163189
}
164190
})
165191

@@ -170,13 +196,29 @@ func TestConn(t *testing.T) {
170196
srv, err := net.Listen("tcp", "127.0.0.1:0")
171197
require.NoError(t, err)
172198
defer srv.Close()
199+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
200+
defer cancel()
173201
go func() {
174-
sch, err := server.Accept(context.Background())
175-
assert.NoError(t, err)
202+
sch, err := server.Accept(ctx)
203+
if err != nil {
204+
assert.NoError(t, err)
205+
return
206+
}
207+
defer sch.Close()
208+
176209
nc2 := sch.NetConn()
210+
defer nc2.Close()
211+
177212
nc1, err := net.Dial("tcp", srv.Addr().String())
178-
assert.NoError(t, err)
213+
if err != nil {
214+
assert.NoError(t, err)
215+
return
216+
}
217+
defer nc1.Close()
218+
179219
go func() {
220+
defer nc1.Close()
221+
defer nc2.Close()
180222
_, _ = io.Copy(nc1, nc2)
181223
}()
182224
_, _ = io.Copy(nc2, nc1)
@@ -204,7 +246,7 @@ func TestConn(t *testing.T) {
204246
c := http.Client{
205247
Transport: defaultTransport,
206248
}
207-
req, err := http.NewRequestWithContext(context.Background(), "GET", "http://localhost/", nil)
249+
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost/", nil)
208250
require.NoError(t, err)
209251
resp, err := c.Do(req)
210252
require.NoError(t, err)
@@ -272,14 +314,21 @@ func TestConn(t *testing.T) {
272314
t.Parallel()
273315
client, server, _ := createPair(t)
274316
exchange(t, client, server)
317+
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
318+
defer cancel()
275319
go func() {
276-
channel, err := client.CreateChannel(context.Background(), "test", nil)
277-
assert.NoError(t, err)
320+
channel, err := client.CreateChannel(ctx, "test", nil)
321+
if err != nil {
322+
assert.NoError(t, err)
323+
return
324+
}
325+
defer channel.Close()
278326
_, err = channel.Write([]byte{1, 2})
279327
assert.NoError(t, err)
280328
}()
281-
channel, err := server.Accept(context.Background())
329+
channel, err := server.Accept(ctx)
282330
require.NoError(t, err)
331+
defer channel.Close()
283332
data := make([]byte, 1)
284333
_, err = channel.Read(data)
285334
require.NoError(t, err)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy