OLD | NEW |
| (Empty) |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "net/websockets/websocket_deflate_stream.h" | |
6 | |
7 #include <algorithm> | |
8 #include <string> | |
9 | |
10 #include "base/bind.h" | |
11 #include "base/logging.h" | |
12 #include "base/memory/ref_counted.h" | |
13 #include "base/memory/scoped_ptr.h" | |
14 #include "base/memory/scoped_vector.h" | |
15 #include "net/base/completion_callback.h" | |
16 #include "net/base/io_buffer.h" | |
17 #include "net/base/net_errors.h" | |
18 #include "net/websockets/websocket_deflate_predictor.h" | |
19 #include "net/websockets/websocket_deflater.h" | |
20 #include "net/websockets/websocket_errors.h" | |
21 #include "net/websockets/websocket_frame.h" | |
22 #include "net/websockets/websocket_inflater.h" | |
23 #include "net/websockets/websocket_stream.h" | |
24 | |
25 class GURL; | |
26 | |
27 namespace net { | |
28 | |
29 namespace { | |
30 | |
31 const int kWindowBits = 15; | |
32 const size_t kChunkSize = 4 * 1024; | |
33 | |
34 } // namespace | |
35 | |
36 WebSocketDeflateStream::WebSocketDeflateStream( | |
37 scoped_ptr<WebSocketStream> stream, | |
38 WebSocketDeflater::ContextTakeOverMode mode, | |
39 int client_window_bits, | |
40 scoped_ptr<WebSocketDeflatePredictor> predictor) | |
41 : stream_(stream.Pass()), | |
42 deflater_(mode), | |
43 inflater_(kChunkSize, kChunkSize), | |
44 reading_state_(NOT_READING), | |
45 writing_state_(NOT_WRITING), | |
46 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText), | |
47 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText), | |
48 predictor_(predictor.Pass()) { | |
49 DCHECK(stream_); | |
50 DCHECK_GE(client_window_bits, 8); | |
51 DCHECK_LE(client_window_bits, 15); | |
52 deflater_.Initialize(client_window_bits); | |
53 inflater_.Initialize(kWindowBits); | |
54 } | |
55 | |
56 WebSocketDeflateStream::~WebSocketDeflateStream() {} | |
57 | |
58 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames, | |
59 const CompletionCallback& callback) { | |
60 int result = stream_->ReadFrames( | |
61 frames, | |
62 base::Bind(&WebSocketDeflateStream::OnReadComplete, | |
63 base::Unretained(this), | |
64 base::Unretained(frames), | |
65 callback)); | |
66 if (result < 0) | |
67 return result; | |
68 DCHECK_EQ(OK, result); | |
69 DCHECK(!frames->empty()); | |
70 | |
71 return InflateAndReadIfNecessary(frames, callback); | |
72 } | |
73 | |
74 int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames, | |
75 const CompletionCallback& callback) { | |
76 int result = Deflate(frames); | |
77 if (result != OK) | |
78 return result; | |
79 if (frames->empty()) | |
80 return OK; | |
81 return stream_->WriteFrames(frames, callback); | |
82 } | |
83 | |
84 void WebSocketDeflateStream::Close() { stream_->Close(); } | |
85 | |
86 std::string WebSocketDeflateStream::GetSubProtocol() const { | |
87 return stream_->GetSubProtocol(); | |
88 } | |
89 | |
90 std::string WebSocketDeflateStream::GetExtensions() const { | |
91 return stream_->GetExtensions(); | |
92 } | |
93 | |
94 void WebSocketDeflateStream::OnReadComplete( | |
95 ScopedVector<WebSocketFrame>* frames, | |
96 const CompletionCallback& callback, | |
97 int result) { | |
98 if (result != OK) { | |
99 frames->clear(); | |
100 callback.Run(result); | |
101 return; | |
102 } | |
103 | |
104 int r = InflateAndReadIfNecessary(frames, callback); | |
105 if (r != ERR_IO_PENDING) | |
106 callback.Run(r); | |
107 } | |
108 | |
109 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) { | |
110 ScopedVector<WebSocketFrame> frames_to_write; | |
111 // Store frames of the currently processed message if writing_state_ equals to | |
112 // WRITING_POSSIBLY_COMPRESSED_MESSAGE. | |
113 ScopedVector<WebSocketFrame> frames_of_message; | |
114 for (size_t i = 0; i < frames->size(); ++i) { | |
115 DCHECK(!(*frames)[i]->header.reserved1); | |
116 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) { | |
117 frames_to_write.push_back((*frames)[i]); | |
118 (*frames)[i] = NULL; | |
119 continue; | |
120 } | |
121 if (writing_state_ == NOT_WRITING) | |
122 OnMessageStart(*frames, i); | |
123 | |
124 scoped_ptr<WebSocketFrame> frame((*frames)[i]); | |
125 (*frames)[i] = NULL; | |
126 predictor_->RecordInputDataFrame(frame.get()); | |
127 | |
128 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) { | |
129 if (frame->header.final) | |
130 writing_state_ = NOT_WRITING; | |
131 predictor_->RecordWrittenDataFrame(frame.get()); | |
132 frames_to_write.push_back(frame.release()); | |
133 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | |
134 } else { | |
135 if (frame->data.get() && | |
136 !deflater_.AddBytes( | |
137 frame->data->data(), | |
138 static_cast<size_t>(frame->header.payload_length))) { | |
139 DVLOG(1) << "WebSocket protocol error. " | |
140 << "deflater_.AddBytes() returns an error."; | |
141 return ERR_WS_PROTOCOL_ERROR; | |
142 } | |
143 if (frame->header.final && !deflater_.Finish()) { | |
144 DVLOG(1) << "WebSocket protocol error. " | |
145 << "deflater_.Finish() returns an error."; | |
146 return ERR_WS_PROTOCOL_ERROR; | |
147 } | |
148 | |
149 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) { | |
150 if (deflater_.CurrentOutputSize() >= kChunkSize || | |
151 frame->header.final) { | |
152 int result = AppendCompressedFrame(frame->header, &frames_to_write); | |
153 if (result != OK) | |
154 return result; | |
155 } | |
156 if (frame->header.final) | |
157 writing_state_ = NOT_WRITING; | |
158 } else { | |
159 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_); | |
160 bool final = frame->header.final; | |
161 frames_of_message.push_back(frame.release()); | |
162 if (final) { | |
163 int result = AppendPossiblyCompressedMessage(&frames_of_message, | |
164 &frames_to_write); | |
165 if (result != OK) | |
166 return result; | |
167 frames_of_message.clear(); | |
168 writing_state_ = NOT_WRITING; | |
169 } | |
170 } | |
171 } | |
172 } | |
173 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_); | |
174 frames->swap(frames_to_write); | |
175 return OK; | |
176 } | |
177 | |
178 void WebSocketDeflateStream::OnMessageStart( | |
179 const ScopedVector<WebSocketFrame>& frames, size_t index) { | |
180 WebSocketFrame* frame = frames[index]; | |
181 current_writing_opcode_ = frame->header.opcode; | |
182 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText || | |
183 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary); | |
184 WebSocketDeflatePredictor::Result prediction = | |
185 predictor_->Predict(frames, index); | |
186 | |
187 switch (prediction) { | |
188 case WebSocketDeflatePredictor::DEFLATE: | |
189 writing_state_ = WRITING_COMPRESSED_MESSAGE; | |
190 return; | |
191 case WebSocketDeflatePredictor::DO_NOT_DEFLATE: | |
192 writing_state_ = WRITING_UNCOMPRESSED_MESSAGE; | |
193 return; | |
194 case WebSocketDeflatePredictor::TRY_DEFLATE: | |
195 writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE; | |
196 return; | |
197 } | |
198 NOTREACHED(); | |
199 } | |
200 | |
201 int WebSocketDeflateStream::AppendCompressedFrame( | |
202 const WebSocketFrameHeader& header, | |
203 ScopedVector<WebSocketFrame>* frames_to_write) { | |
204 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_; | |
205 scoped_refptr<IOBufferWithSize> compressed_payload = | |
206 deflater_.GetOutput(deflater_.CurrentOutputSize()); | |
207 if (!compressed_payload.get()) { | |
208 DVLOG(1) << "WebSocket protocol error. " | |
209 << "deflater_.GetOutput() returns an error."; | |
210 return ERR_WS_PROTOCOL_ERROR; | |
211 } | |
212 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode)); | |
213 compressed->header.CopyFrom(header); | |
214 compressed->header.opcode = opcode; | |
215 compressed->header.final = header.final; | |
216 compressed->header.reserved1 = | |
217 (opcode != WebSocketFrameHeader::kOpCodeContinuation); | |
218 compressed->data = compressed_payload; | |
219 compressed->header.payload_length = compressed_payload->size(); | |
220 | |
221 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | |
222 predictor_->RecordWrittenDataFrame(compressed.get()); | |
223 frames_to_write->push_back(compressed.release()); | |
224 return OK; | |
225 } | |
226 | |
227 int WebSocketDeflateStream::AppendPossiblyCompressedMessage( | |
228 ScopedVector<WebSocketFrame>* frames, | |
229 ScopedVector<WebSocketFrame>* frames_to_write) { | |
230 DCHECK(!frames->empty()); | |
231 | |
232 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_; | |
233 scoped_refptr<IOBufferWithSize> compressed_payload = | |
234 deflater_.GetOutput(deflater_.CurrentOutputSize()); | |
235 if (!compressed_payload.get()) { | |
236 DVLOG(1) << "WebSocket protocol error. " | |
237 << "deflater_.GetOutput() returns an error."; | |
238 return ERR_WS_PROTOCOL_ERROR; | |
239 } | |
240 | |
241 uint64 original_payload_length = 0; | |
242 for (size_t i = 0; i < frames->size(); ++i) { | |
243 WebSocketFrame* frame = (*frames)[i]; | |
244 // Asserts checking that frames represent one whole data message. | |
245 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)); | |
246 DCHECK_EQ(i == 0, | |
247 WebSocketFrameHeader::kOpCodeContinuation != | |
248 frame->header.opcode); | |
249 DCHECK_EQ(i == frames->size() - 1, frame->header.final); | |
250 original_payload_length += frame->header.payload_length; | |
251 } | |
252 if (original_payload_length <= | |
253 static_cast<uint64>(compressed_payload->size())) { | |
254 // Compression is not effective. Use the original frames. | |
255 for (size_t i = 0; i < frames->size(); ++i) { | |
256 WebSocketFrame* frame = (*frames)[i]; | |
257 frames_to_write->push_back(frame); | |
258 predictor_->RecordWrittenDataFrame(frame); | |
259 (*frames)[i] = NULL; | |
260 } | |
261 frames->weak_clear(); | |
262 return OK; | |
263 } | |
264 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode)); | |
265 compressed->header.CopyFrom((*frames)[0]->header); | |
266 compressed->header.opcode = opcode; | |
267 compressed->header.final = true; | |
268 compressed->header.reserved1 = true; | |
269 compressed->data = compressed_payload; | |
270 compressed->header.payload_length = compressed_payload->size(); | |
271 | |
272 predictor_->RecordWrittenDataFrame(compressed.get()); | |
273 frames_to_write->push_back(compressed.release()); | |
274 return OK; | |
275 } | |
276 | |
277 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) { | |
278 ScopedVector<WebSocketFrame> frames_to_output; | |
279 ScopedVector<WebSocketFrame> frames_passed; | |
280 frames->swap(frames_passed); | |
281 for (size_t i = 0; i < frames_passed.size(); ++i) { | |
282 scoped_ptr<WebSocketFrame> frame(frames_passed[i]); | |
283 frames_passed[i] = NULL; | |
284 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode | |
285 << " final=" << frame->header.final | |
286 << " reserved1=" << frame->header.reserved1 | |
287 << " payload_length=" << frame->header.payload_length; | |
288 | |
289 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) { | |
290 frames_to_output.push_back(frame.release()); | |
291 continue; | |
292 } | |
293 | |
294 if (reading_state_ == NOT_READING) { | |
295 if (frame->header.reserved1) | |
296 reading_state_ = READING_COMPRESSED_MESSAGE; | |
297 else | |
298 reading_state_ = READING_UNCOMPRESSED_MESSAGE; | |
299 current_reading_opcode_ = frame->header.opcode; | |
300 } else { | |
301 if (frame->header.reserved1) { | |
302 DVLOG(1) << "WebSocket protocol error. " | |
303 << "Receiving a non-first frame with RSV1 flag set."; | |
304 return ERR_WS_PROTOCOL_ERROR; | |
305 } | |
306 } | |
307 | |
308 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) { | |
309 if (frame->header.final) | |
310 reading_state_ = NOT_READING; | |
311 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | |
312 frames_to_output.push_back(frame.release()); | |
313 } else { | |
314 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE); | |
315 if (frame->data.get() && | |
316 !inflater_.AddBytes( | |
317 frame->data->data(), | |
318 static_cast<size_t>(frame->header.payload_length))) { | |
319 DVLOG(1) << "WebSocket protocol error. " | |
320 << "inflater_.AddBytes() returns an error."; | |
321 return ERR_WS_PROTOCOL_ERROR; | |
322 } | |
323 if (frame->header.final) { | |
324 if (!inflater_.Finish()) { | |
325 DVLOG(1) << "WebSocket protocol error. " | |
326 << "inflater_.Finish() returns an error."; | |
327 return ERR_WS_PROTOCOL_ERROR; | |
328 } | |
329 } | |
330 // TODO(yhirano): Many frames can be generated by the inflater and | |
331 // memory consumption can grow. | |
332 // We could avoid it, but avoiding it makes this class much more | |
333 // complicated. | |
334 while (inflater_.CurrentOutputSize() >= kChunkSize || | |
335 frame->header.final) { | |
336 size_t size = std::min(kChunkSize, inflater_.CurrentOutputSize()); | |
337 scoped_ptr<WebSocketFrame> inflated( | |
338 new WebSocketFrame(WebSocketFrameHeader::kOpCodeText)); | |
339 scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size); | |
340 bool is_final = !inflater_.CurrentOutputSize() && frame->header.final; | |
341 if (!data.get()) { | |
342 DVLOG(1) << "WebSocket protocol error. " | |
343 << "inflater_.GetOutput() returns an error."; | |
344 return ERR_WS_PROTOCOL_ERROR; | |
345 } | |
346 inflated->header.CopyFrom(frame->header); | |
347 inflated->header.opcode = current_reading_opcode_; | |
348 inflated->header.final = is_final; | |
349 inflated->header.reserved1 = false; | |
350 inflated->data = data; | |
351 inflated->header.payload_length = data->size(); | |
352 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode | |
353 << " final=" << inflated->header.final | |
354 << " reserved1=" << inflated->header.reserved1 | |
355 << " payload_length=" << inflated->header.payload_length; | |
356 frames_to_output.push_back(inflated.release()); | |
357 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | |
358 if (is_final) | |
359 break; | |
360 } | |
361 if (frame->header.final) | |
362 reading_state_ = NOT_READING; | |
363 } | |
364 } | |
365 frames->swap(frames_to_output); | |
366 return frames->empty() ? ERR_IO_PENDING : OK; | |
367 } | |
368 | |
369 int WebSocketDeflateStream::InflateAndReadIfNecessary( | |
370 ScopedVector<WebSocketFrame>* frames, | |
371 const CompletionCallback& callback) { | |
372 int result = Inflate(frames); | |
373 while (result == ERR_IO_PENDING) { | |
374 DCHECK(frames->empty()); | |
375 | |
376 result = stream_->ReadFrames( | |
377 frames, | |
378 base::Bind(&WebSocketDeflateStream::OnReadComplete, | |
379 base::Unretained(this), | |
380 base::Unretained(frames), | |
381 callback)); | |
382 if (result < 0) | |
383 break; | |
384 DCHECK_EQ(OK, result); | |
385 DCHECK(!frames->empty()); | |
386 | |
387 result = Inflate(frames); | |
388 } | |
389 if (result < 0) | |
390 frames->clear(); | |
391 return result; | |
392 } | |
393 | |
394 } // namespace net | |
OLD | NEW |