| OLD | NEW |
| (Empty) | |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. |
| 4 |
| 5 #include "remoting/host/websocket_connection.h" |
| 6 #include "remoting/host/websocket_listener.h" |
| 7 |
| 8 #include "base/message_loop.h" |
| 9 #include "base/run_loop.h" |
| 10 #include "net/base/ip_endpoint.h" |
| 11 #include "net/base/net_util.h" |
| 12 #include "net/socket/tcp_client_socket.h" |
| 13 #include "remoting/base/socket_reader.h" |
| 14 #include "testing/gmock/include/gmock/gmock.h" |
| 15 #include "testing/gtest/include/gtest/gtest.h" |
| 16 |
| 17 namespace remoting { |
| 18 |
| 19 namespace { |
| 20 |
| 21 const int kPortRangeMin = 12800; |
| 22 const int kPortRangeMax = 12810; |
| 23 const char kHeaderEndMarker[] = "\r\n\r\n"; |
| 24 |
| 25 } // namespace |
| 26 |
| 27 class WebsocketTestReader { |
| 28 public: |
| 29 WebsocketTestReader(net::Socket* socket, const base::Closure& on_data) |
| 30 : on_data_(on_data), |
| 31 closed_(false), |
| 32 reading_header_(true) { |
| 33 reader_.Init(socket, base::Bind(&WebsocketTestReader::OnReadResult, |
| 34 base::Unretained(this))); |
| 35 } |
| 36 ~WebsocketTestReader() { |
| 37 } |
| 38 |
| 39 bool closed() { return closed_; } |
| 40 |
| 41 bool reading_header() { return reading_header_; } |
| 42 const std::string& header() { return header_; } |
| 43 const std::string& data_receved() { return data_receved_; } |
| 44 |
| 45 protected: |
| 46 void OnReadResult(scoped_refptr<net::IOBuffer> buffer, int result) { |
| 47 if (result <= 0) { |
| 48 closed_ = true; |
| 49 } else if (reading_header_) { |
| 50 header_.append(buffer->data(), buffer->data() + result); |
| 51 size_t end_pos = header_.find(kHeaderEndMarker); |
| 52 if (end_pos != std::string::npos) { |
| 53 data_receved_ = header_.substr(end_pos + strlen(kHeaderEndMarker)); |
| 54 header_ = header_.substr(0, end_pos); |
| 55 reading_header_ = false; |
| 56 } |
| 57 } else { |
| 58 data_receved_.append(buffer->data(), buffer->data() + result); |
| 59 } |
| 60 on_data_.Run(); |
| 61 } |
| 62 |
| 63 private: |
| 64 SocketReader reader_; |
| 65 base::Closure on_data_; |
| 66 bool closed_; |
| 67 bool reading_header_; |
| 68 std::string header_; |
| 69 std::string data_receved_; |
| 70 }; |
| 71 |
| 72 class WebsocketConnectionTest : public testing::Test, |
| 73 public WebsocketConnection::Delegate { |
| 74 public: |
| 75 virtual void OnWebsocketMessage(const std::string& message) OVERRIDE { |
| 76 last_message_ = message; |
| 77 if (message_run_loop_.get()) { |
| 78 message_run_loop_->Quit(); |
| 79 } |
| 80 } |
| 81 |
| 82 virtual void OnWebsocketClosed() OVERRIDE { |
| 83 connection_.reset(); |
| 84 if (closed_run_loop_.get()) { |
| 85 closed_run_loop_->Quit(); |
| 86 } |
| 87 } |
| 88 |
| 89 protected: |
| 90 void Initialize() { |
| 91 listener_.reset(new WebsocketListener()); |
| 92 net::IPAddressNumber localhost; |
| 93 ASSERT_TRUE(net::ParseIPLiteralToNumber("127.0.0.1", &localhost)); |
| 94 for (int port = kPortRangeMin; port < kPortRangeMax; ++port) { |
| 95 endpoint_ = net::IPEndPoint(localhost, port); |
| 96 if (listener_->Listen( |
| 97 endpoint_, base::Bind(&WebsocketConnectionTest::OnNewConnection, |
| 98 base::Unretained(this)))) { |
| 99 return; |
| 100 } |
| 101 } |
| 102 } |
| 103 |
| 104 void OnNewConnection(scoped_ptr<WebsocketConnection> connection) { |
| 105 EXPECT_TRUE(connection_.get() == NULL); |
| 106 connection_ = connection.Pass(); |
| 107 connection_->Accept(this); |
| 108 } |
| 109 |
| 110 void ConnectSocket() { |
| 111 client_.reset(new net::TCPClientSocket( |
| 112 net::AddressList(endpoint_), NULL, net::NetLog::Source())); |
| 113 client_connect_result_ = -1; |
| 114 client_->Connect(base::Bind(&WebsocketConnectionTest::OnClientConnected, |
| 115 base::Unretained(this))); |
| 116 connect_run_loop_.reset(new base::RunLoop()); |
| 117 connect_run_loop_->Run(); |
| 118 ASSERT_EQ(client_connect_result_, 0); |
| 119 client_writer_.reset(new protocol::BufferedSocketWriter()); |
| 120 client_writer_->Init( |
| 121 client_.get(), protocol::BufferedSocketWriter::WriteFailedCallback()); |
| 122 client_reader_.reset(new WebsocketTestReader( |
| 123 client_.get(), |
| 124 base::Bind(&WebsocketConnectionTest::OnClientDataReceived, |
| 125 base::Unretained(this)))); |
| 126 } |
| 127 |
| 128 void OnClientConnected(int result) { |
| 129 client_connect_result_ = result; |
| 130 connect_run_loop_->Quit(); |
| 131 } |
| 132 |
| 133 void OnClientDataReceived() { |
| 134 if (handshake_run_loop_.get() && !client_reader_->reading_header()) { |
| 135 handshake_run_loop_->Quit(); |
| 136 } |
| 137 if (closed_run_loop_.get() && client_reader_->closed()) { |
| 138 closed_run_loop_->Quit(); |
| 139 } |
| 140 if (data_received_run_loop_.get()) { |
| 141 data_received_run_loop_->Quit(); |
| 142 } |
| 143 } |
| 144 |
| 145 void Send(const std::string& data) { |
| 146 scoped_refptr<net::IOBufferWithSize> buffer = |
| 147 new net::IOBufferWithSize(data.size()); |
| 148 memcpy(buffer->data(), data.data(), data.size()); |
| 149 client_writer_->Write(buffer, base::Closure()); |
| 150 } |
| 151 |
| 152 void Handshake() { |
| 153 Send("GET /chat HTTP/1.1\r\n" |
| 154 "Host: server.example.com\r\n" |
| 155 "Upgrade: websocket\r\n" |
| 156 "Connection: Upgrade\r\n" |
| 157 "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" |
| 158 "Origin: http://example.com\r\n" |
| 159 "Sec-WebSocket-Version: 13\r\n\r\n"); |
| 160 handshake_run_loop_.reset(new base::RunLoop()); |
| 161 handshake_run_loop_->Run(); |
| 162 EXPECT_EQ("http://example.com", connection_->origin()); |
| 163 EXPECT_EQ("server.example.com", connection_->request_host()); |
| 164 EXPECT_EQ("/chat", connection_->request_path()); |
| 165 EXPECT_EQ("HTTP/1.1 101 Switching Protocol\r\n" |
| 166 "Upgrade: websocket\r\n" |
| 167 "Connection: Upgrade\r\n" |
| 168 "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", |
| 169 client_reader_->header()); |
| 170 } |
| 171 |
| 172 void ReceiveFrame(std::string expected_frame) { |
| 173 while (client_reader_->data_receved().size() < expected_frame.size()) { |
| 174 data_received_run_loop_.reset(new base::RunLoop()); |
| 175 data_received_run_loop_->Run(); |
| 176 } |
| 177 EXPECT_EQ(expected_frame, client_reader_->data_receved()); |
| 178 } |
| 179 |
| 180 MessageLoopForIO message_loop_; |
| 181 |
| 182 scoped_ptr<base::RunLoop> connect_run_loop_; |
| 183 scoped_ptr<base::RunLoop> handshake_run_loop_; |
| 184 scoped_ptr<base::RunLoop> closed_run_loop_; |
| 185 scoped_ptr<base::RunLoop> data_received_run_loop_; |
| 186 scoped_ptr<base::RunLoop> message_run_loop_; |
| 187 |
| 188 scoped_ptr<WebsocketListener> listener_; |
| 189 net::IPEndPoint endpoint_; |
| 190 scoped_ptr<WebsocketConnection> connection_; |
| 191 scoped_ptr<net::TCPClientSocket> client_; |
| 192 scoped_ptr<protocol::BufferedSocketWriter> client_writer_; |
| 193 scoped_ptr<WebsocketTestReader> client_reader_; |
| 194 int client_connect_result_; |
| 195 std::string last_message_; |
| 196 }; |
| 197 |
| 198 TEST_F(WebsocketConnectionTest, ConnectSocket) { |
| 199 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 200 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 201 } |
| 202 |
| 203 TEST_F(WebsocketConnectionTest, SuccessfulHandshake) { |
| 204 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 205 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 206 ASSERT_NO_FATAL_FAILURE(Handshake()); |
| 207 ASSERT_TRUE(connection_.get() != NULL); |
| 208 } |
| 209 |
| 210 TEST_F(WebsocketConnectionTest, DisconnectAndConnect) { |
| 211 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 212 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 213 ASSERT_NO_FATAL_FAILURE(Handshake()); |
| 214 ASSERT_TRUE(connection_.get() != NULL); |
| 215 |
| 216 client_.reset(); |
| 217 closed_run_loop_.reset(new base::RunLoop()); |
| 218 closed_run_loop_->Run(); |
| 219 EXPECT_TRUE(connection_.get() == NULL); |
| 220 |
| 221 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 222 ASSERT_NO_FATAL_FAILURE(Handshake()); |
| 223 ASSERT_TRUE(connection_.get() != NULL); |
| 224 } |
| 225 |
| 226 TEST_F(WebsocketConnectionTest, NonWebsocketHeader) { |
| 227 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 228 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 229 EXPECT_TRUE(connection_.get() == NULL); |
| 230 Send("GET /chat HTTP/1.1\r\n" |
| 231 "Host: server.example.com\r\n\r\n"); |
| 232 closed_run_loop_.reset(new base::RunLoop()); |
| 233 closed_run_loop_->Run(); |
| 234 EXPECT_TRUE(connection_.get() == NULL); |
| 235 } |
| 236 |
| 237 TEST_F(WebsocketConnectionTest, SendMessage) { |
| 238 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 239 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 240 ASSERT_NO_FATAL_FAILURE(Handshake()); |
| 241 connection_->SendText("Hello"); |
| 242 char expected_frame[] = { 0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f }; |
| 243 ReceiveFrame(std::string(expected_frame, |
| 244 expected_frame + sizeof(expected_frame))); |
| 245 } |
| 246 |
| 247 TEST_F(WebsocketConnectionTest, ReceiveMessage) { |
| 248 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 249 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 250 ASSERT_NO_FATAL_FAILURE(Handshake()); |
| 251 char message_frame[] = { 0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, |
| 252 0x9f, 0x4d, 0x51, 0x58 }; |
| 253 Send(std::string(message_frame, message_frame + sizeof(message_frame))); |
| 254 |
| 255 message_run_loop_.reset(new base::RunLoop()); |
| 256 message_run_loop_->Run(); |
| 257 |
| 258 EXPECT_EQ("Hello", last_message_); |
| 259 } |
| 260 |
| 261 TEST_F(WebsocketConnectionTest, FragmentedMessage) { |
| 262 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 263 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 264 ASSERT_NO_FATAL_FAILURE(Handshake()); |
| 265 |
| 266 char fragment1[] = { 0x01, 0x83, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d }; |
| 267 Send(std::string(fragment1, fragment1 + sizeof(fragment1))); |
| 268 char fragment2[] = { 0x80, 0x82, 0x3d, 0x37, 0x12, 0x42, 0x51, 0x58 }; |
| 269 Send(std::string(fragment2, fragment2 + sizeof(fragment2))); |
| 270 |
| 271 message_run_loop_.reset(new base::RunLoop()); |
| 272 message_run_loop_->Run(); |
| 273 |
| 274 EXPECT_EQ("Hello", last_message_); |
| 275 } |
| 276 |
| 277 TEST_F(WebsocketConnectionTest, PingResponse) { |
| 278 ASSERT_NO_FATAL_FAILURE(Initialize()); |
| 279 ASSERT_NO_FATAL_FAILURE(ConnectSocket()); |
| 280 ASSERT_NO_FATAL_FAILURE(Handshake()); |
| 281 |
| 282 char ping_frame[] = { 0x89, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, |
| 283 0x4d, 0x51, 0x58 }; |
| 284 Send(std::string(ping_frame, ping_frame + sizeof(ping_frame))); |
| 285 |
| 286 char expected_frame[] = { 0x8a, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f }; |
| 287 ReceiveFrame(std::string(expected_frame, |
| 288 expected_frame + sizeof(expected_frame))); |
| 289 } |
| 290 |
| 291 } // namespace remoting |
| OLD | NEW |