OLD | NEW |
---|---|
1 // Copyright 2014 The Chromium Authors. All rights reserved. | 1 // Copyright 2014 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/server/web_socket_encoder.h" | 5 #include "net/server/web_socket_encoder.h" |
6 | 6 |
7 #include <vector> | |
8 | |
7 #include "base/logging.h" | 9 #include "base/logging.h" |
8 #include "base/strings/string_number_conversions.h" | 10 #include "base/strings/string_number_conversions.h" |
9 #include "base/strings/stringprintf.h" | 11 #include "base/strings/stringprintf.h" |
10 #include "net/base/io_buffer.h" | 12 #include "net/base/io_buffer.h" |
13 #include "net/websockets/websocket_deflate_parameters.h" | |
14 #include "net/websockets/websocket_extension.h" | |
11 #include "net/websockets/websocket_extension_parser.h" | 15 #include "net/websockets/websocket_extension_parser.h" |
12 | 16 |
13 namespace net { | 17 namespace net { |
14 | 18 |
15 const char WebSocketEncoder::kClientExtensions[] = | 19 const char WebSocketEncoder::kClientExtensions[] = |
16 "Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits"; | 20 "Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits"; |
17 | 21 |
18 namespace { | 22 namespace { |
19 | 23 |
20 const int kInflaterChunkSize = 16 * 1024; | 24 const int kInflaterChunkSize = 16 * 1024; |
(...skipping 152 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
173 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]); | 177 frame.push_back(data[i] ^ mask_bytes[i % kMaskingKeyWidthInBytes]); |
174 } else { | 178 } else { |
175 frame.insert(frame.end(), data, data + data_length); | 179 frame.insert(frame.end(), data, data + data_length); |
176 } | 180 } |
177 *output = std::string(&frame[0], frame.size()); | 181 *output = std::string(&frame[0], frame.size()); |
178 } | 182 } |
179 | 183 |
180 } // anonymous namespace | 184 } // anonymous namespace |
181 | 185 |
182 // static | 186 // static |
183 WebSocketEncoder* WebSocketEncoder::CreateServer( | 187 scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer() { |
188 return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr)); | |
189 } | |
190 | |
191 // static | |
192 scoped_ptr<WebSocketEncoder> WebSocketEncoder::CreateServer( | |
184 const std::string& request_extensions, | 193 const std::string& request_extensions, |
185 std::string* response_extensions) { | 194 WebSocketDeflateParameters* deflate_parameters) { |
186 bool deflate; | 195 WebSocketExtensionParser parser; |
187 bool has_client_window_bits; | 196 if (!parser.Parse(request_extensions)) { |
188 int client_window_bits; | 197 // Failed to parse Sec-WebSocket-Extensions header. We MUST fail the |
189 int server_window_bits; | 198 // connection. |
190 bool client_no_context_takeover; | 199 return nullptr; |
191 bool server_no_context_takeover; | 200 } |
192 ParseExtensions(request_extensions, &deflate, &has_client_window_bits, | |
193 &client_window_bits, &server_window_bits, | |
194 &client_no_context_takeover, &server_no_context_takeover); | |
195 | 201 |
196 if (deflate) { | 202 WebSocketDeflateParameters offered, response; |
197 *response_extensions = base::StringPrintf( | 203 std::string failure_message; |
198 "permessage-deflate; server_max_window_bits=%d%s", server_window_bits, | 204 bool found = false; |
199 server_no_context_takeover ? "; server_no_context_takeover" : ""); | 205 for (const auto& extension : parser.extensions()) { |
200 if (has_client_window_bits) { | 206 WebSocketDeflateParameters params; |
201 base::StringAppendF(response_extensions, "; client_max_window_bits=%d", | 207 if (!params.Initialize(extension, &failure_message) || |
202 client_window_bits); | 208 !params.IsValidAsRequest(&failure_message)) { |
203 } else { | 209 // We decline unknown / malformed extensions. |
204 DCHECK_EQ(client_window_bits, 15); | 210 continue; |
205 } | 211 } |
206 return new WebSocketEncoder(true /* is_server */, server_window_bits, | 212 found = true; |
207 client_window_bits, server_no_context_takeover); | 213 offered = params; |
208 } else { | 214 break; |
209 *response_extensions = std::string(); | |
210 return new WebSocketEncoder(true /* is_server */); | |
211 } | 215 } |
216 if (!found) | |
217 return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr)); | |
218 | |
219 response = offered; | |
220 if (offered.is_client_max_window_bits_specified() && | |
221 !offered.has_client_max_window_bits_value()) { | |
222 // We need to choose one value for the response. | |
223 response.SetClientMaxWindowBits(8); | |
224 } | |
225 DCHECK(response.IsValidAsResponse()); | |
226 DCHECK(offered.IsCompatibleWith(response)); | |
227 | |
228 auto deflater = make_scoped_ptr( | |
229 new WebSocketDeflater(response.server_context_take_over_mode())); | |
230 auto inflater = make_scoped_ptr( | |
231 new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize)); | |
232 if (!deflater->Initialize(response.PermissiveServerMaxWindowBits()) || | |
233 !inflater->Initialize(response.PermissiveClientMaxWindowBits())) { | |
234 return make_scoped_ptr(new WebSocketEncoder(FOR_SERVER, nullptr, nullptr)); | |
235 } | |
236 *deflate_parameters = response; | |
237 return make_scoped_ptr( | |
238 new WebSocketEncoder(FOR_SERVER, deflater.Pass(), inflater.Pass())); | |
212 } | 239 } |
213 | 240 |
214 // static | 241 // static |
215 WebSocketEncoder* WebSocketEncoder::CreateClient( | 242 WebSocketEncoder* WebSocketEncoder::CreateClient( |
216 const std::string& response_extensions) { | 243 const std::string& response_extensions) { |
217 bool deflate; | 244 // TODO(yhirano): Add a way to return an error. |
218 bool has_client_window_bits; | |
219 int client_window_bits; | |
220 int server_window_bits; | |
221 bool client_no_context_takeover; | |
222 bool server_no_context_takeover; | |
223 ParseExtensions(response_extensions, &deflate, &has_client_window_bits, | |
224 &client_window_bits, &server_window_bits, | |
225 &client_no_context_takeover, &server_no_context_takeover); | |
226 | 245 |
227 if (deflate) { | 246 WebSocketExtensionParser parser; |
228 return new WebSocketEncoder(false /* is_server */, client_window_bits, | 247 if (!parser.Parse(response_extensions)) { |
229 server_window_bits, client_no_context_takeover); | 248 // Parse error. |
230 } else { | 249 return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); |
231 return new WebSocketEncoder(false /* is_server */); | |
232 } | 250 } |
251 if (parser.extensions().size() != 1) { | |
252 // Only permessage-deflate extension is supported. | |
253 return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); | |
254 } | |
255 const auto& extension = parser.extensions()[0]; | |
256 WebSocketDeflateParameters params; | |
257 std::string failure_message; | |
258 if (!params.Initialize(extension, &failure_message) || | |
259 !params.IsValidAsResponse(&failure_message)) { | |
260 return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); | |
261 } | |
262 | |
263 auto deflater = make_scoped_ptr( | |
264 new WebSocketDeflater(params.client_context_take_over_mode())); | |
265 auto inflater = make_scoped_ptr( | |
266 new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize)); | |
267 if (!deflater->Initialize(params.PermissiveClientMaxWindowBits()) || | |
268 !inflater->Initialize(params.PermissiveServerMaxWindowBits())) { | |
269 return new WebSocketEncoder(FOR_CLIENT, nullptr, nullptr); | |
270 } | |
271 | |
272 return new WebSocketEncoder(FOR_CLIENT, deflater.Pass(), inflater.Pass()); | |
233 } | 273 } |
234 | 274 |
235 // static | 275 WebSocketEncoder::WebSocketEncoder(Type type, |
236 void WebSocketEncoder::ParseExtensions(const std::string& header_value, | 276 scoped_ptr<WebSocketDeflater> deflater, |
237 bool* deflate, | 277 scoped_ptr<WebSocketInflater> inflater) |
238 bool* has_client_window_bits, | 278 : is_server_(type == FOR_SERVER), |
dgozman
2015/09/15 17:26:19
Let's go ahead and change |is_server_| to |type_|.
yhirano
2015/09/16 05:28:17
Done.
| |
239 int* client_window_bits, | 279 deflater_(deflater.Pass()), |
240 int* server_window_bits, | 280 inflater_(inflater.Pass()) {} |
241 bool* client_no_context_takeover, | |
242 bool* server_no_context_takeover) { | |
243 *deflate = false; | |
244 *has_client_window_bits = false; | |
245 *client_window_bits = 15; | |
246 *server_window_bits = 15; | |
247 *client_no_context_takeover = false; | |
248 *server_no_context_takeover = false; | |
249 | 281 |
250 if (header_value.empty()) | 282 WebSocketEncoder::~WebSocketEncoder() {} |
251 return; | |
252 | |
253 WebSocketExtensionParser parser; | |
254 if (!parser.Parse(header_value)) | |
255 return; | |
256 const std::vector<WebSocketExtension>& extensions = parser.extensions(); | |
257 // TODO(tyoshino): Fail if this method is used for parsing a response and | |
258 // there are multiple permessage-deflate extensions or there are any unknown | |
259 // extensions. | |
260 for (const auto& extension : extensions) { | |
261 if (extension.name() != "permessage-deflate") { | |
262 continue; | |
263 } | |
264 | |
265 const std::vector<WebSocketExtension::Parameter>& parameters = | |
266 extension.parameters(); | |
267 for (const auto& param : parameters) { | |
268 const std::string& name = param.name(); | |
269 // TODO(tyoshino): Fail the connection when an invalid value is given. | |
270 if (name == "client_max_window_bits") { | |
271 *has_client_window_bits = true; | |
272 if (param.HasValue()) { | |
273 int bits = 0; | |
274 if (base::StringToInt(param.value(), &bits) && bits >= 8 && | |
275 bits <= 15) { | |
276 *client_window_bits = bits; | |
277 } | |
278 } | |
279 } | |
280 if (name == "server_max_window_bits" && param.HasValue()) { | |
281 int bits = 0; | |
282 if (base::StringToInt(param.value(), &bits) && bits >= 8 && bits <= 15) | |
283 *server_window_bits = bits; | |
284 } | |
285 if (name == "client_no_context_takeover") | |
286 *client_no_context_takeover = true; | |
287 if (name == "server_no_context_takeover") | |
288 *server_no_context_takeover = true; | |
289 } | |
290 *deflate = true; | |
291 | |
292 break; | |
293 } | |
294 } | |
295 | |
296 WebSocketEncoder::WebSocketEncoder(bool is_server) : is_server_(is_server) { | |
297 } | |
298 | |
299 WebSocketEncoder::WebSocketEncoder(bool is_server, | |
300 int deflate_bits, | |
301 int inflate_bits, | |
302 bool no_context_takeover) | |
303 : is_server_(is_server) { | |
304 deflater_.reset(new WebSocketDeflater( | |
305 no_context_takeover ? WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT | |
306 : WebSocketDeflater::TAKE_OVER_CONTEXT)); | |
307 inflater_.reset( | |
308 new WebSocketInflater(kInflaterChunkSize, kInflaterChunkSize)); | |
309 | |
310 if (!deflater_->Initialize(deflate_bits) || | |
311 !inflater_->Initialize(inflate_bits)) { | |
312 // Disable deflate support. | |
313 deflater_.reset(); | |
314 inflater_.reset(); | |
315 } | |
316 } | |
317 | |
318 WebSocketEncoder::~WebSocketEncoder() { | |
319 } | |
320 | 283 |
321 WebSocket::ParseResult WebSocketEncoder::DecodeFrame( | 284 WebSocket::ParseResult WebSocketEncoder::DecodeFrame( |
322 const base::StringPiece& frame, | 285 const base::StringPiece& frame, |
323 int* bytes_consumed, | 286 int* bytes_consumed, |
324 std::string* output) { | 287 std::string* output) { |
325 bool compressed; | 288 bool compressed; |
326 WebSocket::ParseResult result = | 289 WebSocket::ParseResult result = |
327 DecodeFrameHybi17(frame, is_server_, bytes_consumed, output, &compressed); | 290 DecodeFrameHybi17(frame, is_server_, bytes_consumed, output, &compressed); |
328 if (result == WebSocket::FRAME_OK && compressed) { | 291 if (result == WebSocket::FRAME_OK && compressed) { |
329 if (!Inflate(output)) | 292 if (!Inflate(output)) |
(...skipping 46 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
376 return false; | 339 return false; |
377 scoped_refptr<IOBufferWithSize> buffer = | 340 scoped_refptr<IOBufferWithSize> buffer = |
378 deflater_->GetOutput(deflater_->CurrentOutputSize()); | 341 deflater_->GetOutput(deflater_->CurrentOutputSize()); |
379 if (!buffer.get()) | 342 if (!buffer.get()) |
380 return false; | 343 return false; |
381 *output = std::string(buffer->data(), buffer->size()); | 344 *output = std::string(buffer->data(), buffer->size()); |
382 return true; | 345 return true; |
383 } | 346 } |
384 | 347 |
385 } // namespace net | 348 } // namespace net |
OLD | NEW |