OLD | NEW |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/websockets/websocket_stream.h" | 5 #include "net/websockets/websocket_stream.h" |
6 | 6 |
7 #include <algorithm> | 7 #include <algorithm> |
8 #include <string> | 8 #include <string> |
9 #include <utility> | 9 #include <utility> |
10 #include <vector> | 10 #include <vector> |
11 | 11 |
12 #include "base/memory/scoped_vector.h" | 12 #include "base/memory/scoped_vector.h" |
13 #include "base/metrics/histogram.h" | 13 #include "base/metrics/histogram.h" |
14 #include "base/metrics/histogram_samples.h" | 14 #include "base/metrics/histogram_samples.h" |
15 #include "base/metrics/statistics_recorder.h" | 15 #include "base/metrics/statistics_recorder.h" |
16 #include "base/run_loop.h" | 16 #include "base/run_loop.h" |
17 #include "base/strings/stringprintf.h" | 17 #include "base/strings/stringprintf.h" |
18 #include "net/base/net_errors.h" | 18 #include "net/base/net_errors.h" |
| 19 #include "net/base/test_data_directory.h" |
19 #include "net/http/http_request_headers.h" | 20 #include "net/http/http_request_headers.h" |
20 #include "net/http/http_response_headers.h" | 21 #include "net/http/http_response_headers.h" |
21 #include "net/socket/client_socket_handle.h" | 22 #include "net/socket/client_socket_handle.h" |
22 #include "net/socket/socket_test_util.h" | 23 #include "net/socket/socket_test_util.h" |
| 24 #include "net/test/cert_test_util.h" |
23 #include "net/url_request/url_request_test_util.h" | 25 #include "net/url_request/url_request_test_util.h" |
24 #include "net/websockets/websocket_basic_handshake_stream.h" | 26 #include "net/websockets/websocket_basic_handshake_stream.h" |
25 #include "net/websockets/websocket_frame.h" | 27 #include "net/websockets/websocket_frame.h" |
26 #include "net/websockets/websocket_handshake_request_info.h" | 28 #include "net/websockets/websocket_handshake_request_info.h" |
27 #include "net/websockets/websocket_handshake_response_info.h" | 29 #include "net/websockets/websocket_handshake_response_info.h" |
28 #include "net/websockets/websocket_handshake_stream_create_helper.h" | 30 #include "net/websockets/websocket_handshake_stream_create_helper.h" |
29 #include "net/websockets/websocket_test_util.h" | 31 #include "net/websockets/websocket_test_util.h" |
30 #include "testing/gtest/include/gtest/gtest.h" | 32 #include "testing/gtest/include/gtest/gtest.h" |
31 #include "url/gurl.h" | 33 #include "url/gurl.h" |
32 #include "url/origin.h" | 34 #include "url/origin.h" |
(...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
72 // This will break in an obvious way if the type created by | 74 // This will break in an obvious way if the type created by |
73 // CreateBasicStream() changes. | 75 // CreateBasicStream() changes. |
74 static_cast<WebSocketBasicHandshakeStream*>(stream()) | 76 static_cast<WebSocketBasicHandshakeStream*>(stream()) |
75 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); | 77 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); |
76 return stream(); | 78 return stream(); |
77 } | 79 } |
78 }; | 80 }; |
79 | 81 |
80 class WebSocketStreamCreateTest : public ::testing::Test { | 82 class WebSocketStreamCreateTest : public ::testing::Test { |
81 public: | 83 public: |
82 WebSocketStreamCreateTest(): has_failed_(false) {} | 84 WebSocketStreamCreateTest() : has_failed_(false), ssl_fatal_(false) {} |
83 | 85 |
84 void CreateAndConnectCustomResponse( | 86 void CreateAndConnectCustomResponse( |
85 const std::string& socket_url, | 87 const std::string& socket_url, |
86 const std::string& socket_path, | 88 const std::string& socket_path, |
87 const std::vector<std::string>& sub_protocols, | 89 const std::vector<std::string>& sub_protocols, |
88 const std::string& origin, | 90 const std::string& origin, |
89 const std::string& extra_request_headers, | 91 const std::string& extra_request_headers, |
90 const std::string& response_body) { | 92 const std::string& response_body) { |
91 url_request_context_host_.SetExpectations( | 93 url_request_context_host_.SetExpectations( |
92 WebSocketStandardRequest(socket_path, origin, extra_request_headers), | 94 WebSocketStandardRequest(socket_path, origin, extra_request_headers), |
(...skipping 16 matching lines...) Expand all Loading... |
109 origin, | 111 origin, |
110 extra_request_headers, | 112 extra_request_headers, |
111 WebSocketStandardResponse(extra_response_headers)); | 113 WebSocketStandardResponse(extra_response_headers)); |
112 } | 114 } |
113 | 115 |
114 void CreateAndConnectRawExpectations( | 116 void CreateAndConnectRawExpectations( |
115 const std::string& socket_url, | 117 const std::string& socket_url, |
116 const std::vector<std::string>& sub_protocols, | 118 const std::vector<std::string>& sub_protocols, |
117 const std::string& origin, | 119 const std::string& origin, |
118 scoped_ptr<DeterministicSocketData> socket_data) { | 120 scoped_ptr<DeterministicSocketData> socket_data) { |
119 url_request_context_host_.SetRawExpectations(socket_data.Pass()); | 121 url_request_context_host_.AddRawExpectations(socket_data.Pass()); |
120 CreateAndConnectStream(socket_url, sub_protocols, origin); | 122 CreateAndConnectStream(socket_url, sub_protocols, origin); |
121 } | 123 } |
122 | 124 |
123 // A wrapper for CreateAndConnectStreamForTesting that knows about our default | 125 // A wrapper for CreateAndConnectStreamForTesting that knows about our default |
124 // parameters. | 126 // parameters. |
125 void CreateAndConnectStream(const std::string& socket_url, | 127 void CreateAndConnectStream(const std::string& socket_url, |
126 const std::vector<std::string>& sub_protocols, | 128 const std::vector<std::string>& sub_protocols, |
127 const std::string& origin) { | 129 const std::string& origin) { |
| 130 for (size_t i = 0; i < ssl_data_.size(); ++i) { |
| 131 scoped_ptr<SSLSocketDataProvider> ssl_data(ssl_data_[i]); |
| 132 ssl_data_[i] = NULL; |
| 133 url_request_context_host_.AddSSLSocketDataProvider(ssl_data.Pass()); |
| 134 } |
| 135 ssl_data_.clear(); |
128 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate( | 136 scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate( |
129 new TestConnectDelegate(this)); | 137 new TestConnectDelegate(this)); |
130 WebSocketStream::ConnectDelegate* delegate = connect_delegate.get(); | 138 WebSocketStream::ConnectDelegate* delegate = connect_delegate.get(); |
131 stream_request_ = ::net::CreateAndConnectStreamForTesting( | 139 stream_request_ = ::net::CreateAndConnectStreamForTesting( |
132 GURL(socket_url), | 140 GURL(socket_url), |
133 scoped_ptr<WebSocketHandshakeStreamCreateHelper>( | 141 scoped_ptr<WebSocketHandshakeStreamCreateHelper>( |
134 new DeterministicKeyWebSocketHandshakeStreamCreateHelper( | 142 new DeterministicKeyWebSocketHandshakeStreamCreateHelper( |
135 delegate, sub_protocols)), | 143 delegate, sub_protocols)), |
136 url::Origin(origin), | 144 url::Origin(origin), |
137 url_request_context_host_.GetURLRequestContext(), | 145 url_request_context_host_.GetURLRequestContext(), |
(...skipping 30 matching lines...) Expand all Loading... |
168 if (owner_->request_info_) | 176 if (owner_->request_info_) |
169 ADD_FAILURE(); | 177 ADD_FAILURE(); |
170 owner_->request_info_ = request.Pass(); | 178 owner_->request_info_ = request.Pass(); |
171 } | 179 } |
172 virtual void OnFinishOpeningHandshake( | 180 virtual void OnFinishOpeningHandshake( |
173 scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { | 181 scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { |
174 if (owner_->response_info_) | 182 if (owner_->response_info_) |
175 ADD_FAILURE(); | 183 ADD_FAILURE(); |
176 owner_->response_info_ = response.Pass(); | 184 owner_->response_info_ = response.Pass(); |
177 } | 185 } |
| 186 virtual void OnSSLCertificateError( |
| 187 scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> |
| 188 ssl_error_callbacks, |
| 189 const SSLInfo& ssl_info, |
| 190 bool fatal) OVERRIDE { |
| 191 owner_->ssl_error_callbacks_ = ssl_error_callbacks.Pass(); |
| 192 owner_->ssl_info_ = ssl_info; |
| 193 owner_->ssl_fatal_ = fatal; |
| 194 } |
178 | 195 |
179 private: | 196 private: |
180 WebSocketStreamCreateTest* owner_; | 197 WebSocketStreamCreateTest* owner_; |
181 }; | 198 }; |
182 | 199 |
183 WebSocketTestURLRequestContextHost url_request_context_host_; | 200 WebSocketTestURLRequestContextHost url_request_context_host_; |
184 scoped_ptr<WebSocketStreamRequest> stream_request_; | 201 scoped_ptr<WebSocketStreamRequest> stream_request_; |
185 // Only set if the connection succeeded. | 202 // Only set if the connection succeeded. |
186 scoped_ptr<WebSocketStream> stream_; | 203 scoped_ptr<WebSocketStream> stream_; |
187 // Only set if the connection failed. | 204 // Only set if the connection failed. |
188 std::string failure_message_; | 205 std::string failure_message_; |
189 bool has_failed_; | 206 bool has_failed_; |
190 scoped_ptr<WebSocketHandshakeRequestInfo> request_info_; | 207 scoped_ptr<WebSocketHandshakeRequestInfo> request_info_; |
191 scoped_ptr<WebSocketHandshakeResponseInfo> response_info_; | 208 scoped_ptr<WebSocketHandshakeResponseInfo> response_info_; |
| 209 scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks_; |
| 210 SSLInfo ssl_info_; |
| 211 bool ssl_fatal_; |
| 212 ScopedVector<SSLSocketDataProvider> ssl_data_; |
192 }; | 213 }; |
193 | 214 |
194 // There are enough tests of the Sec-WebSocket-Extensions header that they | 215 // There are enough tests of the Sec-WebSocket-Extensions header that they |
195 // deserve their own test fixture. | 216 // deserve their own test fixture. |
196 class WebSocketStreamCreateExtensionTest : public WebSocketStreamCreateTest { | 217 class WebSocketStreamCreateExtensionTest : public WebSocketStreamCreateTest { |
197 public: | 218 public: |
198 // Performs a standard connect, with the value of the Sec-WebSocket-Extensions | 219 // Performs a standard connect, with the value of the Sec-WebSocket-Extensions |
199 // header in the response set to |extensions_header_value|. Runs the event | 220 // header in the response set to |extensions_header_value|. Runs the event |
200 // loop to allow the connect to complete. | 221 // loop to allow the connect to complete. |
201 void CreateAndConnectWithExtensions( | 222 void CreateAndConnectWithExtensions( |
(...skipping 823 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
1025 "http://localhost", | 1046 "http://localhost", |
1026 make_scoped_ptr(socket_data)); | 1047 make_scoped_ptr(socket_data)); |
1027 socket_data->RunFor(2); | 1048 socket_data->RunFor(2); |
1028 EXPECT_TRUE(has_failed()); | 1049 EXPECT_TRUE(has_failed()); |
1029 EXPECT_FALSE(stream_); | 1050 EXPECT_FALSE(stream_); |
1030 EXPECT_FALSE(response_info_); | 1051 EXPECT_FALSE(response_info_); |
1031 EXPECT_EQ("Connection closed before receiving a handshake response", | 1052 EXPECT_EQ("Connection closed before receiving a handshake response", |
1032 failure_message()); | 1053 failure_message()); |
1033 } | 1054 } |
1034 | 1055 |
| 1056 TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateFailure) { |
| 1057 ssl_data_.push_back( |
| 1058 new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID)); |
| 1059 ssl_data_[0]->cert = |
| 1060 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"); |
| 1061 ASSERT_TRUE(ssl_data_[0]->cert); |
| 1062 scoped_ptr<DeterministicSocketData> raw_socket_data( |
| 1063 new DeterministicSocketData(NULL, 0, NULL, 0)); |
| 1064 CreateAndConnectRawExpectations("wss://localhost/", |
| 1065 NoSubProtocols(), |
| 1066 "http://localhost", |
| 1067 raw_socket_data.Pass()); |
| 1068 RunUntilIdle(); |
| 1069 EXPECT_FALSE(has_failed()); |
| 1070 ASSERT_TRUE(ssl_error_callbacks_); |
| 1071 ssl_error_callbacks_->CancelSSLRequest(ERR_CERT_AUTHORITY_INVALID, |
| 1072 &ssl_info_); |
| 1073 RunUntilIdle(); |
| 1074 EXPECT_TRUE(has_failed()); |
| 1075 } |
| 1076 |
| 1077 TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateSuccess) { |
| 1078 scoped_ptr<SSLSocketDataProvider> ssl_data( |
| 1079 new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID)); |
| 1080 ssl_data->cert = |
| 1081 ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"); |
| 1082 ASSERT_TRUE(ssl_data->cert); |
| 1083 ssl_data_.push_back(ssl_data.release()); |
| 1084 ssl_data.reset(new SSLSocketDataProvider(ASYNC, OK)); |
| 1085 ssl_data_.push_back(ssl_data.release()); |
| 1086 url_request_context_host_.AddRawExpectations( |
| 1087 make_scoped_ptr(new DeterministicSocketData(NULL, 0, NULL, 0))); |
| 1088 CreateAndConnectStandard( |
| 1089 "wss://localhost/", "/", NoSubProtocols(), "http://localhost", "", ""); |
| 1090 RunUntilIdle(); |
| 1091 ASSERT_TRUE(ssl_error_callbacks_); |
| 1092 ssl_error_callbacks_->ContinueSSLRequest(); |
| 1093 RunUntilIdle(); |
| 1094 EXPECT_FALSE(has_failed()); |
| 1095 EXPECT_TRUE(stream_); |
| 1096 } |
| 1097 |
1035 TEST_F(WebSocketStreamCreateUMATest, Incomplete) { | 1098 TEST_F(WebSocketStreamCreateUMATest, Incomplete) { |
1036 const std::string name("Net.WebSocket.HandshakeResult"); | 1099 const std::string name("Net.WebSocket.HandshakeResult"); |
1037 scoped_ptr<base::HistogramSamples> original(GetSamples(name)); | 1100 scoped_ptr<base::HistogramSamples> original(GetSamples(name)); |
1038 | 1101 |
1039 { | 1102 { |
1040 StreamCreation creation; | 1103 StreamCreation creation; |
1041 creation.CreateAndConnectStandard("ws://localhost/", | 1104 creation.CreateAndConnectStandard("ws://localhost/", |
1042 "/", | 1105 "/", |
1043 creation.NoSubProtocols(), | 1106 creation.NoSubProtocols(), |
1044 "http://localhost", | 1107 "http://localhost", |
(...skipping 55 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
1100 "", | 1163 "", |
1101 kInvalidStatusCodeResponse); | 1164 kInvalidStatusCodeResponse); |
1102 creation.RunUntilIdle(); | 1165 creation.RunUntilIdle(); |
1103 } | 1166 } |
1104 | 1167 |
1105 scoped_ptr<base::HistogramSamples> samples(GetSamples(name)); | 1168 scoped_ptr<base::HistogramSamples> samples(GetSamples(name)); |
1106 ASSERT_TRUE(samples); | 1169 ASSERT_TRUE(samples); |
1107 if (original) { | 1170 if (original) { |
1108 samples->Subtract(*original); // Cancel the original values. | 1171 samples->Subtract(*original); // Cancel the original values. |
1109 } | 1172 } |
1110 EXPECT_EQ(0, samples->GetCount(INCOMPLETE)); | 1173 EXPECT_EQ(1, samples->GetCount(INCOMPLETE)); |
1111 EXPECT_EQ(0, samples->GetCount(CONNECTED)); | 1174 EXPECT_EQ(0, samples->GetCount(CONNECTED)); |
1112 EXPECT_EQ(1, samples->GetCount(FAILED)); | 1175 EXPECT_EQ(0, samples->GetCount(FAILED)); |
1113 } | 1176 } |
1114 | 1177 |
1115 } // namespace | 1178 } // namespace |
1116 } // namespace net | 1179 } // namespace net |
OLD | NEW |