| OLD | NEW |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/websockets/websocket_deflate_stream.h" | 5 #include "net/websockets/websocket_deflate_stream.h" |
| 6 | 6 |
| 7 #include <stdint.h> | 7 #include <stdint.h> |
| 8 #include <algorithm> | 8 #include <algorithm> |
| 9 #include <string> | 9 #include <string> |
| 10 #include <utility> | 10 #include <utility> |
| (...skipping 22 matching lines...) Expand all Loading... |
| 33 | 33 |
| 34 const int kWindowBits = 15; | 34 const int kWindowBits = 15; |
| 35 const size_t kChunkSize = 4 * 1024; | 35 const size_t kChunkSize = 4 * 1024; |
| 36 | 36 |
| 37 } // namespace | 37 } // namespace |
| 38 | 38 |
| 39 WebSocketDeflateStream::WebSocketDeflateStream( | 39 WebSocketDeflateStream::WebSocketDeflateStream( |
| 40 scoped_ptr<WebSocketStream> stream, | 40 scoped_ptr<WebSocketStream> stream, |
| 41 const WebSocketDeflateParameters& params, | 41 const WebSocketDeflateParameters& params, |
| 42 scoped_ptr<WebSocketDeflatePredictor> predictor) | 42 scoped_ptr<WebSocketDeflatePredictor> predictor) |
| 43 : stream_(stream.Pass()), | 43 : stream_(std::move(stream)), |
| 44 deflater_(params.client_context_take_over_mode()), | 44 deflater_(params.client_context_take_over_mode()), |
| 45 inflater_(kChunkSize, kChunkSize), | 45 inflater_(kChunkSize, kChunkSize), |
| 46 reading_state_(NOT_READING), | 46 reading_state_(NOT_READING), |
| 47 writing_state_(NOT_WRITING), | 47 writing_state_(NOT_WRITING), |
| 48 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText), | 48 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText), |
| 49 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText), | 49 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText), |
| 50 predictor_(predictor.Pass()) { | 50 predictor_(std::move(predictor)) { |
| 51 DCHECK(stream_); | 51 DCHECK(stream_); |
| 52 DCHECK(params.IsValidAsResponse()); | 52 DCHECK(params.IsValidAsResponse()); |
| 53 int client_max_window_bits = 15; | 53 int client_max_window_bits = 15; |
| 54 if (params.is_client_max_window_bits_specified()) { | 54 if (params.is_client_max_window_bits_specified()) { |
| 55 DCHECK(params.has_client_max_window_bits_value()); | 55 DCHECK(params.has_client_max_window_bits_value()); |
| 56 client_max_window_bits = params.client_max_window_bits(); | 56 client_max_window_bits = params.client_max_window_bits(); |
| 57 } | 57 } |
| 58 deflater_.Initialize(client_max_window_bits); | 58 deflater_.Initialize(client_max_window_bits); |
| 59 inflater_.Initialize(kWindowBits); | 59 inflater_.Initialize(kWindowBits); |
| 60 } | 60 } |
| (...skipping 68 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 129 if (writing_state_ == NOT_WRITING) | 129 if (writing_state_ == NOT_WRITING) |
| 130 OnMessageStart(*frames, i); | 130 OnMessageStart(*frames, i); |
| 131 | 131 |
| 132 scoped_ptr<WebSocketFrame> frame(std::move((*frames)[i])); | 132 scoped_ptr<WebSocketFrame> frame(std::move((*frames)[i])); |
| 133 predictor_->RecordInputDataFrame(frame.get()); | 133 predictor_->RecordInputDataFrame(frame.get()); |
| 134 | 134 |
| 135 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) { | 135 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) { |
| 136 if (frame->header.final) | 136 if (frame->header.final) |
| 137 writing_state_ = NOT_WRITING; | 137 writing_state_ = NOT_WRITING; |
| 138 predictor_->RecordWrittenDataFrame(frame.get()); | 138 predictor_->RecordWrittenDataFrame(frame.get()); |
| 139 frames_to_write.push_back(frame.Pass()); | 139 frames_to_write.push_back(std::move(frame)); |
| 140 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | 140 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; |
| 141 } else { | 141 } else { |
| 142 if (frame->data.get() && | 142 if (frame->data.get() && |
| 143 !deflater_.AddBytes( | 143 !deflater_.AddBytes( |
| 144 frame->data->data(), | 144 frame->data->data(), |
| 145 static_cast<size_t>(frame->header.payload_length))) { | 145 static_cast<size_t>(frame->header.payload_length))) { |
| 146 DVLOG(1) << "WebSocket protocol error. " | 146 DVLOG(1) << "WebSocket protocol error. " |
| 147 << "deflater_.AddBytes() returns an error."; | 147 << "deflater_.AddBytes() returns an error."; |
| 148 return ERR_WS_PROTOCOL_ERROR; | 148 return ERR_WS_PROTOCOL_ERROR; |
| 149 } | 149 } |
| 150 if (frame->header.final && !deflater_.Finish()) { | 150 if (frame->header.final && !deflater_.Finish()) { |
| 151 DVLOG(1) << "WebSocket protocol error. " | 151 DVLOG(1) << "WebSocket protocol error. " |
| 152 << "deflater_.Finish() returns an error."; | 152 << "deflater_.Finish() returns an error."; |
| 153 return ERR_WS_PROTOCOL_ERROR; | 153 return ERR_WS_PROTOCOL_ERROR; |
| 154 } | 154 } |
| 155 | 155 |
| 156 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) { | 156 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) { |
| 157 if (deflater_.CurrentOutputSize() >= kChunkSize || | 157 if (deflater_.CurrentOutputSize() >= kChunkSize || |
| 158 frame->header.final) { | 158 frame->header.final) { |
| 159 int result = AppendCompressedFrame(frame->header, &frames_to_write); | 159 int result = AppendCompressedFrame(frame->header, &frames_to_write); |
| 160 if (result != OK) | 160 if (result != OK) |
| 161 return result; | 161 return result; |
| 162 } | 162 } |
| 163 if (frame->header.final) | 163 if (frame->header.final) |
| 164 writing_state_ = NOT_WRITING; | 164 writing_state_ = NOT_WRITING; |
| 165 } else { | 165 } else { |
| 166 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_); | 166 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_); |
| 167 bool final = frame->header.final; | 167 bool final = frame->header.final; |
| 168 frames_of_message.push_back(frame.Pass()); | 168 frames_of_message.push_back(std::move(frame)); |
| 169 if (final) { | 169 if (final) { |
| 170 int result = AppendPossiblyCompressedMessage(&frames_of_message, | 170 int result = AppendPossiblyCompressedMessage(&frames_of_message, |
| 171 &frames_to_write); | 171 &frames_to_write); |
| 172 if (result != OK) | 172 if (result != OK) |
| 173 return result; | 173 return result; |
| 174 frames_of_message.clear(); | 174 frames_of_message.clear(); |
| 175 writing_state_ = NOT_WRITING; | 175 writing_state_ = NOT_WRITING; |
| 176 } | 176 } |
| 177 } | 177 } |
| 178 } | 178 } |
| (...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 221 compressed->header.CopyFrom(header); | 221 compressed->header.CopyFrom(header); |
| 222 compressed->header.opcode = opcode; | 222 compressed->header.opcode = opcode; |
| 223 compressed->header.final = header.final; | 223 compressed->header.final = header.final; |
| 224 compressed->header.reserved1 = | 224 compressed->header.reserved1 = |
| 225 (opcode != WebSocketFrameHeader::kOpCodeContinuation); | 225 (opcode != WebSocketFrameHeader::kOpCodeContinuation); |
| 226 compressed->data = compressed_payload; | 226 compressed->data = compressed_payload; |
| 227 compressed->header.payload_length = compressed_payload->size(); | 227 compressed->header.payload_length = compressed_payload->size(); |
| 228 | 228 |
| 229 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | 229 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; |
| 230 predictor_->RecordWrittenDataFrame(compressed.get()); | 230 predictor_->RecordWrittenDataFrame(compressed.get()); |
| 231 frames_to_write->push_back(compressed.Pass()); | 231 frames_to_write->push_back(std::move(compressed)); |
| 232 return OK; | 232 return OK; |
| 233 } | 233 } |
| 234 | 234 |
| 235 int WebSocketDeflateStream::AppendPossiblyCompressedMessage( | 235 int WebSocketDeflateStream::AppendPossiblyCompressedMessage( |
| 236 std::vector<scoped_ptr<WebSocketFrame>>* frames, | 236 std::vector<scoped_ptr<WebSocketFrame>>* frames, |
| 237 std::vector<scoped_ptr<WebSocketFrame>>* frames_to_write) { | 237 std::vector<scoped_ptr<WebSocketFrame>>* frames_to_write) { |
| 238 DCHECK(!frames->empty()); | 238 DCHECK(!frames->empty()); |
| 239 | 239 |
| 240 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_; | 240 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_; |
| 241 scoped_refptr<IOBufferWithSize> compressed_payload = | 241 scoped_refptr<IOBufferWithSize> compressed_payload = |
| (...skipping 28 matching lines...) Expand all Loading... |
| 270 } | 270 } |
| 271 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode)); | 271 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode)); |
| 272 compressed->header.CopyFrom((*frames)[0]->header); | 272 compressed->header.CopyFrom((*frames)[0]->header); |
| 273 compressed->header.opcode = opcode; | 273 compressed->header.opcode = opcode; |
| 274 compressed->header.final = true; | 274 compressed->header.final = true; |
| 275 compressed->header.reserved1 = true; | 275 compressed->header.reserved1 = true; |
| 276 compressed->data = compressed_payload; | 276 compressed->data = compressed_payload; |
| 277 compressed->header.payload_length = compressed_payload->size(); | 277 compressed->header.payload_length = compressed_payload->size(); |
| 278 | 278 |
| 279 predictor_->RecordWrittenDataFrame(compressed.get()); | 279 predictor_->RecordWrittenDataFrame(compressed.get()); |
| 280 frames_to_write->push_back(compressed.Pass()); | 280 frames_to_write->push_back(std::move(compressed)); |
| 281 return OK; | 281 return OK; |
| 282 } | 282 } |
| 283 | 283 |
| 284 int WebSocketDeflateStream::Inflate( | 284 int WebSocketDeflateStream::Inflate( |
| 285 std::vector<scoped_ptr<WebSocketFrame>>* frames) { | 285 std::vector<scoped_ptr<WebSocketFrame>>* frames) { |
| 286 std::vector<scoped_ptr<WebSocketFrame>> frames_to_output; | 286 std::vector<scoped_ptr<WebSocketFrame>> frames_to_output; |
| 287 std::vector<scoped_ptr<WebSocketFrame>> frames_passed; | 287 std::vector<scoped_ptr<WebSocketFrame>> frames_passed; |
| 288 frames->swap(frames_passed); | 288 frames->swap(frames_passed); |
| 289 for (size_t i = 0; i < frames_passed.size(); ++i) { | 289 for (size_t i = 0; i < frames_passed.size(); ++i) { |
| 290 scoped_ptr<WebSocketFrame> frame(std::move(frames_passed[i])); | 290 scoped_ptr<WebSocketFrame> frame(std::move(frames_passed[i])); |
| 291 frames_passed[i] = NULL; | 291 frames_passed[i] = NULL; |
| 292 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode | 292 DVLOG(3) << "Input frame: opcode=" << frame->header.opcode |
| 293 << " final=" << frame->header.final | 293 << " final=" << frame->header.final |
| 294 << " reserved1=" << frame->header.reserved1 | 294 << " reserved1=" << frame->header.reserved1 |
| 295 << " payload_length=" << frame->header.payload_length; | 295 << " payload_length=" << frame->header.payload_length; |
| 296 | 296 |
| 297 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) { | 297 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) { |
| 298 frames_to_output.push_back(frame.Pass()); | 298 frames_to_output.push_back(std::move(frame)); |
| 299 continue; | 299 continue; |
| 300 } | 300 } |
| 301 | 301 |
| 302 if (reading_state_ == NOT_READING) { | 302 if (reading_state_ == NOT_READING) { |
| 303 if (frame->header.reserved1) | 303 if (frame->header.reserved1) |
| 304 reading_state_ = READING_COMPRESSED_MESSAGE; | 304 reading_state_ = READING_COMPRESSED_MESSAGE; |
| 305 else | 305 else |
| 306 reading_state_ = READING_UNCOMPRESSED_MESSAGE; | 306 reading_state_ = READING_UNCOMPRESSED_MESSAGE; |
| 307 current_reading_opcode_ = frame->header.opcode; | 307 current_reading_opcode_ = frame->header.opcode; |
| 308 } else { | 308 } else { |
| 309 if (frame->header.reserved1) { | 309 if (frame->header.reserved1) { |
| 310 DVLOG(1) << "WebSocket protocol error. " | 310 DVLOG(1) << "WebSocket protocol error. " |
| 311 << "Receiving a non-first frame with RSV1 flag set."; | 311 << "Receiving a non-first frame with RSV1 flag set."; |
| 312 return ERR_WS_PROTOCOL_ERROR; | 312 return ERR_WS_PROTOCOL_ERROR; |
| 313 } | 313 } |
| 314 } | 314 } |
| 315 | 315 |
| 316 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) { | 316 if (reading_state_ == READING_UNCOMPRESSED_MESSAGE) { |
| 317 if (frame->header.final) | 317 if (frame->header.final) |
| 318 reading_state_ = NOT_READING; | 318 reading_state_ = NOT_READING; |
| 319 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | 319 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; |
| 320 frames_to_output.push_back(frame.Pass()); | 320 frames_to_output.push_back(std::move(frame)); |
| 321 } else { | 321 } else { |
| 322 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE); | 322 DCHECK_EQ(reading_state_, READING_COMPRESSED_MESSAGE); |
| 323 if (frame->data.get() && | 323 if (frame->data.get() && |
| 324 !inflater_.AddBytes( | 324 !inflater_.AddBytes( |
| 325 frame->data->data(), | 325 frame->data->data(), |
| 326 static_cast<size_t>(frame->header.payload_length))) { | 326 static_cast<size_t>(frame->header.payload_length))) { |
| 327 DVLOG(1) << "WebSocket protocol error. " | 327 DVLOG(1) << "WebSocket protocol error. " |
| 328 << "inflater_.AddBytes() returns an error."; | 328 << "inflater_.AddBytes() returns an error."; |
| 329 return ERR_WS_PROTOCOL_ERROR; | 329 return ERR_WS_PROTOCOL_ERROR; |
| 330 } | 330 } |
| (...skipping 23 matching lines...) Expand all Loading... |
| 354 inflated->header.CopyFrom(frame->header); | 354 inflated->header.CopyFrom(frame->header); |
| 355 inflated->header.opcode = current_reading_opcode_; | 355 inflated->header.opcode = current_reading_opcode_; |
| 356 inflated->header.final = is_final; | 356 inflated->header.final = is_final; |
| 357 inflated->header.reserved1 = false; | 357 inflated->header.reserved1 = false; |
| 358 inflated->data = data; | 358 inflated->data = data; |
| 359 inflated->header.payload_length = data->size(); | 359 inflated->header.payload_length = data->size(); |
| 360 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode | 360 DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode |
| 361 << " final=" << inflated->header.final | 361 << " final=" << inflated->header.final |
| 362 << " reserved1=" << inflated->header.reserved1 | 362 << " reserved1=" << inflated->header.reserved1 |
| 363 << " payload_length=" << inflated->header.payload_length; | 363 << " payload_length=" << inflated->header.payload_length; |
| 364 frames_to_output.push_back(inflated.Pass()); | 364 frames_to_output.push_back(std::move(inflated)); |
| 365 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; | 365 current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; |
| 366 if (is_final) | 366 if (is_final) |
| 367 break; | 367 break; |
| 368 } | 368 } |
| 369 if (frame->header.final) | 369 if (frame->header.final) |
| 370 reading_state_ = NOT_READING; | 370 reading_state_ = NOT_READING; |
| 371 } | 371 } |
| 372 } | 372 } |
| 373 frames->swap(frames_to_output); | 373 frames->swap(frames_to_output); |
| 374 return frames->empty() ? ERR_IO_PENDING : OK; | 374 return frames->empty() ? ERR_IO_PENDING : OK; |
| (...skipping 18 matching lines...) Expand all Loading... |
| 393 DCHECK(!frames->empty()); | 393 DCHECK(!frames->empty()); |
| 394 | 394 |
| 395 result = Inflate(frames); | 395 result = Inflate(frames); |
| 396 } | 396 } |
| 397 if (result < 0) | 397 if (result < 0) |
| 398 frames->clear(); | 398 frames->clear(); |
| 399 return result; | 399 return result; |
| 400 } | 400 } |
| 401 | 401 |
| 402 } // namespace net | 402 } // namespace net |
| OLD | NEW |