Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(518)

Side by Side Diff: net/websockets/websocket_basic_handshake_stream.cc

Issue 105833003: Fail WebSocket channel when handshake fails. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Created 6 years, 11 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
« no previous file with comments | « net/websockets/websocket_basic_handshake_stream.h ('k') | net/websockets/websocket_channel.h » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
1 // Copyright 2013 The Chromium Authors. All rights reserved. 1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be 2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file. 3 // found in the LICENSE file.
4 4
5 #include "net/websockets/websocket_basic_handshake_stream.h" 5 #include "net/websockets/websocket_basic_handshake_stream.h"
6 6
7 #include <algorithm> 7 #include <algorithm>
8 #include <iterator> 8 #include <iterator>
9 #include <string>
9 10
10 #include "base/base64.h" 11 #include "base/base64.h"
11 #include "base/basictypes.h" 12 #include "base/basictypes.h"
12 #include "base/bind.h" 13 #include "base/bind.h"
13 #include "base/containers/hash_tables.h" 14 #include "base/containers/hash_tables.h"
14 #include "base/stl_util.h" 15 #include "base/stl_util.h"
15 #include "base/strings/string_util.h" 16 #include "base/strings/string_util.h"
17 #include "base/strings/stringprintf.h"
16 #include "crypto/random.h" 18 #include "crypto/random.h"
17 #include "net/http/http_request_headers.h" 19 #include "net/http/http_request_headers.h"
18 #include "net/http/http_request_info.h" 20 #include "net/http/http_request_info.h"
19 #include "net/http/http_response_body_drainer.h" 21 #include "net/http/http_response_body_drainer.h"
20 #include "net/http/http_response_headers.h" 22 #include "net/http/http_response_headers.h"
21 #include "net/http/http_status_code.h" 23 #include "net/http/http_status_code.h"
22 #include "net/http/http_stream_parser.h" 24 #include "net/http/http_stream_parser.h"
23 #include "net/socket/client_socket_handle.h" 25 #include "net/socket/client_socket_handle.h"
24 #include "net/websockets/websocket_basic_stream.h" 26 #include "net/websockets/websocket_basic_stream.h"
27 #include "net/websockets/websocket_extension_parser.h"
25 #include "net/websockets/websocket_handshake_constants.h" 28 #include "net/websockets/websocket_handshake_constants.h"
26 #include "net/websockets/websocket_handshake_handler.h" 29 #include "net/websockets/websocket_handshake_handler.h"
27 #include "net/websockets/websocket_stream.h" 30 #include "net/websockets/websocket_stream.h"
28 31
29 namespace net { 32 namespace net {
30 namespace { 33 namespace {
31 34
35 enum GetHeaderResult {
36 GET_HEADER_OK,
37 GET_HEADER_MISSING,
38 GET_HEADER_MULTIPLE,
39 };
40
41 std::string MissingHeaderMessage(const std::string& header_name) {
42 return std::string("'") + header_name + "' header is missing";
43 }
44
45 std::string MultipleHeaderValuesMessage(const std::string& header_name) {
46 return
47 std::string("'") +
48 header_name +
49 "' header must not appear more than once in a response";
50 }
51
32 std::string GenerateHandshakeChallenge() { 52 std::string GenerateHandshakeChallenge() {
33 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); 53 std::string raw_challenge(websockets::kRawChallengeLength, '\0');
34 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); 54 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
35 std::string encoded_challenge; 55 std::string encoded_challenge;
36 base::Base64Encode(raw_challenge, &encoded_challenge); 56 base::Base64Encode(raw_challenge, &encoded_challenge);
37 return encoded_challenge; 57 return encoded_challenge;
38 } 58 }
39 59
40 void AddVectorHeaderIfNonEmpty(const char* name, 60 void AddVectorHeaderIfNonEmpty(const char* name,
41 const std::vector<std::string>& value, 61 const std::vector<std::string>& value,
42 HttpRequestHeaders* headers) { 62 HttpRequestHeaders* headers) {
43 if (value.empty()) 63 if (value.empty())
44 return; 64 return;
45 headers->SetHeader(name, JoinString(value, ", ")); 65 headers->SetHeader(name, JoinString(value, ", "));
46 } 66 }
47 67
48 // If |case_sensitive| is false, then |value| must be in lower-case. 68 GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
49 bool ValidateSingleTokenHeader( 69 const base::StringPiece& name,
50 const scoped_refptr<HttpResponseHeaders>& headers, 70 std::string* value) {
51 const base::StringPiece& name,
52 const std::string& value,
53 bool case_sensitive) {
54 void* state = NULL; 71 void* state = NULL;
55 std::string token; 72 size_t num_values = 0;
56 int tokens = 0; 73 std::string temp_value;
57 bool has_value = false; 74 while (headers->EnumerateHeader(&state, name, &temp_value)) {
58 while (headers->EnumerateHeader(&state, name, &token)) { 75 if (++num_values > 1)
59 if (++tokens > 1) 76 return GET_HEADER_MULTIPLE;
60 return false; 77 *value = temp_value;
61 has_value = case_sensitive ? value == token
62 : LowerCaseEqualsASCII(token, value.c_str());
63 } 78 }
64 return has_value; 79 return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
80 }
81
82 bool ValidateHeaderHasSingleValue(GetHeaderResult result,
83 const std::string& header_name,
84 std::string* failure_message) {
85 if (result == GET_HEADER_MISSING) {
86 *failure_message = MissingHeaderMessage(header_name);
87 return false;
88 }
89 if (result == GET_HEADER_MULTIPLE) {
90 *failure_message = MultipleHeaderValuesMessage(header_name);
91 return false;
92 }
93 DCHECK_EQ(result, GET_HEADER_OK);
94 return true;
95 }
96
97 bool ValidateUpgrade(const HttpResponseHeaders* headers,
98 std::string* failure_message) {
99 std::string value;
100 GetHeaderResult result =
101 GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
102 if (!ValidateHeaderHasSingleValue(result,
103 websockets::kUpgrade,
104 failure_message)) {
105 return false;
106 }
107
108 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
109 *failure_message =
110 "'Upgrade' header value is not 'WebSocket': " + value;
111 return false;
112 }
113 return true;
114 }
115
116 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
117 const std::string& expected,
118 std::string* failure_message) {
119 std::string actual;
120 GetHeaderResult result =
121 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
122 if (!ValidateHeaderHasSingleValue(result,
123 websockets::kSecWebSocketAccept,
124 failure_message)) {
125 return false;
126 }
127
128 if (expected != actual) {
129 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
130 return false;
131 }
132 return true;
133 }
134
135 bool ValidateConnection(const HttpResponseHeaders* headers,
136 std::string* failure_message) {
137 // Connection header is permitted to contain other tokens.
138 if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
139 *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
140 return false;
141 }
142 if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
143 websockets::kUpgrade)) {
144 *failure_message = "'Connection' header value must contain 'Upgrade'";
145 return false;
146 }
147 return true;
65 } 148 }
66 149
67 bool ValidateSubProtocol( 150 bool ValidateSubProtocol(
68 const scoped_refptr<HttpResponseHeaders>& headers, 151 const HttpResponseHeaders* headers,
69 const std::vector<std::string>& requested_sub_protocols, 152 const std::vector<std::string>& requested_sub_protocols,
70 std::string* sub_protocol) { 153 std::string* sub_protocol,
154 std::string* failure_message) {
71 void* state = NULL; 155 void* state = NULL;
72 std::string token; 156 std::string value;
73 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(), 157 base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
74 requested_sub_protocols.end()); 158 requested_sub_protocols.end());
75 int accepted = 0; 159 int count = 0;
76 while (headers->EnumerateHeader( 160 bool has_multiple_protocols = false;
77 &state, websockets::kSecWebSocketProtocol, &token)) { 161 bool has_invalid_protocol = false;
78 if (requested_set.count(token) == 0)
79 return false;
80 162
81 *sub_protocol = token; 163 while (!has_invalid_protocol || !has_multiple_protocols) {
82 // The server is only allowed to accept one protocol. 164 std::string temp_value;
83 if (++accepted > 1) 165 if (!headers->EnumerateHeader(
84 return false; 166 &state, websockets::kSecWebSocketProtocol, &temp_value))
167 break;
168 value = temp_value;
169 if (requested_set.count(value) == 0)
170 has_invalid_protocol = true;
171 if (++count > 1)
172 has_multiple_protocols = true;
85 } 173 }
86 // If the browser requested > 0 protocols, the server is required to accept 174
87 // one. 175 if (has_multiple_protocols) {
88 return requested_set.empty() || accepted == 1; 176 *failure_message =
177 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
178 return false;
179 } else if (count > 0 && requested_sub_protocols.size() == 0) {
180 *failure_message =
181 std::string("Response must not include 'Sec-WebSocket-Protocol' "
182 "header if not present in request: ")
183 + value;
184 return false;
185 } else if (has_invalid_protocol) {
186 *failure_message =
187 "'Sec-WebSocket-Protocol' header value '" +
188 value +
189 "' in response does not match any of sent values";
190 return false;
191 } else if (requested_sub_protocols.size() > 0 && count == 0) {
192 *failure_message =
193 "Sent non-empty 'Sec-WebSocket-Protocol' header "
194 "but no response was received";
195 return false;
196 }
197 *sub_protocol = value;
198 return true;
89 } 199 }
90 200
91 bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers, 201 bool ValidateExtensions(const HttpResponseHeaders* headers,
92 const std::vector<std::string>& requested_extensions, 202 const std::vector<std::string>& requested_extensions,
93 std::string* extensions) { 203 std::string* extensions,
204 std::string* failure_message) {
94 void* state = NULL; 205 void* state = NULL;
95 std::string token; 206 std::string value;
96 while (headers->EnumerateHeader( 207 while (headers->EnumerateHeader(
97 &state, websockets::kSecWebSocketExtensions, &token)) { 208 &state, websockets::kSecWebSocketExtensions, &value)) {
209 WebSocketExtensionParser parser;
210 parser.Parse(value);
211 if (parser.has_error()) {
212 // TODO(yhirano) Set appropriate failure message.
213 *failure_message =
214 "'Sec-WebSocket-Extensions' header value is "
215 "rejected by the parser: " +
216 value;
217 return false;
218 }
98 // TODO(ricea): Accept permessage-deflate with valid parameters. 219 // TODO(ricea): Accept permessage-deflate with valid parameters.
220 *failure_message =
221 "Found an unsupported extension '" +
222 parser.extension().name() +
223 "' in 'Sec-WebSocket-Extensions' header";
99 return false; 224 return false;
100 } 225 }
101 return true; 226 return true;
102 } 227 }
103 228
104 } // namespace 229 } // namespace
105 230
106 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( 231 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream(
107 scoped_ptr<ClientSocketHandle> connection, 232 scoped_ptr<ClientSocketHandle> connection,
108 bool using_proxy, 233 bool using_proxy,
(...skipping 151 matching lines...) Expand 10 before | Expand all | Expand 10 after
260 state_.read_buf(), 385 state_.read_buf(),
261 sub_protocol_, 386 sub_protocol_,
262 extensions_)); 387 extensions_));
263 } 388 }
264 389
265 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( 390 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
266 const std::string& key) { 391 const std::string& key) {
267 handshake_challenge_for_testing_.reset(new std::string(key)); 392 handshake_challenge_for_testing_.reset(new std::string(key));
268 } 393 }
269 394
395 std::string WebSocketBasicHandshakeStream::GetFailureMessage() const {
396 return failure_message_;
397 }
398
270 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( 399 void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
271 const CompletionCallback& callback, 400 const CompletionCallback& callback,
272 int result) { 401 int result) {
273 if (result == OK) 402 if (result == OK)
274 result = ValidateResponse(); 403 result = ValidateResponse();
275 callback.Run(result); 404 callback.Run(result);
276 } 405 }
277 406
278 int WebSocketBasicHandshakeStream::ValidateResponse() { 407 int WebSocketBasicHandshakeStream::ValidateResponse() {
279 DCHECK(http_response_info_); 408 DCHECK(http_response_info_);
280 const scoped_refptr<HttpResponseHeaders>& headers = 409 const scoped_refptr<HttpResponseHeaders>& headers =
281 http_response_info_->headers; 410 http_response_info_->headers;
282 411
283 switch (headers->response_code()) { 412 switch (headers->response_code()) {
284 case HTTP_SWITCHING_PROTOCOLS: 413 case HTTP_SWITCHING_PROTOCOLS:
285 return ValidateUpgradeResponse(headers); 414 return ValidateUpgradeResponse(headers);
286 415
287 // We need to pass these through for authentication to work. 416 // We need to pass these through for authentication to work.
288 case HTTP_UNAUTHORIZED: 417 case HTTP_UNAUTHORIZED:
289 case HTTP_PROXY_AUTHENTICATION_REQUIRED: 418 case HTTP_PROXY_AUTHENTICATION_REQUIRED:
290 return OK; 419 return OK;
291 420
292 // Other status codes are potentially risky (see the warnings in the 421 // Other status codes are potentially risky (see the warnings in the
293 // WHATWG WebSocket API spec) and so are dropped by default. 422 // WHATWG WebSocket API spec) and so are dropped by default.
294 default: 423 default:
424 failure_message_ = base::StringPrintf("Unexpected status code: %d",
425 headers->response_code());
295 return ERR_INVALID_RESPONSE; 426 return ERR_INVALID_RESPONSE;
296 } 427 }
297 } 428 }
298 429
299 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( 430 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
300 const scoped_refptr<HttpResponseHeaders>& headers) { 431 const scoped_refptr<HttpResponseHeaders>& headers) {
301 if (ValidateSingleTokenHeader(headers, 432 if (ValidateUpgrade(headers.get(), &failure_message_) &&
302 websockets::kUpgrade, 433 ValidateSecWebSocketAccept(headers.get(),
303 websockets::kWebSocketLowercase, 434 handshake_challenge_response_,
304 false) && 435 &failure_message_) &&
305 ValidateSingleTokenHeader(headers, 436 ValidateConnection(headers.get(), &failure_message_) &&
306 websockets::kSecWebSocketAccept, 437 ValidateSubProtocol(headers.get(),
307 handshake_challenge_response_, 438 requested_sub_protocols_,
308 true) && 439 &sub_protocol_,
309 headers->HasHeaderValue(HttpRequestHeaders::kConnection, 440 &failure_message_) &&
310 websockets::kUpgrade) && 441 ValidateExtensions(headers.get(),
311 ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) && 442 requested_extensions_,
312 ValidateExtensions(headers, requested_extensions_, &extensions_)) { 443 &extensions_,
444 &failure_message_)) {
313 return OK; 445 return OK;
314 } 446 }
447 failure_message_ = "Error during WebSocket handshake: " + failure_message_;
315 return ERR_INVALID_RESPONSE; 448 return ERR_INVALID_RESPONSE;
316 } 449 }
317 450
318 } // namespace net 451 } // namespace net
OLDNEW
« no previous file with comments | « net/websockets/websocket_basic_handshake_stream.h ('k') | net/websockets/websocket_channel.h » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698