Chromium Code Reviews| Index: extensions/browser/api/cast_channel/cast_socket_unittest.cc |
| diff --git a/extensions/browser/api/cast_channel/cast_socket_unittest.cc b/extensions/browser/api/cast_channel/cast_socket_unittest.cc |
| index 08b18acf01a56c27b35bf7c3d68164d83f4ea61d..c7d5e45ee6b69ee668bbfc4b9abe5afc8d44caf8 100644 |
| --- a/extensions/browser/api/cast_channel/cast_socket_unittest.cc |
| +++ b/extensions/browser/api/cast_channel/cast_socket_unittest.cc |
| @@ -9,6 +9,7 @@ |
| #include <utility> |
| #include <vector> |
| +#include "base/files/file_util.h" |
| #include "base/location.h" |
| #include "base/macros.h" |
| #include "base/memory/ptr_util.h" |
| @@ -21,6 +22,8 @@ |
| #include "base/test/simple_test_clock.h" |
| #include "base/threading/thread_task_runner_handle.h" |
| #include "base/timer/mock_timer.h" |
| +#include "content/public/test/test_browser_thread_bundle.h" |
| +#include "crypto/rsa_private_key.h" |
| #include "extensions/browser/api/cast_channel/cast_auth_util.h" |
| #include "extensions/browser/api/cast_channel/cast_framer.h" |
| #include "extensions/browser/api/cast_channel/cast_message_util.h" |
| @@ -30,11 +33,14 @@ |
| #include "extensions/common/api/cast_channel/cast_channel.pb.h" |
| #include "net/base/address_list.h" |
| #include "net/base/net_errors.h" |
| +#include "net/cert/pem_tokenizer.h" |
| #include "net/log/test_net_log.h" |
| #include "net/socket/socket_test_util.h" |
| #include "net/socket/ssl_client_socket.h" |
| +#include "net/socket/ssl_server_socket.h" |
| #include "net/socket/tcp_client_socket.h" |
| #include "net/ssl/ssl_info.h" |
| +#include "net/ssl/ssl_server_config.h" |
| #include "net/test/cert_test_util.h" |
| #include "net/test/test_data_directory.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| @@ -171,6 +177,97 @@ class CompleteHandler { |
| DISALLOW_COPY_AND_ASSIGN(CompleteHandler); |
| }; |
| +class FakeTCPClientSocket : public net::TCPClientSocket { |
| + public: |
| + FakeTCPClientSocket(net::FakeSocket* socket) |
| + : net::TCPClientSocket(net::AddressList(), |
| + nullptr, |
| + nullptr, |
| + net::NetLog::Source()), |
| + socket_(socket) { |
| + DCHECK(socket_); |
| + } |
| + |
| + int Read(net::IOBuffer* buf, |
| + int buf_len, |
| + const net::CompletionCallback& callback) override { |
| + return socket_->Read(buf, buf_len, callback); |
| + } |
| + |
| + int Write(net::IOBuffer* buf, |
| + int buf_len, |
| + const net::CompletionCallback& callback) override { |
| + return socket_->Write(buf, buf_len, callback); |
| + } |
| + |
| + int SetReceiveBufferSize(int32_t size) override { |
| + return socket_->SetReceiveBufferSize(size); |
| + } |
| + |
| + int SetSendBufferSize(int32_t size) override { |
| + return socket_->SetSendBufferSize(size); |
| + } |
| + |
| + int Connect(const net::CompletionCallback& callback) override { |
| + return socket_->Connect(callback); |
| + } |
| + |
| + void Disconnect() override { return socket_->Disconnect(); } |
| + |
| + bool IsConnected() const override { return socket_->IsConnected(); } |
| + |
| + bool IsConnectedAndIdle() const override { |
| + return socket_->IsConnectedAndIdle(); |
| + } |
| + |
| + int GetPeerAddress(net::IPEndPoint* address) const override { |
| + return socket_->GetPeerAddress(address); |
| + } |
| + |
| + int GetLocalAddress(net::IPEndPoint* address) const override { |
| + return socket_->GetLocalAddress(address); |
| + } |
| + |
| + const net::BoundNetLog& NetLog() const override { return socket_->NetLog(); } |
| + |
| + void SetSubresourceSpeculation() override { |
| + socket_->SetSubresourceSpeculation(); |
| + } |
| + |
| + void SetOmniboxSpeculation() override { socket_->SetOmniboxSpeculation(); } |
| + |
| + bool WasEverUsed() const override { return socket_->WasEverUsed(); } |
| + |
| + bool WasNpnNegotiated() const override { return socket_->WasNpnNegotiated(); } |
| + |
| + net::NextProto GetNegotiatedProtocol() const override { |
| + return socket_->GetNegotiatedProtocol(); |
| + } |
| + |
| + bool GetSSLInfo(net::SSLInfo* ssl_info) override { |
| + return socket_->GetSSLInfo(ssl_info); |
| + } |
| + |
| + void GetConnectionAttempts(net::ConnectionAttempts* out) const override { |
| + socket_->GetConnectionAttempts(out); |
| + } |
| + |
| + void ClearConnectionAttempts() override { |
| + socket_->ClearConnectionAttempts(); |
| + } |
| + |
| + void AddConnectionAttempts(const net::ConnectionAttempts& attempts) override { |
| + socket_->AddConnectionAttempts(attempts); |
| + } |
| + |
| + int64_t GetTotalReceivedBytes() const override { |
| + return socket_->GetTotalReceivedBytes(); |
| + } |
| + |
| + private: |
| + net::FakeSocket* socket_; |
| +}; |
| + |
| class TestCastSocket : public CastSocketImpl { |
| public: |
| static std::unique_ptr<TestCastSocket> Create( |
| @@ -222,7 +319,8 @@ class TestCastSocket : public CastSocketImpl { |
| verify_challenge_disallow_(false), |
| tcp_unresponsive_(false), |
| mock_timer_(new base::MockTimer(false, false)), |
| - mock_transport_(nullptr) {} |
| + mock_transport_(nullptr), |
| + fake_socket_(nullptr) {} |
| ~TestCastSocket() override {} |
| @@ -231,6 +329,10 @@ class TestCastSocket : public CastSocketImpl { |
| SetTransportForTesting(base::WrapUnique(mock_transport_)); |
| } |
| + void SetFakeSocket(net::FakeSocket* fake_socket) { |
| + fake_socket_ = fake_socket; |
| + } |
| + |
| // Socket connection helpers. |
| void SetupTcpConnect(net::IoMode mode, int result) { |
| tcp_connect_data_.reset(new net::MockConnect(mode, result)); |
| @@ -296,6 +398,9 @@ class TestCastSocket : public CastSocketImpl { |
| std::unique_ptr<net::TCPClientSocket> CreateTcpSocket() override { |
| if (tcp_unresponsive_) { |
| return std::unique_ptr<net::TCPClientSocket>(new MockTCPSocket(true)); |
| + } else if (fake_socket_) { |
| + return std::unique_ptr<net::TCPClientSocket>( |
| + new FakeTCPClientSocket(fake_socket_)); |
| } else { |
| net::MockConnect* connect_data = tcp_connect_data_.get(); |
| connect_data->peer_addr = ip_; |
| @@ -306,6 +411,9 @@ class TestCastSocket : public CastSocketImpl { |
| std::unique_ptr<net::SSLClientSocket> CreateSslSocket( |
| std::unique_ptr<net::StreamSocket> socket) override { |
| + if (fake_socket_) { |
| + return CastSocketImpl::CreateSslSocket(std::move(socket)); |
| + } |
| net::MockConnect* connect_data = ssl_connect_data_.get(); |
| connect_data->peer_addr = ip_; |
| @@ -350,6 +458,12 @@ class TestCastSocket : public CastSocketImpl { |
| std::unique_ptr<base::MockTimer> mock_timer_; |
| MockCastTransport* mock_transport_; |
| + // A fake socket that is used instead of the mocks when testing with the real |
| + // SSL implementation. When this is set, CreateTCPSocket() will use this as |
| + // the underlying transport and CreateSSLSocket() will call the base |
| + // implementation in CastSocketImpl. |
| + net::FakeSocket* fake_socket_; |
| + |
| DISALLOW_COPY_AND_ASSIGN(TestCastSocket); |
| }; |
| @@ -378,6 +492,52 @@ class CastSocketTest : public testing::Test { |
| socket_ = TestCastSocket::CreateSecure(logger_); |
| } |
| + // Initializes the SSLServerSocket |server_socket_| and the fake socket |
| + // |fake_client_socket_| to be used by the CastSocket |socket_| as transport |
| + // to |server_socket_|. |
| + void CreateFakeSockets() { |
| + channel_1_.reset(new net::FakeDataChannel()); |
| + channel_2_.reset(new net::FakeDataChannel()); |
| + server_cert_ = |
| + net::ImportCertFromFile(net::GetTestCertsDirectory(), "ok_cert.pem"); |
| + ASSERT_TRUE(server_cert_); |
| + server_private_key_ = ReadTestKeyFromPEM("ok_cert.pem"); |
| + ASSERT_TRUE(server_private_key_); |
| + server_context_ = CreateSSLServerContext( |
| + server_cert_.get(), *server_private_key_, server_ssl_config_); |
| + |
| + fake_client_socket_.reset( |
| + new net::FakeSocket(channel_1_.get(), channel_2_.get())); |
| + std::unique_ptr<net::StreamSocket> server_socket( |
|
Wez
2016/07/01 01:27:56
nit: fake_server_socket?
btolsch
2016/07/01 03:21:08
This was effectively changed to |accepted_socket|.
|
| + new net::FakeSocket(channel_2_.get(), channel_1_.get())); |
| + |
| + server_socket_ = |
| + server_context_->CreateSSLServerSocket(std::move(server_socket)); |
| + ASSERT_TRUE(server_socket_); |
| + } |
| + |
| + std::unique_ptr<crypto::RSAPrivateKey> ReadTestKeyFromPEM( |
| + const base::StringPiece& name) { |
| + base::FilePath certs_dir(net::GetTestCertsDirectory()); |
| + base::FilePath key_path = certs_dir.AppendASCII(name); |
| + std::vector<std::string> headers; |
| + headers.push_back("PRIVATE KEY"); |
| + std::string pem_data; |
| + if (!base::ReadFileToString(key_path, &pem_data)) { |
| + return nullptr; |
| + } |
| + net::PEMTokenizer pem_tokenizer(pem_data, headers); |
| + if (!pem_tokenizer.GetNext()) { |
| + return nullptr; |
| + } |
| + std::vector<uint8_t> key_vector(pem_tokenizer.data().begin(), |
| + pem_tokenizer.data().end()); |
| + |
| + std::unique_ptr<crypto::RSAPrivateKey> key( |
| + crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); |
| + return key; |
| + } |
| + |
| void HandleAuthHandshake() { |
| socket_->SetupMockTransport(); |
| CastMessage challenge_proto = CreateAuthChallenge(); |
| @@ -395,6 +555,36 @@ class CastSocketTest : public testing::Test { |
| RunPendingTasks(); |
| } |
| + int ReadExactLength(net::IOBuffer* buffer, |
| + int buffer_length, |
| + net::Socket* socket) { |
| + scoped_refptr<net::DrainableIOBuffer> draining_buffer( |
| + new net::DrainableIOBuffer(buffer, buffer_length)); |
| + while (draining_buffer->BytesRemaining() > 0) { |
| + net::TestCompletionCallback read_callback; |
| + int read_result = read_callback.GetResult(server_socket_->Read( |
| + draining_buffer.get(), buffer_length, read_callback.callback())); |
| + EXPECT_GT(read_result, 0); |
| + draining_buffer->DidConsume(read_result); |
| + } |
| + return buffer_length; |
| + } |
| + |
| + int WriteExactLength(net::IOBuffer* buffer, |
| + int buffer_length, |
| + net::Socket* socket) { |
| + scoped_refptr<net::DrainableIOBuffer> draining_buffer( |
| + new net::DrainableIOBuffer(buffer, buffer_length)); |
| + while (draining_buffer->BytesRemaining() > 0) { |
| + net::TestCompletionCallback write_callback; |
| + int write_result = write_callback.GetResult(server_socket_->Write( |
| + draining_buffer.get(), buffer_length, write_callback.callback())); |
| + EXPECT_GT(write_result, 0); |
| + draining_buffer->DidConsume(write_result); |
| + } |
| + return buffer_length; |
| + } |
| + |
| protected: |
| // Runs all pending tasks in the message loop. |
| void RunPendingTasks() { |
| @@ -402,12 +592,32 @@ class CastSocketTest : public testing::Test { |
| run_loop.RunUntilIdle(); |
| } |
| - base::MessageLoop message_loop_; |
| + content::TestBrowserThreadBundle thread_bundle_{ |
| + content::TestBrowserThreadBundle::IO_MAINLOOP}; |
|
Wez
2016/07/01 01:27:56
nit: Is there a good reason to use {} init rather
btolsch
2016/07/01 03:21:08
No, just a habit that slipped through. Moved to c
|
| Logger* logger_; |
| std::unique_ptr<TestCastSocket> socket_; |
| CompleteHandler handler_; |
| std::unique_ptr<MockDelegate> delegate_; |
| + // |channel_1_| and |channel_2_| form a full duplex connection between |
| + // |socket_| and |server_socket_|. Passing data through them requires pumping |
| + // the message loop. |
| + std::unique_ptr<net::FakeDataChannel> channel_1_; |
| + std::unique_ptr<net::FakeDataChannel> channel_2_; |
| + |
| + // Used to create the fake TCP socket for |socket_| when communicating with |
| + // |server_socket_|. |
| + std::unique_ptr<net::FakeSocket> fake_client_socket_; |
| + |
| + // |server_socket_| is used for the *RealSSL tests in order to test the |
| + // CastSocket over a real SSL socket. The other members below are used to |
| + // initialize |server_socket_|. |
| + std::unique_ptr<net::SSLServerSocket> server_socket_; |
| + std::unique_ptr<net::SSLServerContext> server_context_; |
| + std::unique_ptr<crypto::RSAPrivateKey> server_private_key_; |
| + scoped_refptr<net::X509Certificate> server_cert_; |
| + net::SSLServerConfig server_ssl_config_; |
| + |
| private: |
| DISALLOW_COPY_AND_ASSIGN(CastSocketTest); |
| }; |
| @@ -803,6 +1013,132 @@ TEST_F(CastSocketTest, TestConnectChallengeVerificationFails) { |
| socket_->error_state()); |
| } |
| +// Tests connecting through an actual non-mocked CastTransport object and |
| +// non-mocked SSLClientSocket, testing the components in integration. |
| +TEST_F(CastSocketTest, TestConnectEndToEndWithRealSSL) { |
| + CreateCastSocketSecure(); |
| + CreateFakeSockets(); |
| + socket_->SetFakeSocket(fake_client_socket_.get()); |
| + |
| + socket_->Connect(std::move(delegate_), |
| + base::Bind(&CompleteHandler::OnConnectComplete, |
| + base::Unretained(&handler_))); |
| + |
| + net::TestCompletionCallback handshake_callback; |
| + int server_ret = handshake_callback.GetResult( |
| + server_socket_->Handshake(handshake_callback.callback())); |
| + |
| + ASSERT_EQ(net::OK, server_ret); |
| + |
| + // Set low-level auth challenge expectations. |
| + CastMessage challenge = CreateAuthChallenge(); |
| + std::string challenge_str; |
| + EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str)); |
| + |
| + int challenge_buffer_length = challenge_str.size(); |
| + scoped_refptr<net::IOBuffer> challenge_buffer( |
| + new net::IOBuffer(challenge_buffer_length)); |
| + int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length, |
| + server_socket_.get()); |
| + |
| + EXPECT_EQ(challenge_buffer_length, read); |
| + EXPECT_EQ(challenge_str, |
| + std::string(challenge_buffer->data(), challenge_buffer_length)); |
| + |
| + // Set low-level auth reply expectations. |
| + CastMessage reply = CreateAuthReply(); |
| + std::string reply_str; |
| + EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str)); |
| + |
| + scoped_refptr<net::StringIOBuffer> reply_buffer( |
| + new net::StringIOBuffer(reply_str)); |
| + int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(), |
| + server_socket_.get()); |
| + |
| + EXPECT_EQ(reply_buffer->size(), written); |
| + EXPECT_CALL(handler_, OnConnectComplete(CHANNEL_ERROR_NONE)); |
| + RunPendingTasks(); |
| + |
| + EXPECT_EQ(cast_channel::READY_STATE_OPEN, socket_->ready_state()); |
| + EXPECT_EQ(cast_channel::CHANNEL_ERROR_NONE, socket_->error_state()); |
| +} |
| + |
| +// Sends message data through an actual non-mocked CastTransport object and |
| +// non-mocked SSLClientSocket, testing the components in integration. |
| +TEST_F(CastSocketTest, TestMessageEndToEndWithRealSSL) { |
| + CreateCastSocketSecure(); |
| + CreateFakeSockets(); |
| + socket_->SetFakeSocket(fake_client_socket_.get()); |
|
Wez
2016/07/01 01:27:56
Is it necessary for the test to retain ownership o
btolsch
2016/07/01 03:21:08
Done (passed ownership to |socket_|).
|
| + |
| + socket_->Connect(std::move(delegate_), |
| + base::Bind(&CompleteHandler::OnConnectComplete, |
| + base::Unretained(&handler_))); |
| + |
| + net::TestCompletionCallback handshake_callback; |
| + int server_ret = handshake_callback.GetResult( |
| + server_socket_->Handshake(handshake_callback.callback())); |
| + |
| + ASSERT_EQ(net::OK, server_ret); |
| + |
| + // Set low-level auth challenge expectations. |
| + CastMessage challenge = CreateAuthChallenge(); |
| + std::string challenge_str; |
| + EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str)); |
| + |
| + int challenge_buffer_length = challenge_str.size(); |
| + scoped_refptr<net::IOBuffer> challenge_buffer( |
| + new net::IOBuffer(challenge_buffer_length)); |
| + |
| + int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length, |
| + server_socket_.get()); |
| + |
| + EXPECT_EQ(challenge_buffer_length, read); |
| + EXPECT_EQ(challenge_str, |
| + std::string(challenge_buffer->data(), challenge_buffer_length)); |
| + |
| + // Set low-level auth reply expectations. |
| + CastMessage reply = CreateAuthReply(); |
| + std::string reply_str; |
| + EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str)); |
| + |
| + scoped_refptr<net::StringIOBuffer> reply_buffer( |
| + new net::StringIOBuffer(reply_str)); |
| + int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(), |
| + server_socket_.get()); |
| + |
| + EXPECT_EQ(reply_buffer->size(), written); |
| + EXPECT_CALL(handler_, OnConnectComplete(CHANNEL_ERROR_NONE)); |
| + RunPendingTasks(); |
| + |
| + EXPECT_EQ(cast_channel::READY_STATE_OPEN, socket_->ready_state()); |
| + EXPECT_EQ(cast_channel::CHANNEL_ERROR_NONE, socket_->error_state()); |
| + |
| + // Send a test message through the ssl socket. |
| + CastMessage test_message = CreateTestMessage(); |
| + std::string test_message_str; |
| + EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str)); |
| + |
| + int test_message_length = test_message_str.size(); |
| + scoped_refptr<net::IOBuffer> test_message_buffer( |
| + new net::IOBuffer(test_message_length)); |
| + |
| + EXPECT_CALL(handler_, OnWriteComplete(net::OK)); |
| + socket_->transport()->SendMessage( |
| + test_message, base::Bind(&CompleteHandler::OnWriteComplete, |
| + base::Unretained(&handler_))); |
| + RunPendingTasks(); |
| + |
| + read = ReadExactLength(test_message_buffer.get(), test_message_length, |
| + server_socket_.get()); |
| + |
| + EXPECT_EQ(test_message_length, read); |
| + EXPECT_EQ(test_message_str, |
| + std::string(test_message_buffer->data(), test_message_length)); |
| + |
| + EXPECT_EQ(cast_channel::READY_STATE_OPEN, socket_->ready_state()); |
| + EXPECT_EQ(cast_channel::CHANNEL_ERROR_NONE, socket_->error_state()); |
| +} |
| + |
| // Sends message data through an actual non-mocked CastTransport object, |
| // testing the two components in integration. |
| TEST_F(CastSocketTest, TestConnectEndToEndWithRealTransportAsync) { |