@@ -40,32 +40,45 @@ type WriteFlusher interface {
40
40
}
41
41
42
42
type CompressResponseWriter struct {
43
- http.ResponseWriter
44
- compressWriter WriteFlusher
45
- compressionType string
46
- headersWritten bool
47
- closeNotify chan bool
48
- parentNotify <- chan bool
49
- closed bool
43
+ Header * BufferedServerHeader
44
+ ControllerResponse * Response
45
+ OriginalWriter io.Writer
46
+ compressWriter WriteFlusher
47
+ compressionType string
48
+ headersWritten bool
49
+ closeNotify chan bool
50
+ parentNotify <- chan bool
51
+ closed bool
50
52
}
51
53
52
- // CompressFilter does compresssion of response body in gzip/deflate if
54
+ // CompressFilter does compression of response body in gzip/deflate if
53
55
// `results.compressed=true` in the app.conf
54
56
func CompressFilter (c * Controller , fc []Filter ) {
55
- fc [0 ](c , fc [1 :])
56
- if Config .BoolDefault ("results.compressed" , false ) {
57
+ if c .Response .Out .internalHeader .Server != nil && Config .BoolDefault ("results.compressed" , false ) {
57
58
if c .Response .Status != http .StatusNoContent && c .Response .Status != http .StatusNotModified {
58
- writer := CompressResponseWriter {c .Response .Out , nil , "" , false , make (chan bool , 1 ), nil , false }
59
- writer .DetectCompressionType (c .Request , c .Response )
60
- w , ok := c .Response .Out .(http.CloseNotifier )
61
- if ok {
62
- writer .parentNotify = w .CloseNotify ()
59
+ if found , compressType , compressWriter := detectCompressionType (c .Request , c .Response ); found {
60
+ writer := CompressResponseWriter {
61
+ ControllerResponse : c .Response ,
62
+ OriginalWriter : c .Response .GetWriter (),
63
+ compressWriter : compressWriter ,
64
+ compressionType : compressType ,
65
+ headersWritten : false ,
66
+ closeNotify : make (chan bool , 1 ),
67
+ closed : false ,
68
+ }
69
+ // Swap out the header with our own
70
+ writer .Header = NewBufferedServerHeader (c .Response .Out .internalHeader .Server )
71
+ c .Response .Out .internalHeader .Server = writer .Header
72
+ if w , ok := c .Response .GetWriter ().(http.CloseNotifier ); ok {
73
+ writer .parentNotify = w .CloseNotify ()
74
+ }
75
+ c .Response .SetWriter (& writer )
63
76
}
64
- c .Response .Out = & writer
65
77
} else {
66
78
TRACE .Printf ("Compression disabled for response status (%d)" , c .Response .Status )
67
79
}
68
80
}
81
+ fc [0 ](c , fc [1 :])
69
82
}
70
83
71
84
func (c CompressResponseWriter ) CloseNotify () <- chan bool {
@@ -77,16 +90,19 @@ func (c CompressResponseWriter) CloseNotify() <-chan bool {
77
90
78
91
func (c * CompressResponseWriter ) prepareHeaders () {
79
92
if c .compressionType != "" {
80
- responseMime := c .Header ().Get ("Content-Type" )
93
+ responseMime := ""
94
+ if t := c .Header .Get ("Content-Type" ); len (t ) > 0 {
95
+ responseMime = t [0 ]
96
+ }
81
97
responseMime = strings .TrimSpace (strings .SplitN (responseMime , ";" , 2 )[0 ])
82
98
shouldEncode := false
83
99
84
- if c .Header () .Get ("Content-Encoding" ) == "" {
100
+ if len ( c .Header .Get ("Content-Encoding" )) == 0 {
85
101
for _ , compressableMime := range compressableMimes {
86
102
if responseMime == compressableMime {
87
103
shouldEncode = true
88
- c .Header () .Set ("Content-Encoding" , c .compressionType )
89
- c .Header () .Del ("Content-Length" )
104
+ c .Header .Set ("Content-Encoding" , c .compressionType )
105
+ c .Header .Del ("Content-Length" )
90
106
break
91
107
}
92
108
}
@@ -97,20 +113,26 @@ func (c *CompressResponseWriter) prepareHeaders() {
97
113
c .compressionType = ""
98
114
}
99
115
}
116
+ c .Header .Release ()
100
117
}
101
118
102
119
func (c * CompressResponseWriter ) WriteHeader (status int ) {
103
120
c .headersWritten = true
104
121
c .prepareHeaders ()
105
- c .ResponseWriter . WriteHeader (status )
122
+ c .Header . SetStatus (status )
106
123
}
107
124
108
125
func (c * CompressResponseWriter ) Close () error {
109
- if c . compressionType != "" {
110
- _ = c . compressWriter . Close ()
126
+ if ! c . headersWritten {
127
+ c . prepareHeaders ()
111
128
}
112
- if w , ok := c .ResponseWriter .(io.Closer ); ok {
113
- _ = w .Close ()
129
+ if c .compressionType != "" {
130
+ c .Header .Del ("Content-Length" )
131
+ if err := c .compressWriter .Close (); err != nil {
132
+ // TODO When writing directly to stream, an error will be generated
133
+ ERROR .Println ("Error closing compress writer" , c .compressionType , err )
134
+ }
135
+
114
136
}
115
137
// Non-blocking write to the closenotifier, if we for some reason should
116
138
// get called multiple times
@@ -135,23 +157,22 @@ func (c *CompressResponseWriter) Write(b []byte) (int, error) {
135
157
if c .closed {
136
158
return 0 , io .ErrClosedPipe
137
159
}
160
+
138
161
if ! c .headersWritten {
139
162
c .prepareHeaders ()
140
163
c .headersWritten = true
141
164
}
142
-
143
165
if c .compressionType != "" {
144
166
return c .compressWriter .Write (b )
145
167
}
146
-
147
- return c .ResponseWriter .Write (b )
168
+ return c .OriginalWriter .Write (b )
148
169
}
149
170
150
- // DetectCompressionType method detects the comperssion type
171
+ // DetectCompressionType method detects the compression type
151
172
// from header "Accept-Encoding"
152
- func ( c * CompressResponseWriter ) DetectCompressionType ( req * Request , resp * Response ) {
173
+ func detectCompressionType ( req * Request , resp * Response ) ( found bool , compressionType string , compressionKind WriteFlusher ) {
153
174
if Config .BoolDefault ("results.compressed" , false ) {
154
- acceptedEncodings := strings .Split (req .Request . Header . Get ("Accept-Encoding" ), "," )
175
+ acceptedEncodings := strings .Split (req .GetHttpHeader ("Accept-Encoding" ), "," )
155
176
156
177
largestQ := 0.0
157
178
chosenEncoding := len (compressionTypes )
@@ -216,13 +237,98 @@ func (c *CompressResponseWriter) DetectCompressionType(req *Request, resp *Respo
216
237
return
217
238
}
218
239
219
- c . compressionType = compressionTypes [chosenEncoding ]
240
+ compressionType = compressionTypes [chosenEncoding ]
220
241
221
- switch c . compressionType {
242
+ switch compressionType {
222
243
case "gzip" :
223
- c .compressWriter = gzip .NewWriter (resp .Out )
244
+ compressionKind = gzip .NewWriter (resp .GetWriter ())
245
+ found = true
224
246
case "deflate" :
225
- c .compressWriter = zlib .NewWriter (resp .Out )
247
+ compressionKind = zlib .NewWriter (resp .GetWriter ())
248
+ found = true
226
249
}
227
250
}
251
+ return
252
+ }
253
+
254
+ // BufferedServerHeader will not send content out until the Released is called, from that point on it will act normally
255
+ // It implements all the ServerHeader
256
+ type BufferedServerHeader struct {
257
+ cookieList []string
258
+ headerMap map [string ][]string
259
+ status int
260
+ released bool
261
+ original ServerHeader
262
+ }
263
+
264
+ func NewBufferedServerHeader (o ServerHeader ) * BufferedServerHeader {
265
+ return & BufferedServerHeader {original : o , headerMap : map [string ][]string {}}
266
+ }
267
+ func (bsh * BufferedServerHeader ) SetCookie (cookie string ) {
268
+ if bsh .released {
269
+ bsh .original .SetCookie (cookie )
270
+ } else {
271
+ bsh .cookieList = append (bsh .cookieList , cookie )
272
+ }
273
+ }
274
+ func (bsh * BufferedServerHeader ) GetCookie (key string ) (value ServerCookie , err error ) {
275
+ return bsh .original .GetCookie (key )
276
+ }
277
+ func (bsh * BufferedServerHeader ) Set (key string , value string ) {
278
+ if bsh .released {
279
+ bsh .original .Set (key , value )
280
+ } else {
281
+ bsh .headerMap [key ] = []string {value }
282
+ }
283
+ }
284
+ func (bsh * BufferedServerHeader ) Add (key string , value string ) {
285
+ if bsh .released {
286
+ bsh .original .Set (key , value )
287
+ } else {
288
+ old := []string {}
289
+ if v , found := bsh .headerMap [key ]; found {
290
+ old = v
291
+ }
292
+ bsh .headerMap [key ] = append (old , value )
293
+ }
294
+
295
+ }
296
+ func (bsh * BufferedServerHeader ) Del (key string ) {
297
+ if bsh .released {
298
+ bsh .original .Del (key )
299
+ } else {
300
+ delete (bsh .headerMap , key )
301
+ }
302
+
303
+ }
304
+ func (bsh * BufferedServerHeader ) Get (key string ) (value []string ) {
305
+ if bsh .released {
306
+ value = bsh .original .Get (key )
307
+ } else {
308
+ if v , found := bsh .headerMap [key ]; found && len (v ) > 0 {
309
+ value = v
310
+ } else {
311
+ value = bsh .original .Get (key )
312
+ }
313
+ }
314
+ return
315
+ }
316
+ func (bsh * BufferedServerHeader ) SetStatus (statusCode int ) {
317
+ if bsh .released {
318
+ bsh .original .SetStatus (statusCode )
319
+ } else {
320
+ bsh .status = statusCode
321
+ }
322
+ }
323
+ func (bsh * BufferedServerHeader ) Release () {
324
+ bsh .released = true
325
+ bsh .original .SetStatus (bsh .status )
326
+ for k , v := range bsh .headerMap {
327
+ for _ , r := range v {
328
+ bsh .original .Set (k , r )
329
+ }
330
+ }
331
+ for _ , c := range bsh .cookieList {
332
+ bsh .original .SetCookie (c )
333
+ }
228
334
}
0 commit comments