Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(76)

Side by Side Diff: net/test/embedded_test_server/embedded_test_server.cc

Issue 1376593007: SSL in EmbeddedTestServer (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: More cleanup. Created 5 years, 2 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be 2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file. 3 // found in the LICENSE file.
4 4
5 #include "net/test/embedded_test_server/embedded_test_server.h" 5 #include "net/test/embedded_test_server/embedded_test_server.h"
6 6
7 #include "base/bind.h" 7 #include "base/bind.h"
8 #include "base/files/file_path.h" 8 #include "base/files/file_path.h"
9 #include "base/files/file_util.h" 9 #include "base/files/file_util.h"
10 #include "base/location.h" 10 #include "base/location.h"
11 #include "base/logging.h" 11 #include "base/logging.h"
12 #include "base/message_loop/message_loop.h" 12 #include "base/message_loop/message_loop.h"
13 #include "base/path_service.h"
13 #include "base/process/process_metrics.h" 14 #include "base/process/process_metrics.h"
14 #include "base/run_loop.h" 15 #include "base/run_loop.h"
15 #include "base/stl_util.h" 16 #include "base/stl_util.h"
16 #include "base/strings/string_util.h" 17 #include "base/strings/string_util.h"
17 #include "base/strings/stringprintf.h" 18 #include "base/strings/stringprintf.h"
18 #include "base/thread_task_runner_handle.h" 19 #include "base/thread_task_runner_handle.h"
19 #include "base/threading/thread_restrictions.h" 20 #include "base/threading/thread_restrictions.h"
21 #include "crypto/rsa_private_key.h"
20 #include "net/base/ip_endpoint.h" 22 #include "net/base/ip_endpoint.h"
21 #include "net/base/net_errors.h" 23 #include "net/base/net_errors.h"
24 #include "net/base/test_data_directory.h"
25 #include "net/cert/pem_tokenizer.h"
26 #include "net/cert/test_root_certs.h"
27 #include "net/socket/ssl_server_socket.h"
22 #include "net/socket/stream_socket.h" 28 #include "net/socket/stream_socket.h"
23 #include "net/socket/tcp_server_socket.h" 29 #include "net/socket/tcp_server_socket.h"
30 #include "net/ssl/ssl_server_config.h"
31 #include "net/test/cert_test_util.h"
24 #include "net/test/embedded_test_server/embedded_test_server_connection_listener .h" 32 #include "net/test/embedded_test_server/embedded_test_server_connection_listener .h"
25 #include "net/test/embedded_test_server/http_connection.h" 33 #include "net/test/embedded_test_server/http_connection.h"
26 #include "net/test/embedded_test_server/http_request.h" 34 #include "net/test/embedded_test_server/http_request.h"
27 #include "net/test/embedded_test_server/http_response.h" 35 #include "net/test/embedded_test_server/http_response.h"
36 #include "net/test/embedded_test_server/request_helpers.h"
28 37
29 namespace net { 38 namespace net {
30 namespace test_server { 39 namespace test_server {
31 40
32 namespace { 41 EmbeddedTestServer::EmbeddedTestServer()
42 : use_ssl_(false), connection_listener_(nullptr), port_(0) {}
mmenke 2015/10/14 21:59:15 I believe this can just be: EmbeddedTestServer::E
svaldez 2015/10/14 22:33:39 Done.
33 43
34 class CustomHttpResponse : public HttpResponse { 44 EmbeddedTestServer::EmbeddedTestServer(Type type)
35 public:
36 CustomHttpResponse(const std::string& headers, const std::string& contents)
37 : headers_(headers), contents_(contents) {
38 }
39
40 std::string ToResponseString() const override {
41 return headers_ + "\r\n" + contents_;
42 }
43
44 private:
45 std::string headers_;
46 std::string contents_;
47
48 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
49 };
50
51 // Handles |request| by serving a file from under |server_root|.
52 scoped_ptr<HttpResponse> HandleFileRequest(
53 const base::FilePath& server_root,
54 const HttpRequest& request) {
55 // This is a test-only server. Ignore I/O thread restrictions.
56 base::ThreadRestrictions::ScopedAllowIO allow_io;
57
58 std::string relative_url(request.relative_url);
59 // A proxy request will have an absolute path. Simulate the proxy by stripping
60 // the scheme, host, and port.
61 GURL relative_gurl(relative_url);
62 if (relative_gurl.is_valid())
63 relative_url = relative_gurl.PathForRequest();
64
65 // Trim the first byte ('/').
66 std::string request_path = relative_url.substr(1);
67
68 // Remove the query string if present.
69 size_t query_pos = request_path.find('?');
70 if (query_pos != std::string::npos)
71 request_path = request_path.substr(0, query_pos);
72
73 base::FilePath file_path(server_root.AppendASCII(request_path));
74 std::string file_contents;
75 if (!base::ReadFileToString(file_path, &file_contents))
76 return scoped_ptr<HttpResponse>();
77
78 base::FilePath headers_path(
79 file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
80
81 if (base::PathExists(headers_path)) {
82 std::string headers_contents;
83 if (!base::ReadFileToString(headers_path, &headers_contents))
84 return scoped_ptr<HttpResponse>();
85
86 scoped_ptr<CustomHttpResponse> http_response(
87 new CustomHttpResponse(headers_contents, file_contents));
88 return http_response.Pass();
89 }
90
91 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
92 http_response->set_code(HTTP_OK);
93 http_response->set_content(file_contents);
94 return http_response.Pass();
95 }
96
97 } // namespace
98
99 EmbeddedTestServer::EmbeddedTestServer()
100 : connection_listener_(nullptr), port_(0) { 45 : connection_listener_(nullptr), port_(0) {
101 DCHECK(thread_checker_.CalledOnValidThread()); 46 DCHECK(thread_checker_.CalledOnValidThread());
47 use_ssl_ = (type == TYPE_HTTPS);
mmenke 2015/10/14 21:59:15 nit: Put in initializer list, and can make it con
svaldez 2015/10/14 22:33:39 Done.
48 if (use_ssl_)
49 LoadTestSSLRoot();
102 } 50 }
103 51
104 EmbeddedTestServer::~EmbeddedTestServer() { 52 EmbeddedTestServer::~EmbeddedTestServer() {
105 DCHECK(thread_checker_.CalledOnValidThread()); 53 DCHECK(thread_checker_.CalledOnValidThread());
106 54
107 if (Started() && !ShutdownAndWaitUntilComplete()) { 55 if (Started() && !ShutdownAndWaitUntilComplete()) {
108 LOG(ERROR) << "EmbeddedTestServer failed to shut down."; 56 LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
109 } 57 }
110 } 58 }
111 59
112 void EmbeddedTestServer::SetConnectionListener( 60 void EmbeddedTestServer::SetConnectionListener(
113 EmbeddedTestServerConnectionListener* listener) { 61 EmbeddedTestServerConnectionListener* listener) {
114 DCHECK(!Started()); 62 DCHECK(!Started());
115 connection_listener_ = listener; 63 connection_listener_ = listener;
116 } 64 }
117 65
118 bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 66 bool EmbeddedTestServer::Start() {
119 bool success = InitializeAndListen(); 67 bool success = InitializeAndListen();
120 if (!success) 68 if (!success)
121 return false; 69 return false;
122 StartAcceptingConnections(); 70 StartAcceptingConnections();
123 return true; 71 return true;
124 } 72 }
125 73
74 bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
75 return Start();
76 }
77
126 bool EmbeddedTestServer::InitializeAndListen() { 78 bool EmbeddedTestServer::InitializeAndListen() {
127 DCHECK(!Started()); 79 DCHECK(!Started());
128 80
129 listen_socket_.reset(new TCPServerSocket(nullptr, NetLog::Source())); 81 listen_socket_.reset(new TCPServerSocket(nullptr, NetLog::Source()));
130 82
131 int result = listen_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 10); 83 int result = listen_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 10);
132 if (result) { 84 if (result) {
133 LOG(ERROR) << "Listen failed: " << ErrorToString(result); 85 LOG(ERROR) << "Listen failed: " << ErrorToString(result);
134 listen_socket_.reset(); 86 listen_socket_.reset();
135 return false; 87 return false;
136 } 88 }
137 89
138 result = listen_socket_->GetLocalAddress(&local_endpoint_); 90 result = listen_socket_->GetLocalAddress(&local_endpoint_);
139 if (result != OK) { 91 if (result != OK) {
140 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); 92 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
141 listen_socket_.reset(); 93 listen_socket_.reset();
142 return false; 94 return false;
143 } 95 }
144 96
145 base_url_ = GURL(std::string("http://") + local_endpoint_.ToString()); 97 if (use_ssl_) {
98 base_url_ = GURL("https://" + local_endpoint_.ToString());
99 if (ssl_config_.server_cert == SSLServerConfig::CERT_MISMATCHED_NAME ||
100 ssl_config_.server_cert ==
101 SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN) {
102 base_url_ = GURL(
103 base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
104 }
105 } else {
106 base_url_ = GURL("http://" + local_endpoint_.ToString());
107 }
146 port_ = local_endpoint_.port(); 108 port_ = local_endpoint_.port();
147 109
148 listen_socket_->DetachFromThread(); 110 listen_socket_->DetachFromThread();
149 return true; 111 return true;
150 } 112 }
151 113
152 void EmbeddedTestServer::StartAcceptingConnections() { 114 void EmbeddedTestServer::StartAcceptingConnections() {
153 DCHECK(!io_thread_.get()); 115 DCHECK(!io_thread_.get());
154 base::Thread::Options thread_options; 116 base::Thread::Options thread_options;
155 thread_options.message_loop_type = base::MessageLoop::TYPE_IO; 117 thread_options.message_loop_type = base::MessageLoop::TYPE_IO;
(...skipping 27 matching lines...) Expand all
183 145
184 scoped_ptr<HttpResponse> response; 146 scoped_ptr<HttpResponse> response;
185 147
186 for (size_t i = 0; i < request_handlers_.size(); ++i) { 148 for (size_t i = 0; i < request_handlers_.size(); ++i) {
187 response = request_handlers_[i].Run(*request); 149 response = request_handlers_[i].Run(*request);
188 if (response) 150 if (response)
189 break; 151 break;
190 } 152 }
191 153
192 if (!response) { 154 if (!response) {
155 for (size_t i = 0; i < default_request_handlers_.size(); ++i) {
mmenke 2015/10/14 21:59:15 Can use a range loop.
svaldez 2015/10/14 22:33:39 Done.
156 response = default_request_handlers_[i].Run(*request);
157 if (response)
158 break;
159 }
160 }
161
162 if (!response) {
193 LOG(WARNING) << "Request not handled. Returning 404: " 163 LOG(WARNING) << "Request not handled. Returning 404: "
194 << request->relative_url; 164 << request->relative_url;
195 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); 165 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
196 not_found_response->set_code(HTTP_NOT_FOUND); 166 not_found_response->set_code(HTTP_NOT_FOUND);
197 response = not_found_response.Pass(); 167 response = not_found_response.Pass();
198 } 168 }
199 169
200 connection->SendResponse(response.Pass(), 170 response->SendResponse(
201 base::Bind(&EmbeddedTestServer::DidClose, 171 base::Bind(&HttpConnection::SendResponse, base::Unretained(connection)),
202 base::Unretained(this), connection)); 172 base::Bind(&EmbeddedTestServer::DidClose, base::Unretained(this),
173 connection));
203 } 174 }
204 175
205 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 176 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
206 DCHECK(Started()) << "You must start the server first."; 177 DCHECK(Started()) << "You must start the server first.";
207 DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE)) 178 DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE))
208 << relative_url; 179 << relative_url;
209 return base_url_.Resolve(relative_url); 180 return base_url_.Resolve(relative_url);
210 } 181 }
211 182
212 GURL EmbeddedTestServer::GetURL( 183 GURL EmbeddedTestServer::GetURL(
213 const std::string& hostname, 184 const std::string& hostname,
214 const std::string& relative_url) const { 185 const std::string& relative_url) const {
215 GURL local_url = GetURL(relative_url); 186 GURL local_url = GetURL(relative_url);
216 GURL::Replacements replace_host; 187 GURL::Replacements replace_host;
217 replace_host.SetHostStr(hostname); 188 replace_host.SetHostStr(hostname);
218 return local_url.ReplaceComponents(replace_host); 189 return local_url.ReplaceComponents(replace_host);
219 } 190 }
220 191
221 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const { 192 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
222 *address_list = AddressList(local_endpoint_); 193 *address_list = AddressList(local_endpoint_);
223 return true; 194 return true;
224 } 195 }
225 196
197 void EmbeddedTestServer::SetSSLConfig(net::SSLServerConfig ssl_config) {
198 DCHECK(!Started());
199 ssl_config_ = ssl_config;
200 }
201
202 std::string EmbeddedTestServer::GetCertificateName() const {
203 std::string cert_name;
204
205 switch (ssl_config_.server_cert) {
206 case SSLServerConfig::CERT_OK:
207 case SSLServerConfig::CERT_MISMATCHED_NAME:
208 cert_name = "ok_cert.pem";
209 break;
210 case SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN:
211 cert_name = "localhost_cert.pem";
212 break;
213 case SSLServerConfig::CERT_EXPIRED:
214 cert_name = "expired_cert.pem";
215 break;
216 case SSLServerConfig::CERT_CHAIN_WRONG_ROOT:
217 cert_name = "redundant-server-chain.pem";
218 break;
219 case SSLServerConfig::CERT_BAD_VALIDITY:
220 cert_name = "bad_validity.pem";
221 break;
222 }
223
224 return cert_name;
225 }
226
227 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() const {
228 base::FilePath certs_dir(GetTestCertsDirectory());
229 return ImportCertFromFile(certs_dir, GetCertificateName());
230 }
231
226 void EmbeddedTestServer::ServeFilesFromDirectory( 232 void EmbeddedTestServer::ServeFilesFromDirectory(
227 const base::FilePath& directory) { 233 const base::FilePath& directory) {
228 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); 234 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
229 } 235 }
230 236
237 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
238 const std::string relative) {
239 base::FilePath test_data_dir;
240 if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir))
241 ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
242 }
243
244 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
245 const base::FilePath& relative) {
246 base::FilePath test_data_dir;
247 if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir))
248 ServeFilesFromDirectory(test_data_dir.Append(relative));
249 }
250
251 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
252 RegisterDefaultHandlers(this);
253 ServeFilesFromSourceDirectory(directory);
254 }
255
231 void EmbeddedTestServer::RegisterRequestHandler( 256 void EmbeddedTestServer::RegisterRequestHandler(
232 const HandleRequestCallback& callback) { 257 const HandleRequestCallback& callback) {
233 request_handlers_.push_back(callback); 258 request_handlers_.push_back(callback);
234 } 259 }
235 260
261 void EmbeddedTestServer::RegisterDefaultHandler(
262 const HandleRequestCallback& callback) {
263 default_request_handlers_.push_back(callback);
264 }
265
266 void EmbeddedTestServer::LoadTestSSLRoot() {
267 TestRootCerts* root_certs = TestRootCerts::GetInstance();
268 if (!root_certs)
269 return;
270 base::FilePath certs_dir(GetTestCertsDirectory());
271 root_certs->AddFromFile(certs_dir.AppendASCII("root_ca_cert.pem"));
272 }
273
274 scoped_ptr<StreamSocket> EmbeddedTestServer::DoSSLUpgrade(
275 scoped_ptr<StreamSocket> connection) {
276 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
277
278 base::FilePath certs_dir(GetTestCertsDirectory());
279 std::string cert_name = GetCertificateName();
280 scoped_refptr<X509Certificate> server_cert = GetCertificate();
281
282 base::FilePath key_path = certs_dir.AppendASCII(cert_name);
283 std::string key_string;
284 CHECK(base::ReadFileToString(key_path, &key_string));
285 std::vector<std::string> headers;
286 headers.push_back("PRIVATE KEY");
287 PEMTokenizer pem_tokenizer(key_string, headers);
288 pem_tokenizer.GetNext();
289 std::vector<uint8_t> key_vector;
290 key_vector.assign(pem_tokenizer.data().begin(), pem_tokenizer.data().end());
291
292 scoped_ptr<crypto::RSAPrivateKey> server_key(
293 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
294
295 return CreateSSLServerSocket(connection.Pass(), server_cert.get(),
296 server_key.get(), ssl_config_);
297 }
298
236 void EmbeddedTestServer::DoAcceptLoop() { 299 void EmbeddedTestServer::DoAcceptLoop() {
237 int rv = OK; 300 int rv = OK;
238 while (rv == OK) { 301 while (rv == OK) {
239 rv = listen_socket_->Accept( 302 rv = listen_socket_->Accept(
240 &accepted_socket_, base::Bind(&EmbeddedTestServer::OnAcceptCompleted, 303 &accepted_socket_, base::Bind(&EmbeddedTestServer::OnAcceptCompleted,
241 base::Unretained(this))); 304 base::Unretained(this)));
242 if (rv == ERR_IO_PENDING) 305 if (rv == ERR_IO_PENDING)
243 return; 306 return;
244 HandleAcceptResult(accepted_socket_.Pass()); 307 HandleAcceptResult(accepted_socket_.Pass());
245 } 308 }
246 } 309 }
247 310
248 void EmbeddedTestServer::OnAcceptCompleted(int rv) { 311 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
249 DCHECK_NE(ERR_IO_PENDING, rv); 312 DCHECK_NE(ERR_IO_PENDING, rv);
250 HandleAcceptResult(accepted_socket_.Pass()); 313 HandleAcceptResult(accepted_socket_.Pass());
251 DoAcceptLoop(); 314 DoAcceptLoop();
252 } 315 }
253 316
317 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
318 if (connection->socket_->IsConnected())
319 ReadData(connection);
320 }
321
254 void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) { 322 void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
255 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); 323 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
256 if (connection_listener_) 324 if (connection_listener_)
257 connection_listener_->AcceptedSocket(*socket); 325 connection_listener_->AcceptedSocket(*socket);
258 326
327 if (use_ssl_)
328 socket = DoSSLUpgrade(socket.Pass());
329
259 HttpConnection* http_connection = new HttpConnection( 330 HttpConnection* http_connection = new HttpConnection(
260 socket.Pass(), 331 socket.Pass(),
261 base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this))); 332 base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this)));
262 connections_[http_connection->socket_.get()] = http_connection; 333 connections_[http_connection->socket_.get()] = http_connection;
263 334
264 ReadData(http_connection); 335 if (use_ssl_) {
336 SSLServerSocket* ssl_socket =
337 static_cast<SSLServerSocket*>(http_connection->socket_.get());
338 int rv = ssl_socket->Handshake(
339 base::Bind(&EmbeddedTestServer::OnHandshakeDone, base::Unretained(this),
340 http_connection));
341 if (rv != ERR_IO_PENDING)
342 OnHandshakeDone(http_connection, rv);
343 } else {
344 ReadData(http_connection);
345 }
265 } 346 }
266 347
267 void EmbeddedTestServer::ReadData(HttpConnection* connection) { 348 void EmbeddedTestServer::ReadData(HttpConnection* connection) {
268 while (true) { 349 while (true) {
269 int rv = 350 int rv =
270 connection->ReadData(base::Bind(&EmbeddedTestServer::OnReadCompleted, 351 connection->ReadData(base::Bind(&EmbeddedTestServer::OnReadCompleted,
271 base::Unretained(this), connection)); 352 base::Unretained(this), connection));
272 if (rv == ERR_IO_PENDING) 353 if (rv == ERR_IO_PENDING)
273 return; 354 return;
274 if (!HandleReadResult(connection, rv)) 355 if (!HandleReadResult(connection, rv))
(...skipping 64 matching lines...) Expand 10 before | Expand all | Expand 10 after
339 run_loop.QuitClosure())) { 420 run_loop.QuitClosure())) {
340 return false; 421 return false;
341 } 422 }
342 run_loop.Run(); 423 run_loop.Run();
343 424
344 return true; 425 return true;
345 } 426 }
346 427
347 } // namespace test_server 428 } // namespace test_server
348 } // namespace net 429 } // namespace net
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698