Skip to content

Commit c62c0dc

Browse files
authored
Merge pull request #193 from nhooyr/ensure-close
Ensure connection is closed at all error points
2 parents 43c4dc0 + 2e0dd1c commit c62c0dc

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

read.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
304304
defer c.readMu.unlock()
305305

306306
if !c.msgReader.fin {
307-
return 0, nil, errors.New("previous message not read to completion")
307+
err = errors.New("previous message not read to completion")
308+
c.close(fmt.Errorf("failed to get reader: %w", err))
309+
return 0, nil, err
308310
}
309311

310312
h, err := c.readLoop(ctx)
@@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
361363
}
362364

363365
func (mr *msgReader) Read(p []byte) (n int, err error) {
364-
defer func() {
365-
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
366-
err = io.EOF
367-
}
368-
if errors.Is(err, io.EOF) {
369-
err = io.EOF
370-
mr.putFlateReader()
371-
return
372-
}
373-
errd.Wrap(&err, "failed to read")
374-
}()
375-
376366
err = mr.c.readMu.lock(mr.ctx)
377367
if err != nil {
378-
return 0, err
368+
return 0, fmt.Errorf("failed to read: %w", err)
379369
}
380370
defer mr.c.readMu.unlock()
381371

@@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
384374
p = p[:n]
385375
mr.dict.write(p)
386376
}
377+
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
378+
mr.putFlateReader()
379+
return n, io.EOF
380+
}
381+
if err != nil {
382+
err = fmt.Errorf("failed to read: %w", err)
383+
mr.c.close(err)
384+
}
387385
return n, err
388386
}
389387

write.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"errors"
1111
"fmt"
1212
"io"
13-
"sync"
1413
"time"
1514

1615
"github.com/klauspost/compress/flate"
@@ -71,7 +70,7 @@ type msgWriterState struct {
7170
c *Conn
7271

7372
mu *mu
74-
writeMu sync.Mutex
73+
writeMu *mu
7574

7675
ctx context.Context
7776
opcode opcode
@@ -83,8 +82,9 @@ type msgWriterState struct {
8382

8483
func newMsgWriterState(c *Conn) *msgWriterState {
8584
mw := &msgWriterState{
86-
c: c,
87-
mu: newMu(c),
85+
c: c,
86+
mu: newMu(c),
87+
writeMu: newMu(c),
8888
}
8989
return mw
9090
}
@@ -155,10 +155,18 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
155155

156156
// Write writes the given bytes to the WebSocket connection.
157157
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
158-
defer errd.Wrap(&err, "failed to write")
158+
err = mw.writeMu.lock(mw.ctx)
159+
if err != nil {
160+
return 0, fmt.Errorf("failed to write: %w", err)
161+
}
162+
defer mw.writeMu.unlock()
159163

160-
mw.writeMu.Lock()
161-
defer mw.writeMu.Unlock()
164+
defer func() {
165+
if err != nil {
166+
err = fmt.Errorf("failed to write: %w", err)
167+
mw.c.close(err)
168+
}
169+
}()
162170

163171
if mw.c.flate() {
164172
// Only enables flate if the length crosses the
@@ -193,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
193201
func (mw *msgWriterState) Close() (err error) {
194202
defer errd.Wrap(&err, "failed to close writer")
195203

196-
mw.writeMu.Lock()
197-
defer mw.writeMu.Unlock()
204+
err = mw.writeMu.lock(mw.ctx)
205+
if err != nil {
206+
return err
207+
}
208+
defer mw.writeMu.unlock()
198209

199210
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
200211
if err != nil {
@@ -214,7 +225,7 @@ func (mw *msgWriterState) close() {
214225
putBufioWriter(mw.c.bw)
215226
}
216227

217-
mw.writeMu.Lock()
228+
mw.writeMu.forceLock()
218229
mw.dict.close()
219230
}
220231

@@ -230,8 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
230241
}
231242

232243
// frame handles all writes to the connection.
233-
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
234-
err := c.writeFrameMu.lock(ctx)
244+
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
245+
err = c.writeFrameMu.lock(ctx)
235246
if err != nil {
236247
return 0, err
237248
}
@@ -243,6 +254,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
243254
case c.writeTimeout <- ctx:
244255
}
245256

257+
defer func() {
258+
if err != nil {
259+
err = fmt.Errorf("failed to write frame: %w", err)
260+
c.close(err)
261+
}
262+
}()
263+
246264
c.writeHeader.fin = fin
247265
c.writeHeader.opcode = opcode
248266
c.writeHeader.payloadLength = int64(len(p))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy