Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(221)

Side by Side Diff: net/websockets/websocket_deflate_stream.cc

Issue 39193005: Introduce WebSocketDeflatePredictor. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Created 7 years, 1 month ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 // Copyright 2013 The Chromium Authors. All rights reserved. 1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be 2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file. 3 // found in the LICENSE file.
4 4
5 #include "net/websockets/websocket_deflate_stream.h" 5 #include "net/websockets/websocket_deflate_stream.h"
6 6
7 #include <algorithm> 7 #include <algorithm>
8 #include <string> 8 #include <string>
9 9
10 #include "base/bind.h" 10 #include "base/bind.h"
11 #include "base/logging.h" 11 #include "base/logging.h"
12 #include "base/memory/ref_counted.h" 12 #include "base/memory/ref_counted.h"
13 #include "base/memory/scoped_ptr.h" 13 #include "base/memory/scoped_ptr.h"
14 #include "base/memory/scoped_vector.h" 14 #include "base/memory/scoped_vector.h"
15 #include "net/base/completion_callback.h" 15 #include "net/base/completion_callback.h"
16 #include "net/base/io_buffer.h" 16 #include "net/base/io_buffer.h"
17 #include "net/base/net_errors.h" 17 #include "net/base/net_errors.h"
18 #include "net/websockets/websocket_deflate_predictor.h"
18 #include "net/websockets/websocket_deflater.h" 19 #include "net/websockets/websocket_deflater.h"
19 #include "net/websockets/websocket_errors.h" 20 #include "net/websockets/websocket_errors.h"
20 #include "net/websockets/websocket_frame.h" 21 #include "net/websockets/websocket_frame.h"
21 #include "net/websockets/websocket_inflater.h" 22 #include "net/websockets/websocket_inflater.h"
22 #include "net/websockets/websocket_stream.h" 23 #include "net/websockets/websocket_stream.h"
23 24
24 class GURL; 25 class GURL;
25 26
26 namespace net { 27 namespace net {
27 28
28 namespace { 29 namespace {
29 30
30 const int kWindowBits = 15; 31 const int kWindowBits = 15;
31 const size_t kChunkSize = 4 * 1024; 32 const size_t kChunkSize = 4 * 1024;
32 33
33 } // namespace 34 } // namespace
34 35
35 WebSocketDeflateStream::WebSocketDeflateStream( 36 WebSocketDeflateStream::WebSocketDeflateStream(
36 scoped_ptr<WebSocketStream> stream, 37 scoped_ptr<WebSocketStream> stream,
37 WebSocketDeflater::ContextTakeOverMode mode) 38 WebSocketDeflater::ContextTakeOverMode mode,
39 scoped_ptr<WebSocketDeflatePredictor> predictor)
38 : stream_(stream.Pass()), 40 : stream_(stream.Pass()),
39 deflater_(mode), 41 deflater_(mode),
40 inflater_(kChunkSize, kChunkSize), 42 inflater_(kChunkSize, kChunkSize),
41 reading_state_(NOT_READING), 43 reading_state_(NOT_READING),
42 writing_state_(NOT_WRITING), 44 writing_state_(NOT_WRITING),
43 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText), 45 current_reading_opcode_(WebSocketFrameHeader::kOpCodeText),
44 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText) { 46 current_writing_opcode_(WebSocketFrameHeader::kOpCodeText),
47 predictor_(predictor.Pass()) {
45 DCHECK(stream_); 48 DCHECK(stream_);
46 deflater_.Initialize(kWindowBits); 49 deflater_.Initialize(kWindowBits);
47 inflater_.Initialize(kWindowBits); 50 inflater_.Initialize(kWindowBits);
48 } 51 }
49 52
50 WebSocketDeflateStream::~WebSocketDeflateStream() {} 53 WebSocketDeflateStream::~WebSocketDeflateStream() {}
51 54
52 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames, 55 int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames,
53 const CompletionCallback& callback) { 56 const CompletionCallback& callback) {
54 CompletionCallback callback_to_pass = 57 CompletionCallback callback_to_pass =
(...skipping 38 matching lines...) Expand 10 before | Expand all | Expand 10 after
93 return; 96 return;
94 } 97 }
95 98
96 int r = InflateAndReadIfNecessary(frames, callback); 99 int r = InflateAndReadIfNecessary(frames, callback);
97 if (r != ERR_IO_PENDING) 100 if (r != ERR_IO_PENDING)
98 callback.Run(r); 101 callback.Run(r);
99 } 102 }
100 103
101 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) { 104 int WebSocketDeflateStream::Deflate(ScopedVector<WebSocketFrame>* frames) {
102 ScopedVector<WebSocketFrame> frames_to_write; 105 ScopedVector<WebSocketFrame> frames_to_write;
106 // Store frames of the currently processed message if writing_state_ equals to
107 // WRITING_POSSIBLY_COMPRESSED_MESSAGE.
108 ScopedVector<WebSocketFrame> frames_of_message;
103 for (size_t i = 0; i < frames->size(); ++i) { 109 for (size_t i = 0; i < frames->size(); ++i) {
110 DCHECK(!(*frames)[i]->header.reserved1);
111 if (!WebSocketFrameHeader::IsKnownDataOpCode((*frames)[i]->header.opcode)) {
112 frames_to_write.push_back((*frames)[i]);
113 (*frames)[i] = NULL;
114 continue;
115 }
116 if (writing_state_ == NOT_WRITING)
117 OnMessageStart(*frames, i);
118
104 scoped_ptr<WebSocketFrame> frame((*frames)[i]); 119 scoped_ptr<WebSocketFrame> frame((*frames)[i]);
105 (*frames)[i] = NULL; 120 (*frames)[i] = NULL;
106 DCHECK(!frame->header.reserved1); 121 predictor_->RecordInputDataFrame(frame.get());
107 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
108 frames_to_write.push_back(frame.release());
109 continue;
110 }
111 122
112 if (writing_state_ == NOT_WRITING) {
113 current_writing_opcode_ = frame->header.opcode;
114 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
115 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
116 // TODO(yhirano): For now, we unconditionally compress data messages.
117 // Further optimization is needed.
118 // http://crbug.com/163882
119 writing_state_ = WRITING_COMPRESSED_MESSAGE;
120 }
121 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) { 123 if (writing_state_ == WRITING_UNCOMPRESSED_MESSAGE) {
122 if (frame->header.final) 124 if (frame->header.final)
123 writing_state_ = NOT_WRITING; 125 writing_state_ = NOT_WRITING;
126 predictor_->RecordWrittenDataFrame(frame.get());
124 frames_to_write.push_back(frame.release()); 127 frames_to_write.push_back(frame.release());
125 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; 128 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
126 } else { 129 } else {
127 DCHECK_EQ(WRITING_COMPRESSED_MESSAGE, writing_state_);
128 if (frame->data && !deflater_.AddBytes(frame->data->data(), 130 if (frame->data && !deflater_.AddBytes(frame->data->data(),
129 frame->header.payload_length)) { 131 frame->header.payload_length)) {
130 DVLOG(1) << "WebSocket protocol error. " 132 DVLOG(1) << "WebSocket protocol error. "
131 << "deflater_.AddBytes() returns an error."; 133 << "deflater_.AddBytes() returns an error.";
132 return ERR_WS_PROTOCOL_ERROR; 134 return ERR_WS_PROTOCOL_ERROR;
133 } 135 }
134 if (frame->header.final && !deflater_.Finish()) { 136 if (frame->header.final && !deflater_.Finish()) {
135 DVLOG(1) << "WebSocket protocol error. " 137 DVLOG(1) << "WebSocket protocol error. "
136 << "deflater_.Finish() returns an error."; 138 << "deflater_.Finish() returns an error.";
137 return ERR_WS_PROTOCOL_ERROR; 139 return ERR_WS_PROTOCOL_ERROR;
138 } 140 }
139 if (deflater_.CurrentOutputSize() >= kChunkSize || frame->header.final) { 141
140 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_; 142 if (writing_state_ == WRITING_COMPRESSED_MESSAGE) {
141 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode)); 143 if (deflater_.CurrentOutputSize() >= kChunkSize ||
142 scoped_refptr<IOBufferWithSize> data = 144 frame->header.final) {
143 deflater_.GetOutput(deflater_.CurrentOutputSize()); 145 int result = AppendCompressedFrame(frame->header, &frames_to_write);
144 if (!data) { 146 if (result != OK)
145 DVLOG(1) << "WebSocket protocol error. " 147 return result;
146 << "deflater_.GetOutput() returns an error.";
147 return ERR_WS_PROTOCOL_ERROR;
148 } 148 }
149 compressed->header.CopyFrom(frame->header); 149 if (frame->header.final)
150 compressed->header.opcode = opcode; 150 writing_state_ = NOT_WRITING;
151 compressed->header.final = frame->header.final; 151 } else {
152 compressed->header.reserved1 = 152 DCHECK_EQ(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
153 (opcode != WebSocketFrameHeader::kOpCodeContinuation); 153 bool final = frame->header.final;
154 compressed->data = data; 154 frames_of_message.push_back(frame.release());
155 compressed->header.payload_length = data->size(); 155 if (final) {
156 156 int result = AppendPossiblyCompressedMessage(&frames_of_message,
157 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; 157 &frames_to_write);
158 frames_to_write.push_back(compressed.release()); 158 if (result != OK)
159 return result;
160 frames_of_message.clear();
161 writing_state_ = NOT_WRITING;
162 }
159 } 163 }
160 if (frame->header.final)
161 writing_state_ = NOT_WRITING;
162 } 164 }
163 } 165 }
166 DCHECK_NE(WRITING_POSSIBLY_COMPRESSED_MESSAGE, writing_state_);
164 frames->swap(frames_to_write); 167 frames->swap(frames_to_write);
165 return OK; 168 return OK;
166 } 169 }
167 170
171 void WebSocketDeflateStream::OnMessageStart(
172 const ScopedVector<WebSocketFrame>& frames, size_t index) {
173 WebSocketFrame* frame = frames[index];
174 current_writing_opcode_ = frame->header.opcode;
175 DCHECK(current_writing_opcode_ == WebSocketFrameHeader::kOpCodeText ||
176 current_writing_opcode_ == WebSocketFrameHeader::kOpCodeBinary);
177 WebSocketDeflatePredictor::Result prediction =
178 predictor_->Predict(frames, index);
179
180 switch (prediction) {
181 case WebSocketDeflatePredictor::DEFLATE:
182 writing_state_ = WRITING_COMPRESSED_MESSAGE;
183 return;
184 case WebSocketDeflatePredictor::DO_NOT_DEFLATE:
185 writing_state_ = WRITING_UNCOMPRESSED_MESSAGE;
186 return;
187 case WebSocketDeflatePredictor::TRY_DEFLATE:
188 writing_state_ = WRITING_POSSIBLY_COMPRESSED_MESSAGE;
189 return;
190 }
191 NOTREACHED();
192 }
193
194 int WebSocketDeflateStream::AppendCompressedFrame(
195 const WebSocketFrameHeader& header,
196 ScopedVector<WebSocketFrame>* frames_to_write) {
197 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
198 scoped_refptr<IOBufferWithSize> data =
199 deflater_.GetOutput(deflater_.CurrentOutputSize());
200 if (!data) {
201 DVLOG(1) << "WebSocket protocol error. "
202 << "deflater_.GetOutput() returns an error.";
203 return ERR_WS_PROTOCOL_ERROR;
204 }
205 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
206 compressed->header.CopyFrom(header);
207 compressed->header.opcode = opcode;
208 compressed->header.final = header.final;
209 compressed->header.reserved1 =
210 (opcode != WebSocketFrameHeader::kOpCodeContinuation);
211 compressed->data = data;
212 compressed->header.payload_length = data->size();
213
214 current_writing_opcode_ = WebSocketFrameHeader::kOpCodeContinuation;
215 predictor_->RecordWrittenDataFrame(compressed.get());
216 frames_to_write->push_back(compressed.release());
217 return OK;
218 }
219
220 int WebSocketDeflateStream::AppendPossiblyCompressedMessage(
221 ScopedVector<WebSocketFrame>* frames,
222 ScopedVector<WebSocketFrame>* frames_to_write) {
223 DCHECK(!frames->empty());
224
225 const WebSocketFrameHeader::OpCode opcode = current_writing_opcode_;
226 scoped_refptr<IOBufferWithSize> data =
227 deflater_.GetOutput(deflater_.CurrentOutputSize());
228 if (!data) {
229 DVLOG(1) << "WebSocket protocol error. "
230 << "deflater_.GetOutput() returns an error.";
231 return ERR_WS_PROTOCOL_ERROR;
232 }
233
234 uint64 original_payload_length = 0;
235 for (size_t i = 0; i < frames->size(); ++i) {
236 WebSocketFrame* frame = (*frames)[i];
237 // Asserts checking that frames represent one whole data message.
238 DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode));
239 DCHECK_EQ(i == 0,
240 WebSocketFrameHeader::kOpCodeContinuation !=
241 frame->header.opcode);
242 DCHECK_EQ(i == frames->size() - 1, frame->header.final);
243 original_payload_length += frame->header.payload_length;
244 }
245 if (original_payload_length <= static_cast<uint64>(data->size())) {
246 // Compression is not effective. Use the original frames.
247 for (size_t i = 0; i < frames->size(); ++i) {
248 WebSocketFrame* frame = (*frames)[i];
249 frames_to_write->push_back(frame);
250 predictor_->RecordWrittenDataFrame(frame);
251 (*frames)[i] = NULL;
252 }
253 frames->weak_clear();
254 return OK;
255 }
256 scoped_ptr<WebSocketFrame> compressed(new WebSocketFrame(opcode));
257 compressed->header.CopyFrom((*frames)[0]->header);
258 compressed->header.opcode = opcode;
259 compressed->header.final = true;
260 compressed->header.reserved1 = true;
261 compressed->data = data;
262 compressed->header.payload_length = data->size();
263
264 predictor_->RecordWrittenDataFrame(compressed.get());
265 frames_to_write->push_back(compressed.release());
266 return OK;
267 }
268
168 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) { 269 int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) {
169 ScopedVector<WebSocketFrame> frames_to_output; 270 ScopedVector<WebSocketFrame> frames_to_output;
170 ScopedVector<WebSocketFrame> frames_passed; 271 ScopedVector<WebSocketFrame> frames_passed;
171 frames->swap(frames_passed); 272 frames->swap(frames_passed);
172 for (size_t i = 0; i < frames_passed.size(); ++i) { 273 for (size_t i = 0; i < frames_passed.size(); ++i) {
173 scoped_ptr<WebSocketFrame> frame(frames_passed[i]); 274 scoped_ptr<WebSocketFrame> frame(frames_passed[i]);
174 frames_passed[i] = NULL; 275 frames_passed[i] = NULL;
175 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) { 276 if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) {
176 frames_to_output.push_back(frame.release()); 277 frames_to_output.push_back(frame.release());
177 continue; 278 continue;
(...skipping 83 matching lines...) Expand 10 before | Expand all | Expand 10 after
261 DCHECK_EQ(OK, result); 362 DCHECK_EQ(OK, result);
262 DCHECK(!frames->empty()); 363 DCHECK(!frames->empty());
263 result = Inflate(frames); 364 result = Inflate(frames);
264 } 365 }
265 if (result < 0) 366 if (result < 0)
266 frames->clear(); 367 frames->clear();
267 return result; 368 return result;
268 } 369 }
269 370
270 } // namespace net 371 } // namespace net
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698