Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(14)

Unified Diff: extensions/browser/api/cast_channel/cast_socket_unittest.cc

Issue 2093923004: [Cast Channel] Add real SSL tests to CastSocketTest (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Responding to rsleevi's comments Created 4 years, 6 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
« no previous file with comments | « no previous file | net/socket/socket_test_util.h » ('j') | no next file with comments »
Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
Index: extensions/browser/api/cast_channel/cast_socket_unittest.cc
diff --git a/extensions/browser/api/cast_channel/cast_socket_unittest.cc b/extensions/browser/api/cast_channel/cast_socket_unittest.cc
index 08b18acf01a56c27b35bf7c3d68164d83f4ea61d..c7d5e45ee6b69ee668bbfc4b9abe5afc8d44caf8 100644
--- a/extensions/browser/api/cast_channel/cast_socket_unittest.cc
+++ b/extensions/browser/api/cast_channel/cast_socket_unittest.cc
@@ -9,6 +9,7 @@
#include <utility>
#include <vector>
+#include "base/files/file_util.h"
#include "base/location.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
@@ -21,6 +22,8 @@
#include "base/test/simple_test_clock.h"
#include "base/threading/thread_task_runner_handle.h"
#include "base/timer/mock_timer.h"
+#include "content/public/test/test_browser_thread_bundle.h"
+#include "crypto/rsa_private_key.h"
#include "extensions/browser/api/cast_channel/cast_auth_util.h"
#include "extensions/browser/api/cast_channel/cast_framer.h"
#include "extensions/browser/api/cast_channel/cast_message_util.h"
@@ -30,11 +33,14 @@
#include "extensions/common/api/cast_channel/cast_channel.pb.h"
#include "net/base/address_list.h"
#include "net/base/net_errors.h"
+#include "net/cert/pem_tokenizer.h"
#include "net/log/test_net_log.h"
#include "net/socket/socket_test_util.h"
#include "net/socket/ssl_client_socket.h"
+#include "net/socket/ssl_server_socket.h"
#include "net/socket/tcp_client_socket.h"
#include "net/ssl/ssl_info.h"
+#include "net/ssl/ssl_server_config.h"
#include "net/test/cert_test_util.h"
#include "net/test/test_data_directory.h"
#include "testing/gmock/include/gmock/gmock.h"
@@ -171,6 +177,97 @@ class CompleteHandler {
DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
};
+class FakeTCPClientSocket : public net::TCPClientSocket {
+ public:
+ FakeTCPClientSocket(net::FakeSocket* socket)
+ : net::TCPClientSocket(net::AddressList(),
+ nullptr,
+ nullptr,
+ net::NetLog::Source()),
+ socket_(socket) {
+ DCHECK(socket_);
+ }
+
+ int Read(net::IOBuffer* buf,
+ int buf_len,
+ const net::CompletionCallback& callback) override {
+ return socket_->Read(buf, buf_len, callback);
+ }
+
+ int Write(net::IOBuffer* buf,
+ int buf_len,
+ const net::CompletionCallback& callback) override {
+ return socket_->Write(buf, buf_len, callback);
+ }
+
+ int SetReceiveBufferSize(int32_t size) override {
+ return socket_->SetReceiveBufferSize(size);
+ }
+
+ int SetSendBufferSize(int32_t size) override {
+ return socket_->SetSendBufferSize(size);
+ }
+
+ int Connect(const net::CompletionCallback& callback) override {
+ return socket_->Connect(callback);
+ }
+
+ void Disconnect() override { return socket_->Disconnect(); }
+
+ bool IsConnected() const override { return socket_->IsConnected(); }
+
+ bool IsConnectedAndIdle() const override {
+ return socket_->IsConnectedAndIdle();
+ }
+
+ int GetPeerAddress(net::IPEndPoint* address) const override {
+ return socket_->GetPeerAddress(address);
+ }
+
+ int GetLocalAddress(net::IPEndPoint* address) const override {
+ return socket_->GetLocalAddress(address);
+ }
+
+ const net::BoundNetLog& NetLog() const override { return socket_->NetLog(); }
+
+ void SetSubresourceSpeculation() override {
+ socket_->SetSubresourceSpeculation();
+ }
+
+ void SetOmniboxSpeculation() override { socket_->SetOmniboxSpeculation(); }
+
+ bool WasEverUsed() const override { return socket_->WasEverUsed(); }
+
+ bool WasNpnNegotiated() const override { return socket_->WasNpnNegotiated(); }
+
+ net::NextProto GetNegotiatedProtocol() const override {
+ return socket_->GetNegotiatedProtocol();
+ }
+
+ bool GetSSLInfo(net::SSLInfo* ssl_info) override {
+ return socket_->GetSSLInfo(ssl_info);
+ }
+
+ void GetConnectionAttempts(net::ConnectionAttempts* out) const override {
+ socket_->GetConnectionAttempts(out);
+ }
+
+ void ClearConnectionAttempts() override {
+ socket_->ClearConnectionAttempts();
+ }
+
+ void AddConnectionAttempts(const net::ConnectionAttempts& attempts) override {
+ socket_->AddConnectionAttempts(attempts);
+ }
+
+ int64_t GetTotalReceivedBytes() const override {
+ return socket_->GetTotalReceivedBytes();
+ }
+
+ private:
+ net::FakeSocket* socket_;
+};
+
class TestCastSocket : public CastSocketImpl {
public:
static std::unique_ptr<TestCastSocket> Create(
@@ -222,7 +319,8 @@ class TestCastSocket : public CastSocketImpl {
verify_challenge_disallow_(false),
tcp_unresponsive_(false),
mock_timer_(new base::MockTimer(false, false)),
- mock_transport_(nullptr) {}
+ mock_transport_(nullptr),
+ fake_socket_(nullptr) {}
~TestCastSocket() override {}
@@ -231,6 +329,10 @@ class TestCastSocket : public CastSocketImpl {
SetTransportForTesting(base::WrapUnique(mock_transport_));
}
+ void SetFakeSocket(net::FakeSocket* fake_socket) {
+ fake_socket_ = fake_socket;
+ }
+
// Socket connection helpers.
void SetupTcpConnect(net::IoMode mode, int result) {
tcp_connect_data_.reset(new net::MockConnect(mode, result));
@@ -296,6 +398,9 @@ class TestCastSocket : public CastSocketImpl {
std::unique_ptr<net::TCPClientSocket> CreateTcpSocket() override {
if (tcp_unresponsive_) {
return std::unique_ptr<net::TCPClientSocket>(new MockTCPSocket(true));
+ } else if (fake_socket_) {
+ return std::unique_ptr<net::TCPClientSocket>(
+ new FakeTCPClientSocket(fake_socket_));
} else {
net::MockConnect* connect_data = tcp_connect_data_.get();
connect_data->peer_addr = ip_;
@@ -306,6 +411,9 @@ class TestCastSocket : public CastSocketImpl {
std::unique_ptr<net::SSLClientSocket> CreateSslSocket(
std::unique_ptr<net::StreamSocket> socket) override {
+ if (fake_socket_) {
+ return CastSocketImpl::CreateSslSocket(std::move(socket));
+ }
net::MockConnect* connect_data = ssl_connect_data_.get();
connect_data->peer_addr = ip_;
@@ -350,6 +458,12 @@ class TestCastSocket : public CastSocketImpl {
std::unique_ptr<base::MockTimer> mock_timer_;
MockCastTransport* mock_transport_;
+ // A fake socket that is used instead of the mocks when testing with the real
+ // SSL implementation. When this is set, CreateTCPSocket() will use this as
+ // the underlying transport and CreateSSLSocket() will call the base
+ // implementation in CastSocketImpl.
+ net::FakeSocket* fake_socket_;
+
DISALLOW_COPY_AND_ASSIGN(TestCastSocket);
};
@@ -378,6 +492,52 @@ class CastSocketTest : public testing::Test {
socket_ = TestCastSocket::CreateSecure(logger_);
}
+ // Initializes the SSLServerSocket |server_socket_| and the fake socket
+ // |fake_client_socket_| to be used by the CastSocket |socket_| as transport
+ // to |server_socket_|.
+ void CreateFakeSockets() {
+ channel_1_.reset(new net::FakeDataChannel());
+ channel_2_.reset(new net::FakeDataChannel());
+ server_cert_ =
+ net::ImportCertFromFile(net::GetTestCertsDirectory(), "ok_cert.pem");
+ ASSERT_TRUE(server_cert_);
+ server_private_key_ = ReadTestKeyFromPEM("ok_cert.pem");
+ ASSERT_TRUE(server_private_key_);
+ server_context_ = CreateSSLServerContext(
+ server_cert_.get(), *server_private_key_, server_ssl_config_);
+
+ fake_client_socket_.reset(
+ new net::FakeSocket(channel_1_.get(), channel_2_.get()));
+ std::unique_ptr<net::StreamSocket> server_socket(
Wez 2016/07/01 01:27:56 nit: fake_server_socket?
btolsch 2016/07/01 03:21:08 This was effectively changed to |accepted_socket|.
+ new net::FakeSocket(channel_2_.get(), channel_1_.get()));
+
+ server_socket_ =
+ server_context_->CreateSSLServerSocket(std::move(server_socket));
+ ASSERT_TRUE(server_socket_);
+ }
+
+ std::unique_ptr<crypto::RSAPrivateKey> ReadTestKeyFromPEM(
+ const base::StringPiece& name) {
+ base::FilePath certs_dir(net::GetTestCertsDirectory());
+ base::FilePath key_path = certs_dir.AppendASCII(name);
+ std::vector<std::string> headers;
+ headers.push_back("PRIVATE KEY");
+ std::string pem_data;
+ if (!base::ReadFileToString(key_path, &pem_data)) {
+ return nullptr;
+ }
+ net::PEMTokenizer pem_tokenizer(pem_data, headers);
+ if (!pem_tokenizer.GetNext()) {
+ return nullptr;
+ }
+ std::vector<uint8_t> key_vector(pem_tokenizer.data().begin(),
+ pem_tokenizer.data().end());
+
+ std::unique_ptr<crypto::RSAPrivateKey> key(
+ crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
+ return key;
+ }
+
void HandleAuthHandshake() {
socket_->SetupMockTransport();
CastMessage challenge_proto = CreateAuthChallenge();
@@ -395,6 +555,36 @@ class CastSocketTest : public testing::Test {
RunPendingTasks();
}
+ int ReadExactLength(net::IOBuffer* buffer,
+ int buffer_length,
+ net::Socket* socket) {
+ scoped_refptr<net::DrainableIOBuffer> draining_buffer(
+ new net::DrainableIOBuffer(buffer, buffer_length));
+ while (draining_buffer->BytesRemaining() > 0) {
+ net::TestCompletionCallback read_callback;
+ int read_result = read_callback.GetResult(server_socket_->Read(
+ draining_buffer.get(), buffer_length, read_callback.callback()));
+ EXPECT_GT(read_result, 0);
+ draining_buffer->DidConsume(read_result);
+ }
+ return buffer_length;
+ }
+
+ int WriteExactLength(net::IOBuffer* buffer,
+ int buffer_length,
+ net::Socket* socket) {
+ scoped_refptr<net::DrainableIOBuffer> draining_buffer(
+ new net::DrainableIOBuffer(buffer, buffer_length));
+ while (draining_buffer->BytesRemaining() > 0) {
+ net::TestCompletionCallback write_callback;
+ int write_result = write_callback.GetResult(server_socket_->Write(
+ draining_buffer.get(), buffer_length, write_callback.callback()));
+ EXPECT_GT(write_result, 0);
+ draining_buffer->DidConsume(write_result);
+ }
+ return buffer_length;
+ }
+
protected:
// Runs all pending tasks in the message loop.
void RunPendingTasks() {
@@ -402,12 +592,32 @@ class CastSocketTest : public testing::Test {
run_loop.RunUntilIdle();
}
- base::MessageLoop message_loop_;
+ content::TestBrowserThreadBundle thread_bundle_{
+ content::TestBrowserThreadBundle::IO_MAINLOOP};
Wez 2016/07/01 01:27:56 nit: Is there a good reason to use {} init rather
btolsch 2016/07/01 03:21:08 No, just a habit that slipped through. Moved to c
Logger* logger_;
std::unique_ptr<TestCastSocket> socket_;
CompleteHandler handler_;
std::unique_ptr<MockDelegate> delegate_;
+ // |channel_1_| and |channel_2_| form a full duplex connection between
+ // |socket_| and |server_socket_|. Passing data through them requires pumping
+ // the message loop.
+ std::unique_ptr<net::FakeDataChannel> channel_1_;
+ std::unique_ptr<net::FakeDataChannel> channel_2_;
+
+ // Used to create the fake TCP socket for |socket_| when communicating with
+ // |server_socket_|.
+ std::unique_ptr<net::FakeSocket> fake_client_socket_;
+
+ // |server_socket_| is used for the *RealSSL tests in order to test the
+ // CastSocket over a real SSL socket. The other members below are used to
+ // initialize |server_socket_|.
+ std::unique_ptr<net::SSLServerSocket> server_socket_;
+ std::unique_ptr<net::SSLServerContext> server_context_;
+ std::unique_ptr<crypto::RSAPrivateKey> server_private_key_;
+ scoped_refptr<net::X509Certificate> server_cert_;
+ net::SSLServerConfig server_ssl_config_;
+
private:
DISALLOW_COPY_AND_ASSIGN(CastSocketTest);
};
@@ -803,6 +1013,132 @@ TEST_F(CastSocketTest, TestConnectChallengeVerificationFails) {
socket_->error_state());
}
+// Tests connecting through an actual non-mocked CastTransport object and
+// non-mocked SSLClientSocket, testing the components in integration.
+TEST_F(CastSocketTest, TestConnectEndToEndWithRealSSL) {
+ CreateCastSocketSecure();
+ CreateFakeSockets();
+ socket_->SetFakeSocket(fake_client_socket_.get());
+
+ socket_->Connect(std::move(delegate_),
+ base::Bind(&CompleteHandler::OnConnectComplete,
+ base::Unretained(&handler_)));
+
+ net::TestCompletionCallback handshake_callback;
+ int server_ret = handshake_callback.GetResult(
+ server_socket_->Handshake(handshake_callback.callback()));
+
+ ASSERT_EQ(net::OK, server_ret);
+
+ // Set low-level auth challenge expectations.
+ CastMessage challenge = CreateAuthChallenge();
+ std::string challenge_str;
+ EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
+
+ int challenge_buffer_length = challenge_str.size();
+ scoped_refptr<net::IOBuffer> challenge_buffer(
+ new net::IOBuffer(challenge_buffer_length));
+ int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
+ server_socket_.get());
+
+ EXPECT_EQ(challenge_buffer_length, read);
+ EXPECT_EQ(challenge_str,
+ std::string(challenge_buffer->data(), challenge_buffer_length));
+
+ // Set low-level auth reply expectations.
+ CastMessage reply = CreateAuthReply();
+ std::string reply_str;
+ EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
+
+ scoped_refptr<net::StringIOBuffer> reply_buffer(
+ new net::StringIOBuffer(reply_str));
+ int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
+ server_socket_.get());
+
+ EXPECT_EQ(reply_buffer->size(), written);
+ EXPECT_CALL(handler_, OnConnectComplete(CHANNEL_ERROR_NONE));
+ RunPendingTasks();
+
+ EXPECT_EQ(cast_channel::READY_STATE_OPEN, socket_->ready_state());
+ EXPECT_EQ(cast_channel::CHANNEL_ERROR_NONE, socket_->error_state());
+}
+
+// Sends message data through an actual non-mocked CastTransport object and
+// non-mocked SSLClientSocket, testing the components in integration.
+TEST_F(CastSocketTest, TestMessageEndToEndWithRealSSL) {
+ CreateCastSocketSecure();
+ CreateFakeSockets();
+ socket_->SetFakeSocket(fake_client_socket_.get());
Wez 2016/07/01 01:27:56 Is it necessary for the test to retain ownership o
btolsch 2016/07/01 03:21:08 Done (passed ownership to |socket_|).
+
+ socket_->Connect(std::move(delegate_),
+ base::Bind(&CompleteHandler::OnConnectComplete,
+ base::Unretained(&handler_)));
+
+ net::TestCompletionCallback handshake_callback;
+ int server_ret = handshake_callback.GetResult(
+ server_socket_->Handshake(handshake_callback.callback()));
+
+ ASSERT_EQ(net::OK, server_ret);
+
+ // Set low-level auth challenge expectations.
+ CastMessage challenge = CreateAuthChallenge();
+ std::string challenge_str;
+ EXPECT_TRUE(MessageFramer::Serialize(challenge, &challenge_str));
+
+ int challenge_buffer_length = challenge_str.size();
+ scoped_refptr<net::IOBuffer> challenge_buffer(
+ new net::IOBuffer(challenge_buffer_length));
+
+ int read = ReadExactLength(challenge_buffer.get(), challenge_buffer_length,
+ server_socket_.get());
+
+ EXPECT_EQ(challenge_buffer_length, read);
+ EXPECT_EQ(challenge_str,
+ std::string(challenge_buffer->data(), challenge_buffer_length));
+
+ // Set low-level auth reply expectations.
+ CastMessage reply = CreateAuthReply();
+ std::string reply_str;
+ EXPECT_TRUE(MessageFramer::Serialize(reply, &reply_str));
+
+ scoped_refptr<net::StringIOBuffer> reply_buffer(
+ new net::StringIOBuffer(reply_str));
+ int written = WriteExactLength(reply_buffer.get(), reply_buffer->size(),
+ server_socket_.get());
+
+ EXPECT_EQ(reply_buffer->size(), written);
+ EXPECT_CALL(handler_, OnConnectComplete(CHANNEL_ERROR_NONE));
+ RunPendingTasks();
+
+ EXPECT_EQ(cast_channel::READY_STATE_OPEN, socket_->ready_state());
+ EXPECT_EQ(cast_channel::CHANNEL_ERROR_NONE, socket_->error_state());
+
+ // Send a test message through the ssl socket.
+ CastMessage test_message = CreateTestMessage();
+ std::string test_message_str;
+ EXPECT_TRUE(MessageFramer::Serialize(test_message, &test_message_str));
+
+ int test_message_length = test_message_str.size();
+ scoped_refptr<net::IOBuffer> test_message_buffer(
+ new net::IOBuffer(test_message_length));
+
+ EXPECT_CALL(handler_, OnWriteComplete(net::OK));
+ socket_->transport()->SendMessage(
+ test_message, base::Bind(&CompleteHandler::OnWriteComplete,
+ base::Unretained(&handler_)));
+ RunPendingTasks();
+
+ read = ReadExactLength(test_message_buffer.get(), test_message_length,
+ server_socket_.get());
+
+ EXPECT_EQ(test_message_length, read);
+ EXPECT_EQ(test_message_str,
+ std::string(test_message_buffer->data(), test_message_length));
+
+ EXPECT_EQ(cast_channel::READY_STATE_OPEN, socket_->ready_state());
+ EXPECT_EQ(cast_channel::CHANNEL_ERROR_NONE, socket_->error_state());
+}
+
// Sends message data through an actual non-mocked CastTransport object,
// testing the two components in integration.
TEST_F(CastSocketTest, TestConnectEndToEndWithRealTransportAsync) {
« no previous file with comments | « no previous file | net/socket/socket_test_util.h » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698