OLD | NEW |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/websockets/websocket_stream.h" | 5 #include "net/websockets/websocket_stream.h" |
6 | 6 |
7 #include <algorithm> | 7 #include <algorithm> |
8 #include <string> | 8 #include <string> |
9 #include <utility> | 9 #include <utility> |
10 #include <vector> | 10 #include <vector> |
11 | 11 |
12 #include "base/memory/scoped_vector.h" | 12 #include "base/memory/scoped_vector.h" |
13 #include "base/metrics/histogram.h" | 13 #include "base/metrics/histogram.h" |
14 #include "base/metrics/histogram_samples.h" | 14 #include "base/metrics/histogram_samples.h" |
15 #include "base/metrics/statistics_recorder.h" | 15 #include "base/metrics/statistics_recorder.h" |
16 #include "base/run_loop.h" | 16 #include "base/run_loop.h" |
17 #include "base/strings/stringprintf.h" | 17 #include "base/strings/stringprintf.h" |
18 #include "net/base/net_errors.h" | 18 #include "net/base/net_errors.h" |
| 19 #include "net/base/test_data_directory.h" |
19 #include "net/http/http_request_headers.h" | 20 #include "net/http/http_request_headers.h" |
20 #include "net/http/http_response_headers.h" | 21 #include "net/http/http_response_headers.h" |
21 #include "net/socket/client_socket_handle.h" | 22 #include "net/socket/client_socket_handle.h" |
22 #include "net/socket/socket_test_util.h" | 23 #include "net/socket/socket_test_util.h" |
| 24 #include "net/test/cert_test_util.h" |
23 #include "net/url_request/url_request_test_util.h" | 25 #include "net/url_request/url_request_test_util.h" |
24 #include "net/websockets/websocket_basic_handshake_stream.h" | 26 #include "net/websockets/websocket_basic_handshake_stream.h" |
25 #include "net/websockets/websocket_frame.h" | 27 #include "net/websockets/websocket_frame.h" |
26 #include "net/websockets/websocket_handshake_request_info.h" | 28 #include "net/websockets/websocket_handshake_request_info.h" |
27 #include "net/websockets/websocket_handshake_response_info.h" | 29 #include "net/websockets/websocket_handshake_response_info.h" |
28 #include "net/websockets/websocket_handshake_stream_create_helper.h" | 30 #include "net/websockets/websocket_handshake_stream_create_helper.h" |
29 #include "net/websockets/websocket_test_util.h" | 31 #include "net/websockets/websocket_test_util.h" |
30 #include "testing/gtest/include/gtest/gtest.h" | 32 #include "testing/gtest/include/gtest/gtest.h" |
31 #include "url/gurl.h" | 33 #include "url/gurl.h" |
32 #include "url/origin.h" | 34 #include "url/origin.h" |
(...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
72 // This will break in an obvious way if the type created by | 74 // This will break in an obvious way if the type created by |
73 // CreateBasicStream() changes. | 75 // CreateBasicStream() changes. |
74 static_cast<WebSocketBasicHandshakeStream*>(stream()) | 76 static_cast<WebSocketBasicHandshakeStream*>(stream()) |
75 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); | 77 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); |
76 return stream(); | 78 return stream(); |
77 } | 79 } |
78 }; | 80 }; |
79 | 81 |
80 class WebSocketStreamCreateTest : public ::testing::Test { | 82 class WebSocketStreamCreateTest : public ::testing::Test { |
81 public: | 83 public: |
82 WebSocketStreamCreateTest(): has_failed_(false) {} | 84 WebSocketStreamCreateTest() : has_failed_(false), ssl_fatal_(false) {} |
83 | 85 |
84 void CreateAndConnectCustomResponse( | 86 void CreateAndConnectCustomResponse( |
85 const std::string& socket_url, | 87 const std::string& socket_url, |
86 const std::string& socket_path, | 88 const std::string& socket_path, |
87 const std::vector<std::string>& sub_protocols, | 89 const std::vector<std::string>& sub_protocols, |
88 const std::string& origin, | 90 const std::string& origin, |
89 const std::string& extra_request_headers, | 91 const std::string& extra_request_headers, |
90 const std::string& response_body) { | 92 const std::string& response_body) { |
91 url_request_context_host_.SetExpectations( | 93 url_request_context_host_.SetExpectations( |
92 WebSocketStandardRequest(socket_path, origin, extra_request_headers), | 94 WebSocketStandardRequest(socket_path, origin, extra_request_headers), |
(...skipping 16 matching lines...) Expand all Loading... |
109 origin, | 111 origin, |
110 extra_request_headers, | 112 extra_request_headers, |
111 WebSocketStandardResponse(extra_response_headers)); | 113 WebSocketStandardResponse(extra_response_headers)); |
112 } | 114 } |
113 | 115 |
114 void CreateAndConnectRawExpectations( | 116 void CreateAndConnectRawExpectations( |
115 const std::string& socket_url, | 117 const std::string& socket_url, |
116 const std::vector<std::string>& sub_protocols, | 118 const std::vector<std::string>& sub_protocols, |
117 const std::string& origin, | 119 const std::string& origin, |
118 scoped_ptr<DeterministicSocketData> socket_data) { | 120 scoped_ptr<DeterministicSocketData> socket_data) { |
119 url_request_context_host_.SetRawExpectations(socket_data.Pass()); | 121 url_request_context_host_.AddRawExpectations(socket_data.Pass()); |
120 CreateAndConnectStream(socket_url, sub_protocols, origin); | 122 CreateAndConnectStream(socket_url, sub_protocols, origin); |
121 } | 123 } |
122 | 124 |
123 // A wrapper for CreateAndConnectStreamForTesting that knows about our default | 125 // A wrapper for CreateAndConnectStreamForTesting that knows about our default |
124 // parameters. | 126 // parameters. |
125 void CreateAndConnectStream(const std::string& socket_url, | 127 void CreateAndConnectStream(const std::string& socket_url, |
126 const std::vector<std::string>& sub_protocols, | 128 const std::vector<std::string>& sub_protocols, |
127 const std::string& origin) { | 129 const std::string& origin) { |
| 130 if (!ssl_data_.empty()) { |
| 131 for (size_t i = 0; i < ssl_data_.size(); ++i) { |
| 132 scoped_ptr<SSLSocketDataProvider> ssl_data(ssl_data_[i]); |
| 133 ssl_data_[i] = NULL; |
| 134 url_request_context_host_.AddSSLSocketDataProvider(ssl_data.Pass()); |
| 135 } |
| 136 ssl_data_.clear(); |
| 137 } |
128 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate( | 138 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate( |
129 new TestConnectDelegate(this)); | 139 new TestConnectDelegate(this)); |
130 WebSocketStream::ConnectDelegate* delegate = connect_delegate.get(); | 140 WebSocketStream::ConnectDelegate* delegate = connect_delegate.get(); |
131 stream_request_ = ::net::CreateAndConnectStreamForTesting( | 141 stream_request_ = ::net::CreateAndConnectStreamForTesting( |
132 GURL(socket_url), | 142 GURL(socket_url), |
133 scoped_ptr<WebSocketHandshakeStreamCreateHelper>( | 143 scoped_ptr<WebSocketHandshakeStreamCreateHelper>( |
134 new DeterministicKeyWebSocketHandshakeStreamCreateHelper( | 144 new DeterministicKeyWebSocketHandshakeStreamCreateHelper( |
135 delegate, sub_protocols)), | 145 delegate, sub_protocols)), |
136 url::Origin(origin), | 146 url::Origin(origin), |
137 url_request_context_host_.GetURLRequestContext(), | 147 url_request_context_host_.GetURLRequestContext(), |
(...skipping 30 matching lines...) Expand all Loading... |
168 if (owner_->request_info_) | 178 if (owner_->request_info_) |
169 ADD_FAILURE(); | 179 ADD_FAILURE(); |
170 owner_->request_info_ = request.Pass(); | 180 owner_->request_info_ = request.Pass(); |
171 } | 181 } |
172 virtual void OnFinishOpeningHandshake( | 182 virtual void OnFinishOpeningHandshake( |
173 scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { | 183 scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { |
174 if (owner_->response_info_) | 184 if (owner_->response_info_) |
175 ADD_FAILURE(); | 185 ADD_FAILURE(); |
176 owner_->response_info_ = response.Pass(); | 186 owner_->response_info_ = response.Pass(); |
177 } | 187 } |
| 188 virtual void OnSSLCertificateError( |
| 189 scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> |
| 190 ssl_error_callbacks, |
| 191 const SSLInfo& ssl_info, |
| 192 bool fatal) OVERRIDE { |
| 193 owner_->ssl_error_callbacks_ = ssl_error_callbacks.Pass(); |
| 194 owner_->ssl_info_ = ssl_info; |
| 195 owner_->ssl_fatal_ = fatal; |
| 196 } |
178 | 197 |
179 private: | 198 private: |
180 WebSocketStreamCreateTest* owner_; | 199 WebSocketStreamCreateTest* owner_; |
181 }; | 200 }; |
182 | 201 |
183 WebSocketTestURLRequestContextHost url_request_context_host_; | 202 WebSocketTestURLRequestContextHost url_request_context_host_; |
184 scoped_ptr<WebSocketStreamRequest> stream_request_; | 203 scoped_ptr<WebSocketStreamRequest> stream_request_; |
185 // Only set if the connection succeeded. | 204 // Only set if the connection succeeded. |
186 scoped_ptr<WebSocketStream> stream_; | 205 scoped_ptr<WebSocketStream> stream_; |
187 // Only set if the connection failed. | 206 // Only set if the connection failed. |
188 std::string failure_message_; | 207 std::string failure_message_; |
189 bool has_failed_; | 208 bool has_failed_; |
190 scoped_ptr<WebSocketHandshakeRequestInfo> request_info_; | 209 scoped_ptr<WebSocketHandshakeRequestInfo> request_info_; |
191 scoped_ptr<WebSocketHandshakeResponseInfo> response_info_; | 210 scoped_ptr<WebSocketHandshakeResponseInfo> response_info_; |
| 211 scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks_; |
| 212 SSLInfo ssl_info_; |
| 213 bool ssl_fatal_; |
| 214 ScopedVector<SSLSocketDataProvider> ssl_data_; |
192 }; | 215 }; |
193 | 216 |
194 // There are enough tests of the Sec-WebSocket-Extensions header that they | 217 // There are enough tests of the Sec-WebSocket-Extensions header that they |
195 // deserve their own test fixture. | 218 // deserve their own test fixture. |
196 class WebSocketStreamCreateExtensionTest : public WebSocketStreamCreateTest { | 219 class WebSocketStreamCreateExtensionTest : public WebSocketStreamCreateTest { |
197 public: | 220 public: |
198 // Performs a standard connect, with the value of the Sec-WebSocket-Extensions | 221 // Performs a standard connect, with the value of the Sec-WebSocket-Extensions |
199 // header in the response set to |extensions_header_value|. Runs the event | 222 // header in the response set to |extensions_header_value|. Runs the event |
200 // loop to allow the connect to complete. | 223 // loop to allow the connect to complete. |
201 void CreateAndConnectWithExtensions( | 224 void CreateAndConnectWithExtensions( |
(...skipping 823 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
1025 "http://localhost", | 1048 "http://localhost", |
1026 make_scoped_ptr(socket_data)); | 1049 make_scoped_ptr(socket_data)); |
1027 socket_data->RunFor(2); | 1050 socket_data->RunFor(2); |
1028 EXPECT_TRUE(has_failed()); | 1051 EXPECT_TRUE(has_failed()); |
1029 EXPECT_FALSE(stream_); | 1052 EXPECT_FALSE(stream_); |
1030 EXPECT_FALSE(response_info_); | 1053 EXPECT_FALSE(response_info_); |
1031 EXPECT_EQ("Connection closed before receiving a handshake response", | 1054 EXPECT_EQ("Connection closed before receiving a handshake response", |
1032 failure_message()); | 1055 failure_message()); |
1033 } | 1056 } |
1034 | 1057 |
| 1058 TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateFailure) { |
| 1059 ssl_data_.push_back( |
| 1060 new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID)); |
| 1061 ssl_data_[0]->cert = |
| 1062 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"); |
| 1063 ASSERT_TRUE(ssl_data_[0]->cert); |
| 1064 scoped_ptr<DeterministicSocketData> raw_socket_data( |
| 1065 new DeterministicSocketData(NULL, 0, NULL, 0)); |
| 1066 CreateAndConnectRawExpectations("wss://localhost/", |
| 1067 NoSubProtocols(), |
| 1068 "http://localhost", |
| 1069 raw_socket_data.Pass()); |
| 1070 RunUntilIdle(); |
| 1071 EXPECT_FALSE(has_failed()); |
| 1072 ASSERT_TRUE(ssl_error_callbacks_); |
| 1073 ssl_error_callbacks_->CancelSSLRequest(ERR_CERT_AUTHORITY_INVALID, |
| 1074 &ssl_info_); |
| 1075 RunUntilIdle(); |
| 1076 EXPECT_TRUE(has_failed()); |
| 1077 } |
| 1078 |
| 1079 TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateSuccess) { |
| 1080 scoped_ptr<SSLSocketDataProvider> ssl_data( |
| 1081 new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID)); |
| 1082 ssl_data->cert = |
| 1083 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"); |
| 1084 ASSERT_TRUE(ssl_data->cert); |
| 1085 ssl_data_.push_back(ssl_data.release()); |
| 1086 ssl_data.reset(new SSLSocketDataProvider(ASYNC, OK)); |
| 1087 ssl_data_.push_back(ssl_data.release()); |
| 1088 url_request_context_host_.AddRawExpectations( |
| 1089 make_scoped_ptr(new DeterministicSocketData(NULL, 0, NULL, 0))); |
| 1090 CreateAndConnectStandard( |
| 1091 "wss://localhost/", "/", NoSubProtocols(), "http://localhost", "", ""); |
| 1092 RunUntilIdle(); |
| 1093 ASSERT_TRUE(ssl_error_callbacks_); |
| 1094 ssl_error_callbacks_->ContinueSSLRequest(); |
| 1095 RunUntilIdle(); |
| 1096 EXPECT_FALSE(has_failed()); |
| 1097 EXPECT_TRUE(stream_); |
| 1098 } |
| 1099 |
1035 TEST_F(WebSocketStreamCreateUMATest, Incomplete) { | 1100 TEST_F(WebSocketStreamCreateUMATest, Incomplete) { |
1036 const std::string name("Net.WebSocket.HandshakeResult"); | 1101 const std::string name("Net.WebSocket.HandshakeResult"); |
1037 scoped_ptr<base::HistogramSamples> original(GetSamples(name)); | 1102 scoped_ptr<base::HistogramSamples> original(GetSamples(name)); |
1038 | 1103 |
1039 { | 1104 { |
1040 StreamCreation creation; | 1105 StreamCreation creation; |
1041 creation.CreateAndConnectStandard("ws://localhost/", | 1106 creation.CreateAndConnectStandard("ws://localhost/", |
1042 "/", | 1107 "/", |
1043 creation.NoSubProtocols(), | 1108 creation.NoSubProtocols(), |
1044 "http://localhost", | 1109 "http://localhost", |
(...skipping 55 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
1100 "", | 1165 "", |
1101 kInvalidStatusCodeResponse); | 1166 kInvalidStatusCodeResponse); |
1102 creation.RunUntilIdle(); | 1167 creation.RunUntilIdle(); |
1103 } | 1168 } |
1104 | 1169 |
1105 scoped_ptr<base::HistogramSamples> samples(GetSamples(name)); | 1170 scoped_ptr<base::HistogramSamples> samples(GetSamples(name)); |
1106 ASSERT_TRUE(samples); | 1171 ASSERT_TRUE(samples); |
1107 if (original) { | 1172 if (original) { |
1108 samples->Subtract(*original); // Cancel the original values. | 1173 samples->Subtract(*original); // Cancel the original values. |
1109 } | 1174 } |
1110 EXPECT_EQ(0, samples->GetCount(INCOMPLETE)); | 1175 EXPECT_EQ(1, samples->GetCount(INCOMPLETE)); |
1111 EXPECT_EQ(0, samples->GetCount(CONNECTED)); | 1176 EXPECT_EQ(0, samples->GetCount(CONNECTED)); |
1112 EXPECT_EQ(1, samples->GetCount(FAILED)); | 1177 EXPECT_EQ(0, samples->GetCount(FAILED)); |
1113 } | 1178 } |
1114 | 1179 |
1115 } // namespace | 1180 } // namespace |
1116 } // namespace net | 1181 } // namespace net |
OLD | NEW |