| Index: net/test/embedded_test_server/embedded_test_server.cc | 
| diff --git a/net/test/embedded_test_server/embedded_test_server.cc b/net/test/embedded_test_server/embedded_test_server.cc | 
| index 62158b4ecc1dec795114eb37aa3872fac04fac17..0f72f6078b1362075506da0430a52337c45161ea 100644 | 
| --- a/net/test/embedded_test_server/embedded_test_server.cc | 
| +++ b/net/test/embedded_test_server/embedded_test_server.cc | 
| @@ -10,6 +10,7 @@ | 
| #include "base/location.h" | 
| #include "base/logging.h" | 
| #include "base/message_loop/message_loop.h" | 
| +#include "base/path_service.h" | 
| #include "base/process/process_metrics.h" | 
| #include "base/run_loop.h" | 
| #include "base/stl_util.h" | 
| @@ -17,88 +18,39 @@ | 
| #include "base/strings/stringprintf.h" | 
| #include "base/thread_task_runner_handle.h" | 
| #include "base/threading/thread_restrictions.h" | 
| +#include "crypto/rsa_private_key.h" | 
| #include "net/base/ip_endpoint.h" | 
| #include "net/base/net_errors.h" | 
| +#include "net/base/test_data_directory.h" | 
| +#include "net/cert/pem_tokenizer.h" | 
| +#include "net/cert/test_root_certs.h" | 
| +#include "net/socket/ssl_server_socket.h" | 
| #include "net/socket/stream_socket.h" | 
| #include "net/socket/tcp_server_socket.h" | 
| +#include "net/ssl/ssl_server_config.h" | 
| +#include "net/test/cert_test_util.h" | 
| #include "net/test/embedded_test_server/embedded_test_server_connection_listener.h" | 
| #include "net/test/embedded_test_server/http_connection.h" | 
| #include "net/test/embedded_test_server/http_request.h" | 
| #include "net/test/embedded_test_server/http_response.h" | 
| +#include "net/test/embedded_test_server/request_helpers.h" | 
|  | 
| namespace net { | 
| namespace test_server { | 
|  | 
| -namespace { | 
| +EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {} | 
|  | 
| -class CustomHttpResponse : public HttpResponse { | 
| - public: | 
| -  CustomHttpResponse(const std::string& headers, const std::string& contents) | 
| -      : headers_(headers), contents_(contents) { | 
| -  } | 
| - | 
| -  std::string ToResponseString() const override { | 
| -    return headers_ + "\r\n" + contents_; | 
| -  } | 
| +EmbeddedTestServer::EmbeddedTestServer(Type type) | 
| +    : is_using_ssl_(type == TYPE_HTTPS), | 
| +      connection_listener_(nullptr), | 
| +      port_(0) { | 
| +  DCHECK(thread_checker_.CalledOnValidThread()); | 
|  | 
| - private: | 
| -  std::string headers_; | 
| -  std::string contents_; | 
| - | 
| -  DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse); | 
| -}; | 
| - | 
| -// Handles |request| by serving a file from under |server_root|. | 
| -scoped_ptr<HttpResponse> HandleFileRequest( | 
| -    const base::FilePath& server_root, | 
| -    const HttpRequest& request) { | 
| -  // This is a test-only server. Ignore I/O thread restrictions. | 
| -  base::ThreadRestrictions::ScopedAllowIO allow_io; | 
| - | 
| -  std::string relative_url(request.relative_url); | 
| -  // A proxy request will have an absolute path. Simulate the proxy by stripping | 
| -  // the scheme, host, and port. | 
| -  GURL relative_gurl(relative_url); | 
| -  if (relative_gurl.is_valid()) | 
| -    relative_url = relative_gurl.PathForRequest(); | 
| - | 
| -  // Trim the first byte ('/'). | 
| -  std::string request_path = relative_url.substr(1); | 
| - | 
| -  // Remove the query string if present. | 
| -  size_t query_pos = request_path.find('?'); | 
| -  if (query_pos != std::string::npos) | 
| -    request_path = request_path.substr(0, query_pos); | 
| - | 
| -  base::FilePath file_path(server_root.AppendASCII(request_path)); | 
| -  std::string file_contents; | 
| -  if (!base::ReadFileToString(file_path, &file_contents)) | 
| -    return scoped_ptr<HttpResponse>(); | 
| - | 
| -  base::FilePath headers_path( | 
| -      file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers"))); | 
| - | 
| -  if (base::PathExists(headers_path)) { | 
| -    std::string headers_contents; | 
| -    if (!base::ReadFileToString(headers_path, &headers_contents)) | 
| -      return scoped_ptr<HttpResponse>(); | 
| - | 
| -    scoped_ptr<CustomHttpResponse> http_response( | 
| -        new CustomHttpResponse(headers_contents, file_contents)); | 
| -    return http_response.Pass(); | 
| +  if (is_using_ssl_) { | 
| +    TestRootCerts* root_certs = TestRootCerts::GetInstance(); | 
| +    base::FilePath certs_dir(GetTestCertsDirectory()); | 
| +    root_certs->AddFromFile(certs_dir.AppendASCII("root_ca_cert.pem")); | 
| } | 
| - | 
| -  scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse); | 
| -  http_response->set_code(HTTP_OK); | 
| -  http_response->set_content(file_contents); | 
| -  return http_response.Pass(); | 
| -} | 
| - | 
| -}  // namespace | 
| - | 
| -EmbeddedTestServer::EmbeddedTestServer() | 
| -    : connection_listener_(nullptr), port_(0) { | 
| -  DCHECK(thread_checker_.CalledOnValidThread()); | 
| } | 
|  | 
| EmbeddedTestServer::~EmbeddedTestServer() { | 
| @@ -115,7 +67,7 @@ void EmbeddedTestServer::SetConnectionListener( | 
| connection_listener_ = listener; | 
| } | 
|  | 
| -bool EmbeddedTestServer::InitializeAndWaitUntilReady() { | 
| +bool EmbeddedTestServer::Start() { | 
| bool success = InitializeAndListen(); | 
| if (!success) | 
| return false; | 
| @@ -123,6 +75,10 @@ bool EmbeddedTestServer::InitializeAndWaitUntilReady() { | 
| return true; | 
| } | 
|  | 
| +bool EmbeddedTestServer::InitializeAndWaitUntilReady() { | 
| +  return Start(); | 
| +} | 
| + | 
| bool EmbeddedTestServer::InitializeAndListen() { | 
| DCHECK(!Started()); | 
|  | 
| @@ -142,7 +98,17 @@ bool EmbeddedTestServer::InitializeAndListen() { | 
| return false; | 
| } | 
|  | 
| -  base_url_ = GURL(std::string("http://") + local_endpoint_.ToString()); | 
| +  if (is_using_ssl_) { | 
| +    base_url_ = GURL("https://" + local_endpoint_.ToString()); | 
| +    if (ssl_config_.server_cert == SSLServerConfig::CERT_MISMATCHED_NAME || | 
| +        ssl_config_.server_cert == | 
| +            SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN) { | 
| +      base_url_ = GURL( | 
| +          base::StringPrintf("https://localhost:%d", local_endpoint_.port())); | 
| +    } | 
| +  } else { | 
| +    base_url_ = GURL("http://" + local_endpoint_.ToString()); | 
| +  } | 
| port_ = local_endpoint_.port(); | 
|  | 
| listen_socket_->DetachFromThread(); | 
| @@ -183,13 +149,21 @@ void EmbeddedTestServer::HandleRequest(HttpConnection* connection, | 
|  | 
| scoped_ptr<HttpResponse> response; | 
|  | 
| -  for (size_t i = 0; i < request_handlers_.size(); ++i) { | 
| -    response = request_handlers_[i].Run(*request); | 
| +  for (auto handler : request_handlers_) { | 
| +    response = handler.Run(*request); | 
| if (response) | 
| break; | 
| } | 
|  | 
| if (!response) { | 
| +    for (auto handler : default_request_handlers_) { | 
| +      response = handler.Run(*request); | 
| +      if (response) | 
| +        break; | 
| +    } | 
| +  } | 
| + | 
| +  if (!response) { | 
| LOG(WARNING) << "Request not handled. Returning 404: " | 
| << request->relative_url; | 
| scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); | 
| @@ -197,9 +171,10 @@ void EmbeddedTestServer::HandleRequest(HttpConnection* connection, | 
| response = not_found_response.Pass(); | 
| } | 
|  | 
| -  connection->SendResponse(response.Pass(), | 
| -                           base::Bind(&EmbeddedTestServer::DidClose, | 
| -                                      base::Unretained(this), connection)); | 
| +  response->SendResponse(base::Bind(&HttpConnection::SendResponseBytes, | 
| +                                    base::Unretained(connection)), | 
| +                         base::Bind(&EmbeddedTestServer::DidClose, | 
| +                                    base::Unretained(this), connection)); | 
| } | 
|  | 
| GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { | 
| @@ -223,16 +198,98 @@ bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const { | 
| return true; | 
| } | 
|  | 
| +void EmbeddedTestServer::SetSSLConfig(const SSLServerConfig& ssl_config) { | 
| +  DCHECK(!Started()); | 
| +  ssl_config_ = ssl_config; | 
| +} | 
| + | 
| +std::string EmbeddedTestServer::GetCertificateName() const { | 
| +  DCHECK(is_using_ssl()); | 
| +  switch (ssl_config_.server_cert) { | 
| +    case SSLServerConfig::CERT_OK: | 
| +    case SSLServerConfig::CERT_MISMATCHED_NAME: | 
| +      return "ok_cert.pem"; | 
| +    case SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN: | 
| +      return "localhost_cert.pem"; | 
| +    case SSLServerConfig::CERT_EXPIRED: | 
| +      return "expired_cert.pem"; | 
| +    case SSLServerConfig::CERT_CHAIN_WRONG_ROOT: | 
| +      return "redundant-server-chain.pem"; | 
| +    case SSLServerConfig::CERT_BAD_VALIDITY: | 
| +      return "bad_validity.pem"; | 
| +  } | 
| + | 
| +  return "ok_cert.pem"; | 
| +} | 
| + | 
| +scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() const { | 
| +  DCHECK(is_using_ssl()); | 
| +  base::FilePath certs_dir(GetTestCertsDirectory()); | 
| +  return ImportCertFromFile(certs_dir, GetCertificateName()); | 
| +} | 
| + | 
| void EmbeddedTestServer::ServeFilesFromDirectory( | 
| const base::FilePath& directory) { | 
| RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); | 
| } | 
|  | 
| +void EmbeddedTestServer::ServeFilesFromSourceDirectory( | 
| +    const std::string& relative) { | 
| +  base::FilePath test_data_dir; | 
| +  if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir)) | 
| +    ServeFilesFromDirectory(test_data_dir.AppendASCII(relative)); | 
| +} | 
| + | 
| +void EmbeddedTestServer::ServeFilesFromSourceDirectory( | 
| +    const base::FilePath& relative) { | 
| +  base::FilePath test_data_dir; | 
| +  if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir)) | 
| +    ServeFilesFromDirectory(test_data_dir.Append(relative)); | 
| +} | 
| + | 
| +void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) { | 
| +  // TODO(svaldez): Add additional default handlers. | 
| +  ServeFilesFromSourceDirectory(directory); | 
| +} | 
| + | 
| void EmbeddedTestServer::RegisterRequestHandler( | 
| const HandleRequestCallback& callback) { | 
| +  // TODO(svaldez): Add check to prevent RegisterHandler from being called | 
| +  // after the server has started. crbug.com/546060 | 
| request_handlers_.push_back(callback); | 
| } | 
|  | 
| +void EmbeddedTestServer::RegisterDefaultHandler( | 
| +    const HandleRequestCallback& callback) { | 
| +  // TODO(svaldez): Add check to prevent RegisterHandler from being called | 
| +  // after the server has started. crbug.com/546060 | 
| +  default_request_handlers_.push_back(callback); | 
| +} | 
| + | 
| +scoped_ptr<StreamSocket> EmbeddedTestServer::DoSSLUpgrade( | 
| +    scoped_ptr<StreamSocket> connection) { | 
| +  DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
| + | 
| +  base::FilePath certs_dir(GetTestCertsDirectory()); | 
| +  std::string cert_name = GetCertificateName(); | 
| + | 
| +  base::FilePath key_path = certs_dir.AppendASCII(cert_name); | 
| +  std::string key_string; | 
| +  CHECK(base::ReadFileToString(key_path, &key_string)); | 
| +  std::vector<std::string> headers; | 
| +  headers.push_back("PRIVATE KEY"); | 
| +  PEMTokenizer pem_tokenizer(key_string, headers); | 
| +  pem_tokenizer.GetNext(); | 
| +  std::vector<uint8_t> key_vector; | 
| +  key_vector.assign(pem_tokenizer.data().begin(), pem_tokenizer.data().end()); | 
| + | 
| +  scoped_ptr<crypto::RSAPrivateKey> server_key( | 
| +      crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); | 
| + | 
| +  return CreateSSLServerSocket(connection.Pass(), GetCertificate().get(), | 
| +                               server_key.get(), ssl_config_); | 
| +} | 
| + | 
| void EmbeddedTestServer::DoAcceptLoop() { | 
| int rv = OK; | 
| while (rv == OK) { | 
| @@ -251,17 +308,37 @@ void EmbeddedTestServer::OnAcceptCompleted(int rv) { | 
| DoAcceptLoop(); | 
| } | 
|  | 
| +void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) { | 
| +  if (connection->socket_->IsConnected()) | 
| +    ReadData(connection); | 
| +  else | 
| +    DidClose(connection); | 
| +} | 
| + | 
| void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) { | 
| DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); | 
| if (connection_listener_) | 
| connection_listener_->AcceptedSocket(*socket); | 
|  | 
| +  if (is_using_ssl_) | 
| +    socket = DoSSLUpgrade(socket.Pass()); | 
| + | 
| HttpConnection* http_connection = new HttpConnection( | 
| socket.Pass(), | 
| base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this))); | 
| connections_[http_connection->socket_.get()] = http_connection; | 
|  | 
| -  ReadData(http_connection); | 
| +  if (is_using_ssl_) { | 
| +    SSLServerSocket* ssl_socket = | 
| +        static_cast<SSLServerSocket*>(http_connection->socket_.get()); | 
| +    int rv = ssl_socket->Handshake( | 
| +        base::Bind(&EmbeddedTestServer::OnHandshakeDone, base::Unretained(this), | 
| +                   http_connection)); | 
| +    if (rv != ERR_IO_PENDING) | 
| +      OnHandshakeDone(http_connection, rv); | 
| +  } else { | 
| +    ReadData(http_connection); | 
| +  } | 
| } | 
|  | 
| void EmbeddedTestServer::ReadData(HttpConnection* connection) { | 
|  |