OLD | NEW |
| (Empty) |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "net/websockets/websocket_handshake_stream_create_helper.h" | |
6 | |
7 #include <string> | |
8 #include <vector> | |
9 | |
10 #include "net/base/completion_callback.h" | |
11 #include "net/base/net_errors.h" | |
12 #include "net/http/http_request_headers.h" | |
13 #include "net/http/http_request_info.h" | |
14 #include "net/http/http_response_headers.h" | |
15 #include "net/http/http_response_info.h" | |
16 #include "net/socket/client_socket_handle.h" | |
17 #include "net/socket/socket_test_util.h" | |
18 #include "net/websockets/websocket_basic_handshake_stream.h" | |
19 #include "net/websockets/websocket_stream.h" | |
20 #include "net/websockets/websocket_test_util.h" | |
21 #include "testing/gtest/include/gtest/gtest.h" | |
22 #include "url/gurl.h" | |
23 | |
24 namespace net { | |
25 namespace { | |
26 | |
27 // This class encapsulates the details of creating a mock ClientSocketHandle. | |
28 class MockClientSocketHandleFactory { | |
29 public: | |
30 MockClientSocketHandleFactory() | |
31 : histograms_("a"), | |
32 pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {} | |
33 | |
34 // The created socket expects |expect_written| to be written to the socket, | |
35 // and will respond with |return_to_read|. The test will fail if the expected | |
36 // text is not written, or if all the bytes are not read. | |
37 scoped_ptr<ClientSocketHandle> CreateClientSocketHandle( | |
38 const std::string& expect_written, | |
39 const std::string& return_to_read) { | |
40 socket_factory_maker_.SetExpectations(expect_written, return_to_read); | |
41 scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle); | |
42 socket_handle->Init( | |
43 "a", | |
44 scoped_refptr<MockTransportSocketParams>(), | |
45 MEDIUM, | |
46 CompletionCallback(), | |
47 &pool_, | |
48 BoundNetLog()); | |
49 return socket_handle.Pass(); | |
50 } | |
51 | |
52 private: | |
53 WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_; | |
54 ClientSocketPoolHistograms histograms_; | |
55 MockTransportClientSocketPool pool_; | |
56 | |
57 DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory); | |
58 }; | |
59 | |
60 class TestConnectDelegate : public WebSocketStream::ConnectDelegate { | |
61 public: | |
62 ~TestConnectDelegate() override {} | |
63 | |
64 void OnSuccess(scoped_ptr<WebSocketStream> stream) override {} | |
65 void OnFailure(const std::string& failure_message) override {} | |
66 void OnStartOpeningHandshake( | |
67 scoped_ptr<WebSocketHandshakeRequestInfo> request) override {} | |
68 void OnFinishOpeningHandshake( | |
69 scoped_ptr<WebSocketHandshakeResponseInfo> response) override {} | |
70 void OnSSLCertificateError( | |
71 scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> | |
72 ssl_error_callbacks, | |
73 const SSLInfo& ssl_info, | |
74 bool fatal) override {} | |
75 }; | |
76 | |
77 class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { | |
78 protected: | |
79 scoped_ptr<WebSocketStream> CreateAndInitializeStream( | |
80 const std::string& socket_url, | |
81 const std::string& socket_host, | |
82 const std::string& socket_path, | |
83 const std::vector<std::string>& sub_protocols, | |
84 const std::string& origin, | |
85 const std::string& extra_request_headers, | |
86 const std::string& extra_response_headers) { | |
87 WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_, | |
88 sub_protocols); | |
89 create_helper.set_failure_message(&failure_message_); | |
90 | |
91 scoped_ptr<ClientSocketHandle> socket_handle = | |
92 socket_handle_factory_.CreateClientSocketHandle( | |
93 WebSocketStandardRequest(socket_path, socket_host, origin, | |
94 extra_request_headers), | |
95 WebSocketStandardResponse(extra_response_headers)); | |
96 | |
97 scoped_ptr<WebSocketHandshakeStreamBase> handshake( | |
98 create_helper.CreateBasicStream(socket_handle.Pass(), false)); | |
99 | |
100 // If in future the implementation type returned by CreateBasicStream() | |
101 // changes, this static_cast will be wrong. However, in that case the test | |
102 // will fail and AddressSanitizer should identify the issue. | |
103 static_cast<WebSocketBasicHandshakeStream*>(handshake.get()) | |
104 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); | |
105 | |
106 HttpRequestInfo request_info; | |
107 request_info.url = GURL(socket_url); | |
108 request_info.method = "GET"; | |
109 request_info.load_flags = LOAD_DISABLE_CACHE; | |
110 int rv = handshake->InitializeStream( | |
111 &request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback()); | |
112 EXPECT_EQ(OK, rv); | |
113 | |
114 HttpRequestHeaders headers; | |
115 headers.SetHeader("Host", "localhost"); | |
116 headers.SetHeader("Connection", "Upgrade"); | |
117 headers.SetHeader("Pragma", "no-cache"); | |
118 headers.SetHeader("Cache-Control", "no-cache"); | |
119 headers.SetHeader("Upgrade", "websocket"); | |
120 headers.SetHeader("Origin", origin); | |
121 headers.SetHeader("Sec-WebSocket-Version", "13"); | |
122 headers.SetHeader("User-Agent", ""); | |
123 headers.SetHeader("Accept-Encoding", "gzip, deflate"); | |
124 headers.SetHeader("Accept-Language", "en-us,fr"); | |
125 | |
126 HttpResponseInfo response; | |
127 TestCompletionCallback dummy; | |
128 | |
129 rv = handshake->SendRequest(headers, &response, dummy.callback()); | |
130 | |
131 EXPECT_EQ(OK, rv); | |
132 | |
133 rv = handshake->ReadResponseHeaders(dummy.callback()); | |
134 EXPECT_EQ(OK, rv); | |
135 EXPECT_EQ(101, response.headers->response_code()); | |
136 EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade")); | |
137 EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket")); | |
138 return handshake->Upgrade(); | |
139 } | |
140 | |
141 MockClientSocketHandleFactory socket_handle_factory_; | |
142 TestConnectDelegate connect_delegate_; | |
143 std::string failure_message_; | |
144 }; | |
145 | |
146 // Confirm that the basic case works as expected. | |
147 TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) { | |
148 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
149 "ws://localhost/", "localhost", "/", std::vector<std::string>(), | |
150 "http://localhost/", "", ""); | |
151 EXPECT_EQ("", stream->GetExtensions()); | |
152 EXPECT_EQ("", stream->GetSubProtocol()); | |
153 } | |
154 | |
155 // Verify that the sub-protocols are passed through. | |
156 TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) { | |
157 std::vector<std::string> sub_protocols; | |
158 sub_protocols.push_back("chat"); | |
159 sub_protocols.push_back("superchat"); | |
160 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
161 "ws://localhost/", "localhost", "/", sub_protocols, "http://localhost/", | |
162 "Sec-WebSocket-Protocol: chat, superchat\r\n", | |
163 "Sec-WebSocket-Protocol: superchat\r\n"); | |
164 EXPECT_EQ("superchat", stream->GetSubProtocol()); | |
165 } | |
166 | |
167 // Verify that extension name is available. Bad extension names are tested in | |
168 // websocket_stream_test.cc. | |
169 TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) { | |
170 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
171 "ws://localhost/", "localhost", "/", std::vector<std::string>(), | |
172 "http://localhost/", "", | |
173 "Sec-WebSocket-Extensions: permessage-deflate\r\n"); | |
174 EXPECT_EQ("permessage-deflate", stream->GetExtensions()); | |
175 } | |
176 | |
177 // Verify that extension parameters are available. Bad parameters are tested in | |
178 // websocket_stream_test.cc. | |
179 TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) { | |
180 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
181 "ws://localhost/", "localhost", "/", std::vector<std::string>(), | |
182 "http://localhost/", "", | |
183 "Sec-WebSocket-Extensions: permessage-deflate;" | |
184 " client_max_window_bits=14; server_max_window_bits=14;" | |
185 " server_no_context_takeover; client_no_context_takeover\r\n"); | |
186 | |
187 EXPECT_EQ( | |
188 "permessage-deflate;" | |
189 " client_max_window_bits=14; server_max_window_bits=14;" | |
190 " server_no_context_takeover; client_no_context_takeover", | |
191 stream->GetExtensions()); | |
192 } | |
193 | |
194 } // namespace | |
195 } // namespace net | |
OLD | NEW |