| Index: net/test/embedded_test_server/embedded_test_server.cc
|
| diff --git a/net/test/embedded_test_server/embedded_test_server.cc b/net/test/embedded_test_server/embedded_test_server.cc
|
| index 62158b4ecc1dec795114eb37aa3872fac04fac17..8725fa118f20d74dbf0bffe50d9d296c9b4c9e98 100644
|
| --- a/net/test/embedded_test_server/embedded_test_server.cc
|
| +++ b/net/test/embedded_test_server/embedded_test_server.cc
|
| @@ -10,6 +10,7 @@
|
| #include "base/location.h"
|
| #include "base/logging.h"
|
| #include "base/message_loop/message_loop.h"
|
| +#include "base/path_service.h"
|
| #include "base/process/process_metrics.h"
|
| #include "base/run_loop.h"
|
| #include "base/stl_util.h"
|
| @@ -17,88 +18,40 @@
|
| #include "base/strings/stringprintf.h"
|
| #include "base/thread_task_runner_handle.h"
|
| #include "base/threading/thread_restrictions.h"
|
| +#include "crypto/rsa_private_key.h"
|
| #include "net/base/ip_endpoint.h"
|
| #include "net/base/net_errors.h"
|
| +#include "net/base/test_data_directory.h"
|
| +#include "net/cert/pem_tokenizer.h"
|
| +#include "net/cert/test_root_certs.h"
|
| +#include "net/socket/ssl_server_socket.h"
|
| #include "net/socket/stream_socket.h"
|
| #include "net/socket/tcp_server_socket.h"
|
| +#include "net/ssl/ssl_server_config.h"
|
| +#include "net/test/cert_test_util.h"
|
| #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h"
|
| #include "net/test/embedded_test_server/http_connection.h"
|
| #include "net/test/embedded_test_server/http_request.h"
|
| #include "net/test/embedded_test_server/http_response.h"
|
| +#include "net/test/embedded_test_server/request_handler_util.h"
|
|
|
| namespace net {
|
| namespace test_server {
|
|
|
| -namespace {
|
| +EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {}
|
|
|
| -class CustomHttpResponse : public HttpResponse {
|
| - public:
|
| - CustomHttpResponse(const std::string& headers, const std::string& contents)
|
| - : headers_(headers), contents_(contents) {
|
| - }
|
| -
|
| - std::string ToResponseString() const override {
|
| - return headers_ + "\r\n" + contents_;
|
| - }
|
| +EmbeddedTestServer::EmbeddedTestServer(Type type)
|
| + : is_using_ssl_(type == TYPE_HTTPS),
|
| + connection_listener_(nullptr),
|
| + port_(0),
|
| + cert_(CERT_OK) {
|
| + DCHECK(thread_checker_.CalledOnValidThread());
|
|
|
| - private:
|
| - std::string headers_;
|
| - std::string contents_;
|
| -
|
| - DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
|
| -};
|
| -
|
| -// Handles |request| by serving a file from under |server_root|.
|
| -scoped_ptr<HttpResponse> HandleFileRequest(
|
| - const base::FilePath& server_root,
|
| - const HttpRequest& request) {
|
| - // This is a test-only server. Ignore I/O thread restrictions.
|
| - base::ThreadRestrictions::ScopedAllowIO allow_io;
|
| -
|
| - std::string relative_url(request.relative_url);
|
| - // A proxy request will have an absolute path. Simulate the proxy by stripping
|
| - // the scheme, host, and port.
|
| - GURL relative_gurl(relative_url);
|
| - if (relative_gurl.is_valid())
|
| - relative_url = relative_gurl.PathForRequest();
|
| -
|
| - // Trim the first byte ('/').
|
| - std::string request_path = relative_url.substr(1);
|
| -
|
| - // Remove the query string if present.
|
| - size_t query_pos = request_path.find('?');
|
| - if (query_pos != std::string::npos)
|
| - request_path = request_path.substr(0, query_pos);
|
| -
|
| - base::FilePath file_path(server_root.AppendASCII(request_path));
|
| - std::string file_contents;
|
| - if (!base::ReadFileToString(file_path, &file_contents))
|
| - return scoped_ptr<HttpResponse>();
|
| -
|
| - base::FilePath headers_path(
|
| - file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
|
| -
|
| - if (base::PathExists(headers_path)) {
|
| - std::string headers_contents;
|
| - if (!base::ReadFileToString(headers_path, &headers_contents))
|
| - return scoped_ptr<HttpResponse>();
|
| -
|
| - scoped_ptr<CustomHttpResponse> http_response(
|
| - new CustomHttpResponse(headers_contents, file_contents));
|
| - return http_response.Pass();
|
| + if (is_using_ssl_) {
|
| + TestRootCerts* root_certs = TestRootCerts::GetInstance();
|
| + base::FilePath certs_dir(GetTestCertsDirectory());
|
| + root_certs->AddFromFile(certs_dir.AppendASCII("root_ca_cert.pem"));
|
| }
|
| -
|
| - scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
|
| - http_response->set_code(HTTP_OK);
|
| - http_response->set_content(file_contents);
|
| - return http_response.Pass();
|
| -}
|
| -
|
| -} // namespace
|
| -
|
| -EmbeddedTestServer::EmbeddedTestServer()
|
| - : connection_listener_(nullptr), port_(0) {
|
| - DCHECK(thread_checker_.CalledOnValidThread());
|
| }
|
|
|
| EmbeddedTestServer::~EmbeddedTestServer() {
|
| @@ -115,7 +68,7 @@ void EmbeddedTestServer::SetConnectionListener(
|
| connection_listener_ = listener;
|
| }
|
|
|
| -bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
|
| +bool EmbeddedTestServer::Start() {
|
| bool success = InitializeAndListen();
|
| if (!success)
|
| return false;
|
| @@ -123,6 +76,10 @@ bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
|
| return true;
|
| }
|
|
|
| +bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
|
| + return Start();
|
| +}
|
| +
|
| bool EmbeddedTestServer::InitializeAndListen() {
|
| DCHECK(!Started());
|
|
|
| @@ -142,7 +99,15 @@ bool EmbeddedTestServer::InitializeAndListen() {
|
| return false;
|
| }
|
|
|
| - base_url_ = GURL(std::string("http://") + local_endpoint_.ToString());
|
| + if (is_using_ssl_) {
|
| + base_url_ = GURL("https://" + local_endpoint_.ToString());
|
| + if (cert_ == CERT_MISMATCHED_NAME || cert_ == CERT_COMMON_NAME_IS_DOMAIN) {
|
| + base_url_ = GURL(
|
| + base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
|
| + }
|
| + } else {
|
| + base_url_ = GURL("http://" + local_endpoint_.ToString());
|
| + }
|
| port_ = local_endpoint_.port();
|
|
|
| listen_socket_->DetachFromThread();
|
| @@ -183,13 +148,21 @@ void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
|
|
|
| scoped_ptr<HttpResponse> response;
|
|
|
| - for (size_t i = 0; i < request_handlers_.size(); ++i) {
|
| - response = request_handlers_[i].Run(*request);
|
| + for (const auto& handler : request_handlers_) {
|
| + response = handler.Run(*request);
|
| if (response)
|
| break;
|
| }
|
|
|
| if (!response) {
|
| + for (const auto& handler : default_request_handlers_) {
|
| + response = handler.Run(*request);
|
| + if (response)
|
| + break;
|
| + }
|
| + }
|
| +
|
| + if (!response) {
|
| LOG(WARNING) << "Request not handled. Returning 404: "
|
| << request->relative_url;
|
| scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
|
| @@ -197,9 +170,10 @@ void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
|
| response = not_found_response.Pass();
|
| }
|
|
|
| - connection->SendResponse(response.Pass(),
|
| - base::Bind(&EmbeddedTestServer::DidClose,
|
| - base::Unretained(this), connection));
|
| + response->SendResponse(base::Bind(&HttpConnection::SendResponseBytes,
|
| + base::Unretained(connection)),
|
| + base::Bind(&EmbeddedTestServer::DidClose,
|
| + base::Unretained(this), connection));
|
| }
|
|
|
| GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
|
| @@ -223,16 +197,104 @@ bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
|
| return true;
|
| }
|
|
|
| +void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert,
|
| + const SSLServerConfig& ssl_config) {
|
| + DCHECK(!Started());
|
| + cert_ = cert;
|
| + ssl_config_ = ssl_config;
|
| +}
|
| +
|
| +void EmbeddedTestServer::SetSSLConfig(ServerCertificate cert) {
|
| + SetSSLConfig(cert, SSLServerConfig());
|
| +}
|
| +
|
| +std::string EmbeddedTestServer::GetCertificateName() const {
|
| + DCHECK(is_using_ssl_);
|
| + switch (cert_) {
|
| + case CERT_OK:
|
| + case CERT_MISMATCHED_NAME:
|
| + return "ok_cert.pem";
|
| + case CERT_COMMON_NAME_IS_DOMAIN:
|
| + return "localhost_cert.pem";
|
| + case CERT_EXPIRED:
|
| + return "expired_cert.pem";
|
| + case CERT_CHAIN_WRONG_ROOT:
|
| + return "redundant-server-chain.pem";
|
| + case CERT_BAD_VALIDITY:
|
| + return "bad_validity.pem";
|
| + }
|
| +
|
| + return "ok_cert.pem";
|
| +}
|
| +
|
| +scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() const {
|
| + DCHECK(is_using_ssl_);
|
| + base::FilePath certs_dir(GetTestCertsDirectory());
|
| + return ImportCertFromFile(certs_dir, GetCertificateName());
|
| +}
|
| +
|
| void EmbeddedTestServer::ServeFilesFromDirectory(
|
| const base::FilePath& directory) {
|
| RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
|
| }
|
|
|
| +void EmbeddedTestServer::ServeFilesFromSourceDirectory(
|
| + const std::string& relative) {
|
| + base::FilePath test_data_dir;
|
| + CHECK(PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir));
|
| + ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
|
| +}
|
| +
|
| +void EmbeddedTestServer::ServeFilesFromSourceDirectory(
|
| + const base::FilePath& relative) {
|
| + base::FilePath test_data_dir;
|
| + CHECK(PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir));
|
| + ServeFilesFromDirectory(test_data_dir.Append(relative));
|
| +}
|
| +
|
| +void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
|
| + // TODO(svaldez): Add additional default handlers.
|
| + ServeFilesFromSourceDirectory(directory);
|
| +}
|
| +
|
| void EmbeddedTestServer::RegisterRequestHandler(
|
| const HandleRequestCallback& callback) {
|
| + // TODO(svaldez): Add check to prevent RegisterHandler from being called
|
| + // after the server has started. https://crbug.com/546060
|
| request_handlers_.push_back(callback);
|
| }
|
|
|
| +void EmbeddedTestServer::RegisterDefaultHandler(
|
| + const HandleRequestCallback& callback) {
|
| + // TODO(svaldez): Add check to prevent RegisterHandler from being called
|
| + // after the server has started. https://crbug.com/546060
|
| + default_request_handlers_.push_back(callback);
|
| +}
|
| +
|
| +scoped_ptr<StreamSocket> EmbeddedTestServer::DoSSLUpgrade(
|
| + scoped_ptr<StreamSocket> connection) {
|
| + DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
|
| +
|
| + base::FilePath certs_dir(GetTestCertsDirectory());
|
| + std::string cert_name = GetCertificateName();
|
| +
|
| + base::FilePath key_path = certs_dir.AppendASCII(cert_name);
|
| + std::string key_string;
|
| + CHECK(base::ReadFileToString(key_path, &key_string));
|
| + std::vector<std::string> headers;
|
| + headers.push_back("PRIVATE KEY");
|
| + PEMTokenizer pem_tokenizer(key_string, headers);
|
| + pem_tokenizer.GetNext();
|
| + std::vector<uint8_t> key_vector;
|
| + key_vector.assign(pem_tokenizer.data().begin(), pem_tokenizer.data().end());
|
| +
|
| + scoped_ptr<crypto::RSAPrivateKey> server_key(
|
| + crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
|
| +
|
| + return CreateSSLServerSocket(connection.Pass(), GetCertificate().get(),
|
| + server_key.get(), ssl_config_);
|
| +}
|
| +
|
| void EmbeddedTestServer::DoAcceptLoop() {
|
| int rv = OK;
|
| while (rv == OK) {
|
| @@ -251,17 +313,37 @@ void EmbeddedTestServer::OnAcceptCompleted(int rv) {
|
| DoAcceptLoop();
|
| }
|
|
|
| +void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
|
| + if (connection->socket_->IsConnected())
|
| + ReadData(connection);
|
| + else
|
| + DidClose(connection);
|
| +}
|
| +
|
| void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
|
| DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
|
| if (connection_listener_)
|
| connection_listener_->AcceptedSocket(*socket);
|
|
|
| + if (is_using_ssl_)
|
| + socket = DoSSLUpgrade(socket.Pass());
|
| +
|
| HttpConnection* http_connection = new HttpConnection(
|
| socket.Pass(),
|
| base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this)));
|
| connections_[http_connection->socket_.get()] = http_connection;
|
|
|
| - ReadData(http_connection);
|
| + if (is_using_ssl_) {
|
| + SSLServerSocket* ssl_socket =
|
| + static_cast<SSLServerSocket*>(http_connection->socket_.get());
|
| + int rv = ssl_socket->Handshake(
|
| + base::Bind(&EmbeddedTestServer::OnHandshakeDone, base::Unretained(this),
|
| + http_connection));
|
| + if (rv != ERR_IO_PENDING)
|
| + OnHandshakeDone(http_connection, rv);
|
| + } else {
|
| + ReadData(http_connection);
|
| + }
|
| }
|
|
|
| void EmbeddedTestServer::ReadData(HttpConnection* connection) {
|
|
|