Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(336)

Side by Side Diff: net/test/embedded_test_server/embedded_test_server.cc

Issue 1376593007: SSL in EmbeddedTestServer (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Rebase. Created 5 years, 2 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be 2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file. 3 // found in the LICENSE file.
4 4
5 #include "net/test/embedded_test_server/embedded_test_server.h" 5 #include "net/test/embedded_test_server/embedded_test_server.h"
6 6
7 #include "base/bind.h" 7 #include "base/bind.h"
8 #include "base/files/file_path.h" 8 #include "base/files/file_path.h"
9 #include "base/files/file_util.h" 9 #include "base/files/file_util.h"
10 #include "base/location.h" 10 #include "base/location.h"
11 #include "base/logging.h" 11 #include "base/logging.h"
12 #include "base/message_loop/message_loop.h" 12 #include "base/message_loop/message_loop.h"
13 #include "base/path_service.h"
13 #include "base/process/process_metrics.h" 14 #include "base/process/process_metrics.h"
14 #include "base/run_loop.h" 15 #include "base/run_loop.h"
15 #include "base/stl_util.h" 16 #include "base/stl_util.h"
16 #include "base/strings/string_util.h" 17 #include "base/strings/string_util.h"
17 #include "base/strings/stringprintf.h" 18 #include "base/strings/stringprintf.h"
18 #include "base/thread_task_runner_handle.h" 19 #include "base/thread_task_runner_handle.h"
19 #include "base/threading/thread_restrictions.h" 20 #include "base/threading/thread_restrictions.h"
21 #include "crypto/rsa_private_key.h"
20 #include "net/base/ip_endpoint.h" 22 #include "net/base/ip_endpoint.h"
21 #include "net/base/net_errors.h" 23 #include "net/base/net_errors.h"
24 #include "net/base/test_data_directory.h"
25 #include "net/cert/pem_tokenizer.h"
26 #include "net/cert/test_root_certs.h"
27 #include "net/socket/ssl_server_socket.h"
22 #include "net/socket/stream_socket.h" 28 #include "net/socket/stream_socket.h"
23 #include "net/socket/tcp_server_socket.h" 29 #include "net/socket/tcp_server_socket.h"
30 #include "net/ssl/ssl_server_config.h"
31 #include "net/test/cert_test_util.h"
24 #include "net/test/embedded_test_server/embedded_test_server_connection_listener .h" 32 #include "net/test/embedded_test_server/embedded_test_server_connection_listener .h"
25 #include "net/test/embedded_test_server/http_connection.h" 33 #include "net/test/embedded_test_server/http_connection.h"
26 #include "net/test/embedded_test_server/http_request.h" 34 #include "net/test/embedded_test_server/http_request.h"
27 #include "net/test/embedded_test_server/http_response.h" 35 #include "net/test/embedded_test_server/http_response.h"
36 #include "net/test/embedded_test_server/request_helpers.h"
28 37
29 namespace net { 38 namespace net {
30 namespace test_server { 39 namespace test_server {
31 40
32 namespace { 41 EmbeddedTestServer::EmbeddedTestServer(bool use_ssl)
33 42 : use_ssl_(use_ssl), connection_listener_(nullptr), port_(0) {
34 class CustomHttpResponse : public HttpResponse { 43 DCHECK(thread_checker_.CalledOnValidThread());
35 public: 44 if (use_ssl_) {
36 CustomHttpResponse(const std::string& headers, const std::string& contents) 45 LoadTestSSLRoot();
37 : headers_(headers), contents_(contents) {
38 } 46 }
39
40 std::string ToResponseString() const override {
41 return headers_ + "\r\n" + contents_;
42 }
43
44 private:
45 std::string headers_;
46 std::string contents_;
47
48 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
49 };
50
51 // Handles |request| by serving a file from under |server_root|.
52 scoped_ptr<HttpResponse> HandleFileRequest(
53 const base::FilePath& server_root,
54 const HttpRequest& request) {
55 // This is a test-only server. Ignore I/O thread restrictions.
56 base::ThreadRestrictions::ScopedAllowIO allow_io;
57
58 std::string relative_url(request.relative_url);
59 // A proxy request will have an absolute path. Simulate the proxy by stripping
60 // the scheme, host, and port.
61 GURL relative_gurl(relative_url);
62 if (relative_gurl.is_valid())
63 relative_url = relative_gurl.PathForRequest();
64
65 // Trim the first byte ('/').
66 std::string request_path = relative_url.substr(1);
67
68 // Remove the query string if present.
69 size_t query_pos = request_path.find('?');
70 if (query_pos != std::string::npos)
71 request_path = request_path.substr(0, query_pos);
72
73 base::FilePath file_path(server_root.AppendASCII(request_path));
74 std::string file_contents;
75 if (!base::ReadFileToString(file_path, &file_contents))
76 return scoped_ptr<HttpResponse>();
77
78 base::FilePath headers_path(
79 file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
80
81 if (base::PathExists(headers_path)) {
82 std::string headers_contents;
83 if (!base::ReadFileToString(headers_path, &headers_contents))
84 return scoped_ptr<HttpResponse>();
85
86 scoped_ptr<CustomHttpResponse> http_response(
87 new CustomHttpResponse(headers_contents, file_contents));
88 return http_response.Pass();
89 }
90
91 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
92 http_response->set_code(HTTP_OK);
93 http_response->set_content(file_contents);
94 return http_response.Pass();
95 }
96
97 } // namespace
98
99 EmbeddedTestServer::EmbeddedTestServer()
100 : connection_listener_(nullptr), port_(0) {
101 DCHECK(thread_checker_.CalledOnValidThread());
102 } 47 }
103 48
104 EmbeddedTestServer::~EmbeddedTestServer() { 49 EmbeddedTestServer::~EmbeddedTestServer() {
105 DCHECK(thread_checker_.CalledOnValidThread()); 50 DCHECK(thread_checker_.CalledOnValidThread());
106 51
107 if (Started() && !ShutdownAndWaitUntilComplete()) { 52 if (Started() && !ShutdownAndWaitUntilComplete()) {
108 LOG(ERROR) << "EmbeddedTestServer failed to shut down."; 53 LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
109 } 54 }
110 } 55 }
111 56
112 void EmbeddedTestServer::SetConnectionListener( 57 void EmbeddedTestServer::SetConnectionListener(
113 EmbeddedTestServerConnectionListener* listener) { 58 EmbeddedTestServerConnectionListener* listener) {
114 DCHECK(!Started()); 59 DCHECK(!Started());
115 connection_listener_ = listener; 60 connection_listener_ = listener;
116 } 61 }
117 62
118 bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 63 bool EmbeddedTestServer::Start() {
119 bool success = InitializeAndListen(); 64 bool success = InitializeAndListen();
120 if (!success) 65 if (!success)
121 return false; 66 return false;
122 StartAcceptingConnections(); 67 StartAcceptingConnections();
123 return true; 68 return true;
124 } 69 }
125 70
71 bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
72 return Start();
73 }
74
126 bool EmbeddedTestServer::InitializeAndListen() { 75 bool EmbeddedTestServer::InitializeAndListen() {
127 DCHECK(!Started()); 76 DCHECK(!Started());
128 77
129 listen_socket_.reset(new TCPServerSocket(nullptr, NetLog::Source())); 78 listen_socket_.reset(new TCPServerSocket(nullptr, NetLog::Source()));
130 79
131 int result = listen_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 10); 80 int result = listen_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 10);
132 if (result) { 81 if (result) {
133 LOG(ERROR) << "Listen failed: " << ErrorToString(result); 82 LOG(ERROR) << "Listen failed: " << ErrorToString(result);
134 listen_socket_.reset(); 83 listen_socket_.reset();
135 return false; 84 return false;
136 } 85 }
137 86
138 result = listen_socket_->GetLocalAddress(&local_endpoint_); 87 result = listen_socket_->GetLocalAddress(&local_endpoint_);
139 if (result != OK) { 88 if (result != OK) {
140 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); 89 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
141 listen_socket_.reset(); 90 listen_socket_.reset();
142 return false; 91 return false;
143 } 92 }
144 93
145 base_url_ = GURL(std::string("http://") + local_endpoint_.ToString()); 94 base_url_ = GURL("http://" + local_endpoint_.ToString());
davidben 2015/10/13 19:43:47 Nit: I'd probably put this in an else.
svaldez 2015/10/13 20:54:43 Done.
95 if (use_ssl_) {
96 base_url_ = GURL("https://" + local_endpoint_.ToString());
97 if (ssl_config_.server_cert == SSLServerConfig::CERT_MISMATCHED_NAME ||
98 ssl_config_.server_cert ==
99 SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN) {
100 base_url_ = GURL(
101 base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
102 }
103 }
146 port_ = local_endpoint_.port(); 104 port_ = local_endpoint_.port();
147 105
148 listen_socket_->DetachFromThread(); 106 listen_socket_->DetachFromThread();
149 return true; 107 return true;
150 } 108 }
151 109
152 void EmbeddedTestServer::StartAcceptingConnections() { 110 void EmbeddedTestServer::StartAcceptingConnections() {
153 DCHECK(!io_thread_.get()); 111 DCHECK(!io_thread_.get());
154 base::Thread::Options thread_options; 112 base::Thread::Options thread_options;
155 thread_options.message_loop_type = base::MessageLoop::TYPE_IO; 113 thread_options.message_loop_type = base::MessageLoop::TYPE_IO;
(...skipping 20 matching lines...) Expand all
176 connections_.end()); 134 connections_.end());
177 connections_.clear(); 135 connections_.clear();
178 } 136 }
179 137
180 void EmbeddedTestServer::HandleRequest(HttpConnection* connection, 138 void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
181 scoped_ptr<HttpRequest> request) { 139 scoped_ptr<HttpRequest> request) {
182 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); 140 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
183 141
184 scoped_ptr<HttpResponse> response; 142 scoped_ptr<HttpResponse> response;
185 143
144 LOG(WARNING) << "Request incoming: " << request->relative_url;
davidben 2015/10/13 19:43:48 Probably didn't mean to keep this line.
svaldez 2015/10/13 20:54:44 Done.
186 for (size_t i = 0; i < request_handlers_.size(); ++i) { 145 for (size_t i = 0; i < request_handlers_.size(); ++i) {
187 response = request_handlers_[i].Run(*request); 146 response = request_handlers_[i].Run(*request);
188 if (response) 147 if (response)
189 break; 148 break;
190 } 149 }
191 150
192 if (!response) { 151 if (!response) {
152 for (size_t i = 0; i < default_request_handlers_.size(); ++i) {
153 response = default_request_handlers_[i].Run(*request);
154 if (response)
155 break;
156 }
157 }
158
159 if (!response) {
193 LOG(WARNING) << "Request not handled. Returning 404: " 160 LOG(WARNING) << "Request not handled. Returning 404: "
194 << request->relative_url; 161 << request->relative_url;
195 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); 162 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
196 not_found_response->set_code(HTTP_NOT_FOUND); 163 not_found_response->set_code(HTTP_NOT_FOUND);
197 response = not_found_response.Pass(); 164 response = not_found_response.Pass();
198 } 165 }
199 166
200 connection->SendResponse(response.Pass(), 167 response->SendResponse(
201 base::Bind(&EmbeddedTestServer::DidClose, 168 base::Bind(&HttpConnection::SendResponse, base::Unretained(connection)),
202 base::Unretained(this), connection)); 169 base::Bind(&EmbeddedTestServer::DidClose, base::Unretained(this),
170 connection));
203 } 171 }
204 172
205 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 173 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
206 DCHECK(Started()) << "You must start the server first."; 174 DCHECK(Started()) << "You must start the server first.";
207 DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE)) 175 DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE))
208 << relative_url; 176 << relative_url;
209 return base_url_.Resolve(relative_url); 177 return base_url_.Resolve(relative_url);
210 } 178 }
211 179
212 GURL EmbeddedTestServer::GetURL( 180 GURL EmbeddedTestServer::GetURL(
213 const std::string& hostname, 181 const std::string& hostname,
214 const std::string& relative_url) const { 182 const std::string& relative_url) const {
215 GURL local_url = GetURL(relative_url); 183 GURL local_url = GetURL(relative_url);
216 GURL::Replacements replace_host; 184 GURL::Replacements replace_host;
217 replace_host.SetHostStr(hostname); 185 replace_host.SetHostStr(hostname);
218 return local_url.ReplaceComponents(replace_host); 186 return local_url.ReplaceComponents(replace_host);
219 } 187 }
220 188
221 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const { 189 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
222 *address_list = AddressList(local_endpoint_); 190 *address_list = AddressList(local_endpoint_);
223 return true; 191 return true;
224 } 192 }
225 193
194 std::string EmbeddedTestServer::GetCertificateName() const {
195 std::string cert_name;
196
197 switch (ssl_config_.server_cert) {
198 case SSLServerConfig::CERT_OK:
199 case SSLServerConfig::CERT_MISMATCHED_NAME:
200 cert_name = "ok_cert.pem";
201 break;
202 case SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN:
203 cert_name = "localhost_cert.pem";
204 break;
205 case SSLServerConfig::CERT_EXPIRED:
206 cert_name = "expired_cert.pem";
207 break;
208 case SSLServerConfig::CERT_CHAIN_WRONG_ROOT:
209 cert_name = "redundant-server-chain.pem";
210 break;
211 case SSLServerConfig::CERT_BAD_VALIDITY:
212 cert_name = "bad_validity.pem";
213 break;
214 }
215
216 return cert_name;
217 }
218
219 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() const {
220 base::FilePath certs_dir(GetTestCertsDirectory());
221 return ImportCertFromFile(certs_dir, GetCertificateName());
222 }
223
226 void EmbeddedTestServer::ServeFilesFromDirectory( 224 void EmbeddedTestServer::ServeFilesFromDirectory(
227 const base::FilePath& directory) { 225 const base::FilePath& directory) {
228 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); 226 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
229 } 227 }
230 228
229 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
230 const std::string relative) {
231 base::FilePath test_data_dir;
232 if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir))
233 ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
234 }
235
236 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
237 const base::FilePath& relative) {
238 base::FilePath test_data_dir;
239 if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir))
240 ServeFilesFromDirectory(test_data_dir.Append(relative));
241 }
242
243 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
244 RegisterDefaultHandlers(this);
245 ServeFilesFromSourceDirectory(directory);
246 }
247
231 void EmbeddedTestServer::RegisterRequestHandler( 248 void EmbeddedTestServer::RegisterRequestHandler(
232 const HandleRequestCallback& callback) { 249 const HandleRequestCallback& callback) {
233 request_handlers_.push_back(callback); 250 request_handlers_.push_back(callback);
234 } 251 }
235 252
253 void EmbeddedTestServer::RegisterDefaultHandler(
254 const HandleRequestCallback& callback) {
255 default_request_handlers_.push_back(callback);
256 }
257
258 void EmbeddedTestServer::LoadTestSSLRoot() {
259 TestRootCerts* root_certs = TestRootCerts::GetInstance();
260 if (!root_certs)
261 return;
262 base::FilePath certs_dir(GetTestCertsDirectory());
263 root_certs->AddFromFile(certs_dir.AppendASCII("root_ca_cert.pem"));
264 }
265
266 scoped_ptr<StreamSocket> EmbeddedTestServer::DoSSLUpgrade(
267 scoped_ptr<StreamSocket> connection) {
268 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
269
270 base::FilePath certs_dir(GetTestCertsDirectory());
271 std::string cert_name = GetCertificateName();
272 scoped_refptr<X509Certificate> server_cert = GetCertificate();
273
274 base::FilePath key_path = certs_dir.AppendASCII(cert_name);
275 std::string key_string;
276 DCHECK(base::ReadFileToString(key_path, &key_string));
davidben 2015/10/13 19:43:48 If wrapped in a DCHECK, release builds won't even
svaldez 2015/10/13 20:54:44 Done.
277 std::vector<std::string> headers;
278 headers.push_back("PRIVATE KEY");
279 PEMTokenizer pem_tok(key_string, headers);
davidben 2015/10/13 19:43:48 Nit: pem_tokenizer
svaldez 2015/10/13 20:54:44 Done.
280 pem_tok.GetNext();
281 std::vector<uint8> key_vector;
davidben 2015/10/13 19:43:47 Nit: uint8_t for new code
svaldez 2015/10/13 20:54:44 Done.
282 key_vector.assign(pem_tok.data().begin(), pem_tok.data().end());
283
284 scoped_ptr<crypto::RSAPrivateKey> server_key(
285 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
286
287 return CreateSSLServerSocket(connection.Pass(), server_cert.get(),
288 server_key.get(), ssl_config_);
289 }
290
236 void EmbeddedTestServer::DoAcceptLoop() { 291 void EmbeddedTestServer::DoAcceptLoop() {
237 int rv = OK; 292 int rv = OK;
238 while (rv == OK) { 293 while (rv == OK) {
239 rv = listen_socket_->Accept( 294 rv = listen_socket_->Accept(
240 &accepted_socket_, base::Bind(&EmbeddedTestServer::OnAcceptCompleted, 295 &accepted_socket_, base::Bind(&EmbeddedTestServer::OnAcceptCompleted,
241 base::Unretained(this))); 296 base::Unretained(this)));
242 if (rv == ERR_IO_PENDING) 297 if (rv == ERR_IO_PENDING)
243 return; 298 return;
244 HandleAcceptResult(accepted_socket_.Pass()); 299 HandleAcceptResult(accepted_socket_.Pass());
245 } 300 }
246 } 301 }
247 302
248 void EmbeddedTestServer::OnAcceptCompleted(int rv) { 303 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
249 DCHECK_NE(ERR_IO_PENDING, rv); 304 DCHECK_NE(ERR_IO_PENDING, rv);
250 HandleAcceptResult(accepted_socket_.Pass()); 305 HandleAcceptResult(accepted_socket_.Pass());
251 DoAcceptLoop(); 306 DoAcceptLoop();
252 } 307 }
253 308
309 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
310 if (connection->socket_->IsConnected()) {
311 ReadData(connection);
312 }
davidben 2015/10/13 19:43:48 Nit: No curlies
svaldez 2015/10/13 20:54:44 Done.
313 }
314
254 void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) { 315 void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
255 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); 316 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
256 if (connection_listener_) 317 if (connection_listener_)
257 connection_listener_->AcceptedSocket(*socket); 318 connection_listener_->AcceptedSocket(*socket);
258 319
320 if (use_ssl_) {
321 socket = DoSSLUpgrade(socket.Pass());
322 }
davidben 2015/10/13 19:43:47 Nit: No curlies
svaldez 2015/10/13 20:54:44 Done.
323
259 HttpConnection* http_connection = new HttpConnection( 324 HttpConnection* http_connection = new HttpConnection(
260 socket.Pass(), 325 socket.Pass(),
261 base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this))); 326 base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this)));
262 connections_[http_connection->socket_.get()] = http_connection; 327 connections_[http_connection->socket_.get()] = http_connection;
263 328
264 ReadData(http_connection); 329 if (use_ssl_) {
330 SSLServerSocket* ssl_socket =
331 (SSLServerSocket*)http_connection->socket_.get();
davidben 2015/10/13 19:43:48 static_cast<SSLServerSocket*>(....)
svaldez 2015/10/13 20:54:44 Done.
332 int rv = ssl_socket->Handshake(
333 base::Bind(&EmbeddedTestServer::OnHandshakeDone, base::Unretained(this),
334 http_connection));
335 if (rv != ERR_IO_PENDING) {
336 OnHandshakeDone(http_connection, rv);
davidben 2015/10/13 19:43:47 Nit: no curlies
svaldez 2015/10/13 20:54:44 Done.
337 }
338 } else {
339 ReadData(http_connection);
340 }
265 } 341 }
266 342
267 void EmbeddedTestServer::ReadData(HttpConnection* connection) { 343 void EmbeddedTestServer::ReadData(HttpConnection* connection) {
268 while (true) { 344 while (true) {
269 int rv = 345 int rv =
270 connection->ReadData(base::Bind(&EmbeddedTestServer::OnReadCompleted, 346 connection->ReadData(base::Bind(&EmbeddedTestServer::OnReadCompleted,
271 base::Unretained(this), connection)); 347 base::Unretained(this), connection));
272 if (rv == ERR_IO_PENDING) 348 if (rv == ERR_IO_PENDING)
273 return; 349 return;
274 if (!HandleReadResult(connection, rv)) 350 if (!HandleReadResult(connection, rv))
(...skipping 64 matching lines...) Expand 10 before | Expand all | Expand 10 after
339 run_loop.QuitClosure())) { 415 run_loop.QuitClosure())) {
340 return false; 416 return false;
341 } 417 }
342 run_loop.Run(); 418 run_loop.Run();
343 419
344 return true; 420 return true;
345 } 421 }
346 422
347 } // namespace test_server 423 } // namespace test_server
348 } // namespace net 424 } // namespace net
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698