| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "chrome/browser/extensions/api/socket/tcp_socket.h" | 5 #include "chrome/browser/extensions/api/socket/tls_socket.h" |
| 6 | 6 |
| 7 #include "base/memory/scoped_ptr.h" | 7 #include "base/memory/scoped_ptr.h" |
| 8 #include "base/strings/string_piece.h" |
| 8 #include "net/base/address_list.h" | 9 #include "net/base/address_list.h" |
| 9 #include "net/base/completion_callback.h" | 10 #include "net/base/completion_callback.h" |
| 10 #include "net/base/io_buffer.h" | 11 #include "net/base/io_buffer.h" |
| 11 #include "net/base/net_errors.h" | 12 #include "net/base/net_errors.h" |
| 12 #include "net/base/rand_callback.h" | 13 #include "net/base/rand_callback.h" |
| 14 #include "net/socket/ssl_client_socket.h" |
| 13 #include "net/socket/tcp_client_socket.h" | 15 #include "net/socket/tcp_client_socket.h" |
| 14 #include "net/socket/tcp_server_socket.h" | |
| 15 #include "testing/gmock/include/gmock/gmock.h" | 16 #include "testing/gmock/include/gmock/gmock.h" |
| 16 | 17 |
| 17 using testing::_; | 18 using testing::_; |
| 18 using testing::DoAll; | 19 using testing::DoAll; |
| 19 using testing::Return; | 20 using testing::Return; |
| 20 using testing::SaveArg; | 21 using testing::SaveArg; |
| 22 using base::StringPiece; |
| 23 |
| 24 namespace net { |
| 25 class ServerBoundCertService; |
| 26 } |
| 21 | 27 |
| 22 namespace extensions { | 28 namespace extensions { |
| 23 | 29 |
| 24 class MockTCPSocket : public net::TCPClientSocket { | 30 class MockSSLClientSocket : public net::SSLClientSocket { |
| 31 public: |
| 32 MockSSLClientSocket() {} |
| 33 MOCK_METHOD0(Disconnect, void()); |
| 34 MOCK_METHOD3(Read, int(net::IOBuffer* buf, int buf_len, |
| 35 const net::CompletionCallback& callback)); |
| 36 MOCK_METHOD3(Write, int(net::IOBuffer* buf, int buf_len, |
| 37 const net::CompletionCallback& callback)); |
| 38 MOCK_METHOD1(SetReceiveBufferSize, bool(int32)); |
| 39 MOCK_METHOD1(SetSendBufferSize, bool(int32)); |
| 40 MOCK_METHOD1(Connect, int(const CompletionCallback&)); |
| 41 MOCK_CONST_METHOD0(IsConnectedAndIdle, bool()); |
| 42 MOCK_CONST_METHOD1(GetPeerAddress, int(net::IPEndPoint*) ); |
| 43 MOCK_CONST_METHOD1(GetLocalAddress, int(net::IPEndPoint*)); |
| 44 MOCK_CONST_METHOD0(NetLog, const net::BoundNetLog& ()); |
| 45 MOCK_METHOD0(SetSubresourceSpeculation, void ()); |
| 46 MOCK_METHOD0(SetOmniboxSpeculation, void ()); |
| 47 MOCK_CONST_METHOD0(WasEverUsed, bool ()); |
| 48 MOCK_CONST_METHOD0(UsingTCPFastOpen, bool ()); |
| 49 MOCK_METHOD1(GetSSLInfo, bool (net::SSLInfo*)); |
| 50 MOCK_METHOD5(ExportKeyingMaterial, int( |
| 51 const StringPiece&, bool, const StringPiece&, unsigned char*, |
| 52 unsigned int)); |
| 53 MOCK_METHOD1(GetTLSUniqueChannelBinding, int (std::string*)); |
| 54 MOCK_METHOD1(GetSSLCertRequestInfo, void (net::SSLCertRequestInfo*)); |
| 55 MOCK_METHOD2(GetNextProto, net::SSLClientSocket::NextProtoStatus( |
| 56 std::string*, std::string*)); |
| 57 MOCK_CONST_METHOD0(GetServerBoundCertService, net::ServerBoundCertService*()); |
| 58 virtual bool IsConnected() const OVERRIDE { |
| 59 return true; |
| 60 } |
| 61 private: |
| 62 DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocket); |
| 63 }; |
| 64 |
| 65 class MockTCPSocket: public net::TCPClientSocket { |
| 25 public: | 66 public: |
| 26 explicit MockTCPSocket(const net::AddressList& address_list) | 67 explicit MockTCPSocket(const net::AddressList& address_list) |
| 27 : net::TCPClientSocket(address_list, NULL, net::NetLog::Source()) { | 68 : net::TCPClientSocket(address_list, NULL, net::NetLog::Source()) { |
| 28 } | 69 } |
| 29 | 70 |
| 30 MOCK_METHOD3(Read, int(net::IOBuffer* buf, int buf_len, | 71 MOCK_METHOD3(Read, int(net::IOBuffer* buf, int buf_len, |
| 31 const net::CompletionCallback& callback)); | 72 const net::CompletionCallback& callback)); |
| 32 MOCK_METHOD3(Write, int(net::IOBuffer* buf, int buf_len, | 73 MOCK_METHOD3(Write, int(net::IOBuffer* buf, int buf_len, |
| 33 const net::CompletionCallback& callback)); | 74 const net::CompletionCallback& callback)); |
| 34 MOCK_METHOD2(SetKeepAlive, bool(bool enable, int delay)); | 75 MOCK_METHOD2(SetKeepAlive, bool(bool enable, int delay)); |
| 35 MOCK_METHOD1(SetNoDelay, bool(bool no_delay)); | 76 MOCK_METHOD1(SetNoDelay, bool(bool no_delay)); |
| 77 |
| 36 virtual bool IsConnected() const OVERRIDE { | 78 virtual bool IsConnected() const OVERRIDE { |
| 37 return true; | 79 return true; |
| 38 } | 80 } |
| 39 | |
| 40 private: | 81 private: |
| 41 DISALLOW_COPY_AND_ASSIGN(MockTCPSocket); | 82 DISALLOW_COPY_AND_ASSIGN(MockTCPSocket); |
| 42 }; | 83 }; |
| 43 | 84 |
| 44 class MockTCPServerSocket : public net::TCPServerSocket { | |
| 45 public: | |
| 46 explicit MockTCPServerSocket() | |
| 47 : net::TCPServerSocket(NULL, net::NetLog::Source()) { | |
| 48 } | |
| 49 MOCK_METHOD2(Listen, int(const net::IPEndPoint& address, int backlog)); | |
| 50 MOCK_METHOD2(Accept, int(scoped_ptr<net::StreamSocket>* socket, | |
| 51 const net::CompletionCallback& callback)); | |
| 52 | |
| 53 private: | |
| 54 DISALLOW_COPY_AND_ASSIGN(MockTCPServerSocket); | |
| 55 }; | |
| 56 | |
| 57 class CompleteHandler { | 85 class CompleteHandler { |
| 58 public: | 86 public: |
| 59 CompleteHandler() {} | 87 CompleteHandler() {} |
| 60 MOCK_METHOD1(OnComplete, void(int result_code)); | 88 MOCK_METHOD1(OnComplete, void(int result_code)); |
| 61 MOCK_METHOD2(OnReadComplete, void(int result_code, | 89 MOCK_METHOD2(OnReadComplete, void(int result_code, |
| 62 scoped_refptr<net::IOBuffer> io_buffer)); | 90 scoped_refptr<net::IOBuffer> io_buffer)); |
| 63 MOCK_METHOD2(OnAccept, void(int, net::TCPClientSocket*)); | 91 MOCK_METHOD2(OnAccept, void(int, net::TCPClientSocket*)); |
| 64 private: | 92 private: |
| 65 DISALLOW_COPY_AND_ASSIGN(CompleteHandler); | 93 DISALLOW_COPY_AND_ASSIGN(CompleteHandler); |
| 66 }; | 94 }; |
| 67 | 95 |
| 68 const std::string FAKE_ID = "abcdefghijklmnopqrst"; | 96 static const char FAKE_ID[]="faktetesttlssocketunittest"; |
| 69 | 97 |
| 70 TEST(SocketTest, TestTCPSocketRead) { | 98 TEST(SocketTest, TestTLSSocketRead) { |
| 71 net::AddressList address_list; | 99 net::AddressList address_list; |
| 72 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); | 100 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); |
| 101 MockSSLClientSocket *ssl_socket = new MockSSLClientSocket; |
| 73 CompleteHandler handler; | 102 CompleteHandler handler; |
| 74 | 103 |
| 75 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( | 104 scoped_ptr<TLSSocket> socket(new TLSSocket( |
| 76 tcp_client_socket, FAKE_ID, true)); | 105 ssl_socket, tcp_client_socket, FAKE_ID)); |
| 77 | 106 |
| 78 EXPECT_CALL(*tcp_client_socket, Read(_, _, _)) | 107 EXPECT_CALL(*ssl_socket, Read(_, _, _)) |
| 79 .Times(1); | 108 .Times(1); |
| 80 EXPECT_CALL(handler, OnReadComplete(_, _)) | 109 EXPECT_CALL(handler, OnReadComplete(_, _)) |
| 81 .Times(1); | 110 .Times(1); |
| 82 | 111 |
| 83 const int count = 512; | 112 const int count = 512; |
| 84 socket->Read(count, base::Bind(&CompleteHandler::OnReadComplete, | 113 socket->Read(count, base::Bind(&CompleteHandler::OnReadComplete, |
| 85 base::Unretained(&handler))); | 114 base::Unretained(&handler))); |
| 86 } | 115 } |
| 87 | 116 |
| 88 TEST(SocketTest, TestTCPSocketWrite) { | 117 TEST(SocketTest, TestTLSSocketWrite) { |
| 89 net::AddressList address_list; | 118 net::AddressList address_list; |
| 90 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); | 119 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); |
| 120 MockSSLClientSocket *ssl_socket = new MockSSLClientSocket; |
| 91 CompleteHandler handler; | 121 CompleteHandler handler; |
| 92 | 122 |
| 93 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( | 123 scoped_ptr<TLSSocket> socket(new TLSSocket( |
| 94 tcp_client_socket, FAKE_ID, true)); | 124 ssl_socket, tcp_client_socket, FAKE_ID)); |
| 95 | 125 |
| 96 net::CompletionCallback callback; | 126 net::CompletionCallback callback; |
| 97 EXPECT_CALL(*tcp_client_socket, Write(_, _, _)) | 127 EXPECT_CALL(*ssl_socket, Write(_, _, _)) |
| 98 .Times(2) | 128 .Times(2) |
| 99 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), | 129 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), |
| 100 Return(128))); | 130 Return(128))); |
| 101 EXPECT_CALL(handler, OnComplete(_)) | 131 EXPECT_CALL(handler, OnComplete(_)) |
| 102 .Times(1); | 132 .Times(1); |
| 103 | 133 |
| 104 scoped_refptr<net::IOBufferWithSize> io_buffer( | 134 scoped_refptr<net::IOBufferWithSize> io_buffer( |
| 105 new net::IOBufferWithSize(256)); | 135 new net::IOBufferWithSize(256)); |
| 106 socket->Write(io_buffer.get(), io_buffer->size(), | 136 socket->Write(io_buffer.get(), io_buffer->size(), |
| 107 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); | 137 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); |
| 108 } | 138 } |
| 109 | 139 |
| 110 TEST(SocketTest, TestTCPSocketBlockedWrite) { | 140 TEST(SocketTest, TestTLSSocketBlockedWrite) { |
| 111 net::AddressList address_list; | 141 net::AddressList address_list; |
| 112 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); | 142 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); |
| 143 MockSSLClientSocket *ssl_socket = new MockSSLClientSocket; |
| 113 CompleteHandler handler; | 144 CompleteHandler handler; |
| 114 | 145 |
| 115 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( | 146 scoped_ptr<TLSSocket> socket(new TLSSocket( |
| 116 tcp_client_socket, FAKE_ID, true)); | 147 ssl_socket, tcp_client_socket, FAKE_ID)); |
| 117 | 148 |
| 118 net::CompletionCallback callback; | 149 net::CompletionCallback callback; |
| 119 EXPECT_CALL(*tcp_client_socket, Write(_, _, _)) | 150 EXPECT_CALL(*ssl_socket, Write(_, _, _)) |
| 120 .Times(2) | 151 .Times(2) |
| 121 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), | 152 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), |
| 122 Return(net::ERR_IO_PENDING))); | 153 Return(net::ERR_IO_PENDING))); |
| 123 scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(42)); | 154 scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(42)); |
| 124 socket->Write(io_buffer.get(), io_buffer->size(), | 155 socket->Write(io_buffer.get(), io_buffer->size(), |
| 125 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); | 156 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); |
| 126 | 157 |
| 127 // Good. Original call came back unable to complete. Now pretend the socket | 158 // Good. Original call came back unable to complete. Now pretend the socket |
| 128 // finished, and confirm that we passed the error back. | 159 // finished, and confirm that we passed the error back. |
| 129 EXPECT_CALL(handler, OnComplete(42)) | 160 EXPECT_CALL(handler, OnComplete(42)) |
| 130 .Times(1); | 161 .Times(1); |
| 131 callback.Run(40); | 162 callback.Run(40); |
| 132 callback.Run(2); | 163 callback.Run(2); |
| 133 } | 164 } |
| 134 | 165 |
| 135 TEST(SocketTest, TestTCPSocketBlockedWriteReentry) { | 166 TEST(SocketTest, TestTLSSocketBlockedWriteReentry) { |
| 136 net::AddressList address_list; | 167 net::AddressList address_list; |
| 137 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); | 168 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); |
| 169 MockSSLClientSocket *ssl_socket = new MockSSLClientSocket; |
| 138 CompleteHandler handlers[5]; | 170 CompleteHandler handlers[5]; |
| 139 | 171 |
| 140 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( | 172 scoped_ptr<TLSSocket> socket(new TLSSocket( |
| 141 tcp_client_socket, FAKE_ID, true)); | 173 ssl_socket, tcp_client_socket, FAKE_ID)); |
| 142 | 174 |
| 143 net::CompletionCallback callback; | 175 net::CompletionCallback callback; |
| 144 EXPECT_CALL(*tcp_client_socket, Write(_, _, _)) | 176 EXPECT_CALL(*ssl_socket, Write(_, _, _)) |
| 145 .Times(5) | 177 .Times(5) |
| 146 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), | 178 .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback), |
| 147 Return(net::ERR_IO_PENDING))); | 179 Return(net::ERR_IO_PENDING))); |
| 148 scoped_refptr<net::IOBufferWithSize> io_buffers[5]; | 180 scoped_refptr<net::IOBufferWithSize> io_buffers[5]; |
| 149 int i; | 181 int i; |
| 150 for (i = 0; i < 5; i++) { | 182 for (i = 0; i < 5; i++) { |
| 151 io_buffers[i] = new net::IOBufferWithSize(128 + i * 50); | 183 io_buffers[i] = new net::IOBufferWithSize(128 + i * 50); |
| 152 scoped_refptr<net::IOBufferWithSize> io_buffer1( | 184 scoped_refptr<net::IOBufferWithSize> io_buffer1( |
| 153 new net::IOBufferWithSize(42)); | 185 new net::IOBufferWithSize(42)); |
| 154 socket->Write(io_buffers[i].get(), io_buffers[i]->size(), | 186 socket->Write(io_buffers[i].get(), io_buffers[i]->size(), |
| 155 base::Bind(&CompleteHandler::OnComplete, | 187 base::Bind(&CompleteHandler::OnComplete, |
| 156 base::Unretained(&handlers[i]))); | 188 base::Unretained(&handlers[i]))); |
| 157 | 189 |
| 158 EXPECT_CALL(handlers[i], OnComplete(io_buffers[i]->size())) | 190 EXPECT_CALL(handlers[i], OnComplete(io_buffers[i]->size())) |
| 159 .Times(1); | 191 .Times(1); |
| 160 } | 192 } |
| 161 | 193 |
| 162 for (i = 0; i < 5; i++) { | 194 for (i = 0; i < 5; i++) { |
| 163 callback.Run(128 + i * 50); | 195 callback.Run(128 + i * 50); |
| 164 } | 196 } |
| 165 } | 197 } |
| 166 | 198 |
| 167 TEST(SocketTest, TestTCPSocketSetNoDelay) { | |
| 168 net::AddressList address_list; | |
| 169 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); | |
| 170 | |
| 171 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( | |
| 172 tcp_client_socket, FAKE_ID)); | |
| 173 | |
| 174 bool no_delay = false; | |
| 175 EXPECT_CALL(*tcp_client_socket, SetNoDelay(_)) | |
| 176 .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(true))); | |
| 177 int result = socket->SetNoDelay(true); | |
| 178 EXPECT_TRUE(result); | |
| 179 EXPECT_TRUE(no_delay); | |
| 180 | |
| 181 EXPECT_CALL(*tcp_client_socket, SetNoDelay(_)) | |
| 182 .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(false))); | |
| 183 | |
| 184 result = socket->SetNoDelay(false); | |
| 185 EXPECT_FALSE(result); | |
| 186 EXPECT_FALSE(no_delay); | |
| 187 } | |
| 188 | |
| 189 TEST(SocketTest, TestTCPSocketSetKeepAlive) { | |
| 190 net::AddressList address_list; | |
| 191 MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list); | |
| 192 | |
| 193 scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting( | |
| 194 tcp_client_socket, FAKE_ID)); | |
| 195 | |
| 196 bool enable = false; | |
| 197 int delay = 0; | |
| 198 EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _)) | |
| 199 .WillOnce(testing::DoAll(SaveArg<0>(&enable), | |
| 200 SaveArg<1>(&delay), | |
| 201 Return(true))); | |
| 202 int result = socket->SetKeepAlive(true, 4500); | |
| 203 EXPECT_TRUE(result); | |
| 204 EXPECT_TRUE(enable); | |
| 205 EXPECT_EQ(4500, delay); | |
| 206 | |
| 207 EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _)) | |
| 208 .WillOnce(testing::DoAll(SaveArg<0>(&enable), | |
| 209 SaveArg<1>(&delay), | |
| 210 Return(false))); | |
| 211 result = socket->SetKeepAlive(false, 0); | |
| 212 EXPECT_FALSE(result); | |
| 213 EXPECT_FALSE(enable); | |
| 214 EXPECT_EQ(0, delay); | |
| 215 } | |
| 216 | |
| 217 TEST(SocketTest, TestTCPServerSocketListenAccept) { | |
| 218 MockTCPServerSocket* tcp_server_socket = new MockTCPServerSocket(); | |
| 219 CompleteHandler handler; | |
| 220 | |
| 221 scoped_ptr<TCPSocket> socket(TCPSocket::CreateServerSocketForTesting( | |
| 222 tcp_server_socket, FAKE_ID)); | |
| 223 | |
| 224 EXPECT_CALL(*tcp_server_socket, Accept(_, _)).Times(1); | |
| 225 EXPECT_CALL(*tcp_server_socket, Listen(_, _)).Times(1); | |
| 226 EXPECT_CALL(handler, OnAccept(_, _)); | |
| 227 | |
| 228 std::string err_msg; | |
| 229 EXPECT_EQ(net::OK, socket->Listen("127.0.0.1", 9999, 10, &err_msg)); | |
| 230 socket->Accept(base::Bind(&CompleteHandler::OnAccept, | |
| 231 base::Unretained(&handler))); | |
| 232 } | |
| 233 | |
| 234 } // namespace extensions | 199 } // namespace extensions |
| OLD | NEW |