Chromium Code Reviews| Index: net/test/embedded_test_server/embedded_test_server.cc |
| diff --git a/net/test/embedded_test_server/embedded_test_server.cc b/net/test/embedded_test_server/embedded_test_server.cc |
| index 62158b4ecc1dec795114eb37aa3872fac04fac17..131b243e868e3ef8037cf4cfa438a096fb421de9 100644 |
| --- a/net/test/embedded_test_server/embedded_test_server.cc |
| +++ b/net/test/embedded_test_server/embedded_test_server.cc |
| @@ -10,6 +10,7 @@ |
| #include "base/location.h" |
| #include "base/logging.h" |
| #include "base/message_loop/message_loop.h" |
| +#include "base/path_service.h" |
| #include "base/process/process_metrics.h" |
| #include "base/run_loop.h" |
| #include "base/stl_util.h" |
| @@ -17,88 +18,35 @@ |
| #include "base/strings/stringprintf.h" |
| #include "base/thread_task_runner_handle.h" |
| #include "base/threading/thread_restrictions.h" |
| +#include "crypto/rsa_private_key.h" |
| #include "net/base/ip_endpoint.h" |
| #include "net/base/net_errors.h" |
| +#include "net/base/test_data_directory.h" |
| +#include "net/cert/pem_tokenizer.h" |
| +#include "net/cert/test_root_certs.h" |
| +#include "net/socket/ssl_server_socket.h" |
| #include "net/socket/stream_socket.h" |
| #include "net/socket/tcp_server_socket.h" |
| +#include "net/ssl/ssl_server_config.h" |
| +#include "net/test/cert_test_util.h" |
| #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h" |
| #include "net/test/embedded_test_server/http_connection.h" |
| #include "net/test/embedded_test_server/http_request.h" |
| #include "net/test/embedded_test_server/http_response.h" |
| +#include "net/test/embedded_test_server/request_helpers.h" |
| namespace net { |
| namespace test_server { |
| -namespace { |
| - |
| -class CustomHttpResponse : public HttpResponse { |
| - public: |
| - CustomHttpResponse(const std::string& headers, const std::string& contents) |
| - : headers_(headers), contents_(contents) { |
| - } |
| - |
| - std::string ToResponseString() const override { |
| - return headers_ + "\r\n" + contents_; |
| - } |
| - |
| - private: |
| - std::string headers_; |
| - std::string contents_; |
| - |
| - DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse); |
| -}; |
| - |
| -// Handles |request| by serving a file from under |server_root|. |
| -scoped_ptr<HttpResponse> HandleFileRequest( |
| - const base::FilePath& server_root, |
| - const HttpRequest& request) { |
| - // This is a test-only server. Ignore I/O thread restrictions. |
| - base::ThreadRestrictions::ScopedAllowIO allow_io; |
| - |
| - std::string relative_url(request.relative_url); |
| - // A proxy request will have an absolute path. Simulate the proxy by stripping |
| - // the scheme, host, and port. |
| - GURL relative_gurl(relative_url); |
| - if (relative_gurl.is_valid()) |
| - relative_url = relative_gurl.PathForRequest(); |
| - |
| - // Trim the first byte ('/'). |
| - std::string request_path = relative_url.substr(1); |
| - |
| - // Remove the query string if present. |
| - size_t query_pos = request_path.find('?'); |
| - if (query_pos != std::string::npos) |
| - request_path = request_path.substr(0, query_pos); |
| - |
| - base::FilePath file_path(server_root.AppendASCII(request_path)); |
| - std::string file_contents; |
| - if (!base::ReadFileToString(file_path, &file_contents)) |
| - return scoped_ptr<HttpResponse>(); |
| - |
| - base::FilePath headers_path( |
| - file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers"))); |
| - |
| - if (base::PathExists(headers_path)) { |
| - std::string headers_contents; |
| - if (!base::ReadFileToString(headers_path, &headers_contents)) |
| - return scoped_ptr<HttpResponse>(); |
| - |
| - scoped_ptr<CustomHttpResponse> http_response( |
| - new CustomHttpResponse(headers_contents, file_contents)); |
| - return http_response.Pass(); |
| - } |
| - |
| - scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse); |
| - http_response->set_code(HTTP_OK); |
| - http_response->set_content(file_contents); |
| - return http_response.Pass(); |
| -} |
| - |
| -} // namespace |
| - |
| EmbeddedTestServer::EmbeddedTestServer() |
| + : use_ssl_(false), connection_listener_(nullptr), port_(0) {} |
|
mmenke
2015/10/14 21:59:15
I believe this can just be:
EmbeddedTestServer::E
svaldez
2015/10/14 22:33:39
Done.
|
| + |
| +EmbeddedTestServer::EmbeddedTestServer(Type type) |
| : connection_listener_(nullptr), port_(0) { |
| DCHECK(thread_checker_.CalledOnValidThread()); |
| + use_ssl_ = (type == TYPE_HTTPS); |
|
mmenke
2015/10/14 21:59:15
nit: Put in initializer list, and can make it con
svaldez
2015/10/14 22:33:39
Done.
|
| + if (use_ssl_) |
| + LoadTestSSLRoot(); |
| } |
| EmbeddedTestServer::~EmbeddedTestServer() { |
| @@ -115,7 +63,7 @@ void EmbeddedTestServer::SetConnectionListener( |
| connection_listener_ = listener; |
| } |
| -bool EmbeddedTestServer::InitializeAndWaitUntilReady() { |
| +bool EmbeddedTestServer::Start() { |
| bool success = InitializeAndListen(); |
| if (!success) |
| return false; |
| @@ -123,6 +71,10 @@ bool EmbeddedTestServer::InitializeAndWaitUntilReady() { |
| return true; |
| } |
| +bool EmbeddedTestServer::InitializeAndWaitUntilReady() { |
| + return Start(); |
| +} |
| + |
| bool EmbeddedTestServer::InitializeAndListen() { |
| DCHECK(!Started()); |
| @@ -142,7 +94,17 @@ bool EmbeddedTestServer::InitializeAndListen() { |
| return false; |
| } |
| - base_url_ = GURL(std::string("http://") + local_endpoint_.ToString()); |
| + if (use_ssl_) { |
| + base_url_ = GURL("https://" + local_endpoint_.ToString()); |
| + if (ssl_config_.server_cert == SSLServerConfig::CERT_MISMATCHED_NAME || |
| + ssl_config_.server_cert == |
| + SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN) { |
| + base_url_ = GURL( |
| + base::StringPrintf("https://localhost:%d", local_endpoint_.port())); |
| + } |
| + } else { |
| + base_url_ = GURL("http://" + local_endpoint_.ToString()); |
| + } |
| port_ = local_endpoint_.port(); |
| listen_socket_->DetachFromThread(); |
| @@ -190,6 +152,14 @@ void EmbeddedTestServer::HandleRequest(HttpConnection* connection, |
| } |
| if (!response) { |
| + for (size_t i = 0; i < default_request_handlers_.size(); ++i) { |
|
mmenke
2015/10/14 21:59:15
Can use a range loop.
svaldez
2015/10/14 22:33:39
Done.
|
| + response = default_request_handlers_[i].Run(*request); |
| + if (response) |
| + break; |
| + } |
| + } |
| + |
| + if (!response) { |
| LOG(WARNING) << "Request not handled. Returning 404: " |
| << request->relative_url; |
| scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); |
| @@ -197,9 +167,10 @@ void EmbeddedTestServer::HandleRequest(HttpConnection* connection, |
| response = not_found_response.Pass(); |
| } |
| - connection->SendResponse(response.Pass(), |
| - base::Bind(&EmbeddedTestServer::DidClose, |
| - base::Unretained(this), connection)); |
| + response->SendResponse( |
| + base::Bind(&HttpConnection::SendResponse, base::Unretained(connection)), |
| + base::Bind(&EmbeddedTestServer::DidClose, base::Unretained(this), |
| + connection)); |
| } |
| GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { |
| @@ -223,16 +194,108 @@ bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const { |
| return true; |
| } |
| +void EmbeddedTestServer::SetSSLConfig(net::SSLServerConfig ssl_config) { |
| + DCHECK(!Started()); |
| + ssl_config_ = ssl_config; |
| +} |
| + |
| +std::string EmbeddedTestServer::GetCertificateName() const { |
| + std::string cert_name; |
| + |
| + switch (ssl_config_.server_cert) { |
| + case SSLServerConfig::CERT_OK: |
| + case SSLServerConfig::CERT_MISMATCHED_NAME: |
| + cert_name = "ok_cert.pem"; |
| + break; |
| + case SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN: |
| + cert_name = "localhost_cert.pem"; |
| + break; |
| + case SSLServerConfig::CERT_EXPIRED: |
| + cert_name = "expired_cert.pem"; |
| + break; |
| + case SSLServerConfig::CERT_CHAIN_WRONG_ROOT: |
| + cert_name = "redundant-server-chain.pem"; |
| + break; |
| + case SSLServerConfig::CERT_BAD_VALIDITY: |
| + cert_name = "bad_validity.pem"; |
| + break; |
| + } |
| + |
| + return cert_name; |
| +} |
| + |
| +scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() const { |
| + base::FilePath certs_dir(GetTestCertsDirectory()); |
| + return ImportCertFromFile(certs_dir, GetCertificateName()); |
| +} |
| + |
| void EmbeddedTestServer::ServeFilesFromDirectory( |
| const base::FilePath& directory) { |
| RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); |
| } |
| +void EmbeddedTestServer::ServeFilesFromSourceDirectory( |
| + const std::string relative) { |
| + base::FilePath test_data_dir; |
| + if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir)) |
| + ServeFilesFromDirectory(test_data_dir.AppendASCII(relative)); |
| +} |
| + |
| +void EmbeddedTestServer::ServeFilesFromSourceDirectory( |
| + const base::FilePath& relative) { |
| + base::FilePath test_data_dir; |
| + if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir)) |
| + ServeFilesFromDirectory(test_data_dir.Append(relative)); |
| +} |
| + |
| +void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) { |
| + RegisterDefaultHandlers(this); |
| + ServeFilesFromSourceDirectory(directory); |
| +} |
| + |
| void EmbeddedTestServer::RegisterRequestHandler( |
| const HandleRequestCallback& callback) { |
| request_handlers_.push_back(callback); |
| } |
| +void EmbeddedTestServer::RegisterDefaultHandler( |
| + const HandleRequestCallback& callback) { |
| + default_request_handlers_.push_back(callback); |
| +} |
| + |
| +void EmbeddedTestServer::LoadTestSSLRoot() { |
| + TestRootCerts* root_certs = TestRootCerts::GetInstance(); |
| + if (!root_certs) |
| + return; |
| + base::FilePath certs_dir(GetTestCertsDirectory()); |
| + root_certs->AddFromFile(certs_dir.AppendASCII("root_ca_cert.pem")); |
| +} |
| + |
| +scoped_ptr<StreamSocket> EmbeddedTestServer::DoSSLUpgrade( |
| + scoped_ptr<StreamSocket> connection) { |
| + DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); |
| + |
| + base::FilePath certs_dir(GetTestCertsDirectory()); |
| + std::string cert_name = GetCertificateName(); |
| + scoped_refptr<X509Certificate> server_cert = GetCertificate(); |
| + |
| + base::FilePath key_path = certs_dir.AppendASCII(cert_name); |
| + std::string key_string; |
| + CHECK(base::ReadFileToString(key_path, &key_string)); |
| + std::vector<std::string> headers; |
| + headers.push_back("PRIVATE KEY"); |
| + PEMTokenizer pem_tokenizer(key_string, headers); |
| + pem_tokenizer.GetNext(); |
| + std::vector<uint8_t> key_vector; |
| + key_vector.assign(pem_tokenizer.data().begin(), pem_tokenizer.data().end()); |
| + |
| + scoped_ptr<crypto::RSAPrivateKey> server_key( |
| + crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); |
| + |
| + return CreateSSLServerSocket(connection.Pass(), server_cert.get(), |
| + server_key.get(), ssl_config_); |
| +} |
| + |
| void EmbeddedTestServer::DoAcceptLoop() { |
| int rv = OK; |
| while (rv == OK) { |
| @@ -251,17 +314,35 @@ void EmbeddedTestServer::OnAcceptCompleted(int rv) { |
| DoAcceptLoop(); |
| } |
| +void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) { |
| + if (connection->socket_->IsConnected()) |
| + ReadData(connection); |
| +} |
| + |
| void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) { |
| DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); |
| if (connection_listener_) |
| connection_listener_->AcceptedSocket(*socket); |
| + if (use_ssl_) |
| + socket = DoSSLUpgrade(socket.Pass()); |
| + |
| HttpConnection* http_connection = new HttpConnection( |
| socket.Pass(), |
| base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this))); |
| connections_[http_connection->socket_.get()] = http_connection; |
| - ReadData(http_connection); |
| + if (use_ssl_) { |
| + SSLServerSocket* ssl_socket = |
| + static_cast<SSLServerSocket*>(http_connection->socket_.get()); |
| + int rv = ssl_socket->Handshake( |
| + base::Bind(&EmbeddedTestServer::OnHandshakeDone, base::Unretained(this), |
| + http_connection)); |
| + if (rv != ERR_IO_PENDING) |
| + OnHandshakeDone(http_connection, rv); |
| + } else { |
| + ReadData(http_connection); |
| + } |
| } |
| void EmbeddedTestServer::ReadData(HttpConnection* connection) { |