| OLD | NEW |
| (Empty) |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "net/websockets/websocket_handshake_stream_create_helper.h" | |
| 6 | |
| 7 #include <string> | |
| 8 #include <vector> | |
| 9 | |
| 10 #include "net/base/completion_callback.h" | |
| 11 #include "net/base/net_errors.h" | |
| 12 #include "net/http/http_request_headers.h" | |
| 13 #include "net/http/http_request_info.h" | |
| 14 #include "net/http/http_response_headers.h" | |
| 15 #include "net/http/http_response_info.h" | |
| 16 #include "net/socket/client_socket_handle.h" | |
| 17 #include "net/socket/socket_test_util.h" | |
| 18 #include "net/websockets/websocket_basic_handshake_stream.h" | |
| 19 #include "net/websockets/websocket_stream.h" | |
| 20 #include "net/websockets/websocket_test_util.h" | |
| 21 #include "testing/gtest/include/gtest/gtest.h" | |
| 22 #include "url/gurl.h" | |
| 23 | |
| 24 namespace net { | |
| 25 namespace { | |
| 26 | |
| 27 // This class encapsulates the details of creating a mock ClientSocketHandle. | |
| 28 class MockClientSocketHandleFactory { | |
| 29 public: | |
| 30 MockClientSocketHandleFactory() | |
| 31 : histograms_("a"), | |
| 32 pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {} | |
| 33 | |
| 34 // The created socket expects |expect_written| to be written to the socket, | |
| 35 // and will respond with |return_to_read|. The test will fail if the expected | |
| 36 // text is not written, or if all the bytes are not read. | |
| 37 scoped_ptr<ClientSocketHandle> CreateClientSocketHandle( | |
| 38 const std::string& expect_written, | |
| 39 const std::string& return_to_read) { | |
| 40 socket_factory_maker_.SetExpectations(expect_written, return_to_read); | |
| 41 scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle); | |
| 42 socket_handle->Init( | |
| 43 "a", | |
| 44 scoped_refptr<MockTransportSocketParams>(), | |
| 45 MEDIUM, | |
| 46 CompletionCallback(), | |
| 47 &pool_, | |
| 48 BoundNetLog()); | |
| 49 return socket_handle.Pass(); | |
| 50 } | |
| 51 | |
| 52 private: | |
| 53 WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_; | |
| 54 ClientSocketPoolHistograms histograms_; | |
| 55 MockTransportClientSocketPool pool_; | |
| 56 | |
| 57 DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory); | |
| 58 }; | |
| 59 | |
| 60 class TestConnectDelegate : public WebSocketStream::ConnectDelegate { | |
| 61 public: | |
| 62 ~TestConnectDelegate() override {} | |
| 63 | |
| 64 void OnSuccess(scoped_ptr<WebSocketStream> stream) override {} | |
| 65 void OnFailure(const std::string& failure_message) override {} | |
| 66 void OnStartOpeningHandshake( | |
| 67 scoped_ptr<WebSocketHandshakeRequestInfo> request) override {} | |
| 68 void OnFinishOpeningHandshake( | |
| 69 scoped_ptr<WebSocketHandshakeResponseInfo> response) override {} | |
| 70 void OnSSLCertificateError( | |
| 71 scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> | |
| 72 ssl_error_callbacks, | |
| 73 const SSLInfo& ssl_info, | |
| 74 bool fatal) override {} | |
| 75 }; | |
| 76 | |
| 77 class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { | |
| 78 protected: | |
| 79 scoped_ptr<WebSocketStream> CreateAndInitializeStream( | |
| 80 const std::string& socket_url, | |
| 81 const std::string& socket_host, | |
| 82 const std::string& socket_path, | |
| 83 const std::vector<std::string>& sub_protocols, | |
| 84 const std::string& origin, | |
| 85 const std::string& extra_request_headers, | |
| 86 const std::string& extra_response_headers) { | |
| 87 WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_, | |
| 88 sub_protocols); | |
| 89 create_helper.set_failure_message(&failure_message_); | |
| 90 | |
| 91 scoped_ptr<ClientSocketHandle> socket_handle = | |
| 92 socket_handle_factory_.CreateClientSocketHandle( | |
| 93 WebSocketStandardRequest(socket_path, socket_host, origin, | |
| 94 extra_request_headers), | |
| 95 WebSocketStandardResponse(extra_response_headers)); | |
| 96 | |
| 97 scoped_ptr<WebSocketHandshakeStreamBase> handshake( | |
| 98 create_helper.CreateBasicStream(socket_handle.Pass(), false)); | |
| 99 | |
| 100 // If in future the implementation type returned by CreateBasicStream() | |
| 101 // changes, this static_cast will be wrong. However, in that case the test | |
| 102 // will fail and AddressSanitizer should identify the issue. | |
| 103 static_cast<WebSocketBasicHandshakeStream*>(handshake.get()) | |
| 104 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); | |
| 105 | |
| 106 HttpRequestInfo request_info; | |
| 107 request_info.url = GURL(socket_url); | |
| 108 request_info.method = "GET"; | |
| 109 request_info.load_flags = LOAD_DISABLE_CACHE; | |
| 110 int rv = handshake->InitializeStream( | |
| 111 &request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback()); | |
| 112 EXPECT_EQ(OK, rv); | |
| 113 | |
| 114 HttpRequestHeaders headers; | |
| 115 headers.SetHeader("Host", "localhost"); | |
| 116 headers.SetHeader("Connection", "Upgrade"); | |
| 117 headers.SetHeader("Pragma", "no-cache"); | |
| 118 headers.SetHeader("Cache-Control", "no-cache"); | |
| 119 headers.SetHeader("Upgrade", "websocket"); | |
| 120 headers.SetHeader("Origin", origin); | |
| 121 headers.SetHeader("Sec-WebSocket-Version", "13"); | |
| 122 headers.SetHeader("User-Agent", ""); | |
| 123 headers.SetHeader("Accept-Encoding", "gzip, deflate"); | |
| 124 headers.SetHeader("Accept-Language", "en-us,fr"); | |
| 125 | |
| 126 HttpResponseInfo response; | |
| 127 TestCompletionCallback dummy; | |
| 128 | |
| 129 rv = handshake->SendRequest(headers, &response, dummy.callback()); | |
| 130 | |
| 131 EXPECT_EQ(OK, rv); | |
| 132 | |
| 133 rv = handshake->ReadResponseHeaders(dummy.callback()); | |
| 134 EXPECT_EQ(OK, rv); | |
| 135 EXPECT_EQ(101, response.headers->response_code()); | |
| 136 EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade")); | |
| 137 EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket")); | |
| 138 return handshake->Upgrade(); | |
| 139 } | |
| 140 | |
| 141 MockClientSocketHandleFactory socket_handle_factory_; | |
| 142 TestConnectDelegate connect_delegate_; | |
| 143 std::string failure_message_; | |
| 144 }; | |
| 145 | |
| 146 // Confirm that the basic case works as expected. | |
| 147 TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) { | |
| 148 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
| 149 "ws://localhost/", "localhost", "/", std::vector<std::string>(), | |
| 150 "http://localhost/", "", ""); | |
| 151 EXPECT_EQ("", stream->GetExtensions()); | |
| 152 EXPECT_EQ("", stream->GetSubProtocol()); | |
| 153 } | |
| 154 | |
| 155 // Verify that the sub-protocols are passed through. | |
| 156 TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) { | |
| 157 std::vector<std::string> sub_protocols; | |
| 158 sub_protocols.push_back("chat"); | |
| 159 sub_protocols.push_back("superchat"); | |
| 160 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
| 161 "ws://localhost/", "localhost", "/", sub_protocols, "http://localhost/", | |
| 162 "Sec-WebSocket-Protocol: chat, superchat\r\n", | |
| 163 "Sec-WebSocket-Protocol: superchat\r\n"); | |
| 164 EXPECT_EQ("superchat", stream->GetSubProtocol()); | |
| 165 } | |
| 166 | |
| 167 // Verify that extension name is available. Bad extension names are tested in | |
| 168 // websocket_stream_test.cc. | |
| 169 TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) { | |
| 170 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
| 171 "ws://localhost/", "localhost", "/", std::vector<std::string>(), | |
| 172 "http://localhost/", "", | |
| 173 "Sec-WebSocket-Extensions: permessage-deflate\r\n"); | |
| 174 EXPECT_EQ("permessage-deflate", stream->GetExtensions()); | |
| 175 } | |
| 176 | |
| 177 // Verify that extension parameters are available. Bad parameters are tested in | |
| 178 // websocket_stream_test.cc. | |
| 179 TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) { | |
| 180 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( | |
| 181 "ws://localhost/", "localhost", "/", std::vector<std::string>(), | |
| 182 "http://localhost/", "", | |
| 183 "Sec-WebSocket-Extensions: permessage-deflate;" | |
| 184 " client_max_window_bits=14; server_max_window_bits=14;" | |
| 185 " server_no_context_takeover; client_no_context_takeover\r\n"); | |
| 186 | |
| 187 EXPECT_EQ( | |
| 188 "permessage-deflate;" | |
| 189 " client_max_window_bits=14; server_max_window_bits=14;" | |
| 190 " server_no_context_takeover; client_no_context_takeover", | |
| 191 stream->GetExtensions()); | |
| 192 } | |
| 193 | |
| 194 } // namespace | |
| 195 } // namespace net | |
| OLD | NEW |