Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(166)

Side by Side Diff: net/test/embedded_test_server/embedded_test_server.cc

Issue 1376593007: SSL in EmbeddedTestServer (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Removing unused macros. Created 5 years, 2 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be 2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file. 3 // found in the LICENSE file.
4 4
5 #include "net/test/embedded_test_server/embedded_test_server.h" 5 #include "net/test/embedded_test_server/embedded_test_server.h"
6 6
7 #include "base/bind.h" 7 #include "base/bind.h"
8 #include "base/files/file_path.h" 8 #include "base/files/file_path.h"
9 #include "base/files/file_util.h" 9 #include "base/files/file_util.h"
10 #include "base/location.h" 10 #include "base/location.h"
11 #include "base/logging.h" 11 #include "base/logging.h"
12 #include "base/message_loop/message_loop.h" 12 #include "base/message_loop/message_loop.h"
13 #include "base/path_service.h"
13 #include "base/process/process_metrics.h" 14 #include "base/process/process_metrics.h"
14 #include "base/run_loop.h" 15 #include "base/run_loop.h"
15 #include "base/stl_util.h" 16 #include "base/stl_util.h"
16 #include "base/strings/string_util.h" 17 #include "base/strings/string_util.h"
17 #include "base/strings/stringprintf.h" 18 #include "base/strings/stringprintf.h"
18 #include "base/thread_task_runner_handle.h" 19 #include "base/thread_task_runner_handle.h"
19 #include "base/threading/thread_restrictions.h" 20 #include "base/threading/thread_restrictions.h"
21 #include "crypto/rsa_private_key.h"
20 #include "net/base/ip_endpoint.h" 22 #include "net/base/ip_endpoint.h"
21 #include "net/base/net_errors.h" 23 #include "net/base/net_errors.h"
24 #include "net/base/test_data_directory.h"
25 #include "net/cert/pem_tokenizer.h"
26 #include "net/cert/test_root_certs.h"
27 #include "net/socket/ssl_server_socket.h"
22 #include "net/socket/stream_socket.h" 28 #include "net/socket/stream_socket.h"
23 #include "net/socket/tcp_server_socket.h" 29 #include "net/socket/tcp_server_socket.h"
30 #include "net/ssl/ssl_server_config.h"
31 #include "net/test/cert_test_util.h"
24 #include "net/test/embedded_test_server/embedded_test_server_connection_listener .h" 32 #include "net/test/embedded_test_server/embedded_test_server_connection_listener .h"
25 #include "net/test/embedded_test_server/http_connection.h" 33 #include "net/test/embedded_test_server/http_connection.h"
26 #include "net/test/embedded_test_server/http_request.h" 34 #include "net/test/embedded_test_server/http_request.h"
27 #include "net/test/embedded_test_server/http_response.h" 35 #include "net/test/embedded_test_server/http_response.h"
36 #include "net/test/embedded_test_server/request_helpers.h"
28 37
29 namespace net { 38 namespace net {
30 namespace test_server { 39 namespace test_server {
31 40
32 namespace { 41 EmbeddedTestServer::EmbeddedTestServer() : EmbeddedTestServer(TYPE_HTTP) {}
33 42
34 class CustomHttpResponse : public HttpResponse { 43 EmbeddedTestServer::EmbeddedTestServer(const Type type)
35 public: 44 : is_using_ssl_(type == TYPE_HTTPS),
36 CustomHttpResponse(const std::string& headers, const std::string& contents) 45 connection_listener_(nullptr),
37 : headers_(headers), contents_(contents) { 46 port_(0) {
38 } 47 DCHECK(thread_checker_.CalledOnValidThread());
39 48
40 std::string ToResponseString() const override { 49 if (is_using_ssl_)
41 return headers_ + "\r\n" + contents_; 50 LoadTestSSLRoot();
42 }
43
44 private:
45 std::string headers_;
46 std::string contents_;
47
48 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
49 };
50
51 // Handles |request| by serving a file from under |server_root|.
52 scoped_ptr<HttpResponse> HandleFileRequest(
53 const base::FilePath& server_root,
54 const HttpRequest& request) {
55 // This is a test-only server. Ignore I/O thread restrictions.
56 base::ThreadRestrictions::ScopedAllowIO allow_io;
57
58 std::string relative_url(request.relative_url);
59 // A proxy request will have an absolute path. Simulate the proxy by stripping
60 // the scheme, host, and port.
61 GURL relative_gurl(relative_url);
62 if (relative_gurl.is_valid())
63 relative_url = relative_gurl.PathForRequest();
64
65 // Trim the first byte ('/').
66 std::string request_path = relative_url.substr(1);
67
68 // Remove the query string if present.
69 size_t query_pos = request_path.find('?');
70 if (query_pos != std::string::npos)
71 request_path = request_path.substr(0, query_pos);
72
73 base::FilePath file_path(server_root.AppendASCII(request_path));
74 std::string file_contents;
75 if (!base::ReadFileToString(file_path, &file_contents))
76 return scoped_ptr<HttpResponse>();
77
78 base::FilePath headers_path(
79 file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
80
81 if (base::PathExists(headers_path)) {
82 std::string headers_contents;
83 if (!base::ReadFileToString(headers_path, &headers_contents))
84 return scoped_ptr<HttpResponse>();
85
86 scoped_ptr<CustomHttpResponse> http_response(
87 new CustomHttpResponse(headers_contents, file_contents));
88 return http_response.Pass();
89 }
90
91 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
92 http_response->set_code(HTTP_OK);
93 http_response->set_content(file_contents);
94 return http_response.Pass();
95 }
96
97 } // namespace
98
99 EmbeddedTestServer::EmbeddedTestServer()
100 : connection_listener_(nullptr), port_(0) {
101 DCHECK(thread_checker_.CalledOnValidThread());
102 } 51 }
103 52
104 EmbeddedTestServer::~EmbeddedTestServer() { 53 EmbeddedTestServer::~EmbeddedTestServer() {
105 DCHECK(thread_checker_.CalledOnValidThread()); 54 DCHECK(thread_checker_.CalledOnValidThread());
106 55
107 if (Started() && !ShutdownAndWaitUntilComplete()) { 56 if (Started() && !ShutdownAndWaitUntilComplete()) {
108 LOG(ERROR) << "EmbeddedTestServer failed to shut down."; 57 LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
109 } 58 }
110 } 59 }
111 60
112 void EmbeddedTestServer::SetConnectionListener( 61 void EmbeddedTestServer::SetConnectionListener(
113 EmbeddedTestServerConnectionListener* listener) { 62 EmbeddedTestServerConnectionListener* listener) {
114 DCHECK(!Started()); 63 DCHECK(!Started());
115 connection_listener_ = listener; 64 connection_listener_ = listener;
116 } 65 }
117 66
118 bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 67 bool EmbeddedTestServer::Start() {
119 bool success = InitializeAndListen(); 68 bool success = InitializeAndListen();
120 if (!success) 69 if (!success)
121 return false; 70 return false;
122 StartAcceptingConnections(); 71 StartAcceptingConnections();
123 return true; 72 return true;
124 } 73 }
125 74
75 bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
76 return Start();
77 }
78
126 bool EmbeddedTestServer::InitializeAndListen() { 79 bool EmbeddedTestServer::InitializeAndListen() {
127 DCHECK(!Started()); 80 DCHECK(!Started());
128 81
129 listen_socket_.reset(new TCPServerSocket(nullptr, NetLog::Source())); 82 listen_socket_.reset(new TCPServerSocket(nullptr, NetLog::Source()));
130 83
131 int result = listen_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 10); 84 int result = listen_socket_->ListenWithAddressAndPort("127.0.0.1", 0, 10);
132 if (result) { 85 if (result) {
133 LOG(ERROR) << "Listen failed: " << ErrorToString(result); 86 LOG(ERROR) << "Listen failed: " << ErrorToString(result);
134 listen_socket_.reset(); 87 listen_socket_.reset();
135 return false; 88 return false;
136 } 89 }
137 90
138 result = listen_socket_->GetLocalAddress(&local_endpoint_); 91 result = listen_socket_->GetLocalAddress(&local_endpoint_);
139 if (result != OK) { 92 if (result != OK) {
140 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); 93 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
141 listen_socket_.reset(); 94 listen_socket_.reset();
142 return false; 95 return false;
143 } 96 }
144 97
145 base_url_ = GURL(std::string("http://") + local_endpoint_.ToString()); 98 if (is_using_ssl_) {
99 base_url_ = GURL("https://" + local_endpoint_.ToString());
100 if (ssl_config_.server_cert == SSLServerConfig::CERT_MISMATCHED_NAME ||
101 ssl_config_.server_cert ==
102 SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN) {
103 base_url_ = GURL(
104 base::StringPrintf("https://localhost:%d", local_endpoint_.port()));
105 }
106 } else {
107 base_url_ = GURL("http://" + local_endpoint_.ToString());
108 }
146 port_ = local_endpoint_.port(); 109 port_ = local_endpoint_.port();
147 110
148 listen_socket_->DetachFromThread(); 111 listen_socket_->DetachFromThread();
149 return true; 112 return true;
150 } 113 }
151 114
152 void EmbeddedTestServer::StartAcceptingConnections() { 115 void EmbeddedTestServer::StartAcceptingConnections() {
153 DCHECK(!io_thread_.get()); 116 DCHECK(!io_thread_.get());
154 base::Thread::Options thread_options; 117 base::Thread::Options thread_options;
155 thread_options.message_loop_type = base::MessageLoop::TYPE_IO; 118 thread_options.message_loop_type = base::MessageLoop::TYPE_IO;
(...skipping 20 matching lines...) Expand all
176 connections_.end()); 139 connections_.end());
177 connections_.clear(); 140 connections_.clear();
178 } 141 }
179 142
180 void EmbeddedTestServer::HandleRequest(HttpConnection* connection, 143 void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
181 scoped_ptr<HttpRequest> request) { 144 scoped_ptr<HttpRequest> request) {
182 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); 145 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
183 146
184 scoped_ptr<HttpResponse> response; 147 scoped_ptr<HttpResponse> response;
185 148
186 for (size_t i = 0; i < request_handlers_.size(); ++i) { 149 for (auto handler : request_handlers_) {
187 response = request_handlers_[i].Run(*request); 150 response = handler.Run(*request);
188 if (response) 151 if (response)
189 break; 152 break;
190 } 153 }
191 154
192 if (!response) { 155 if (!response) {
156 for (auto handler : default_request_handlers_) {
157 response = handler.Run(*request);
158 if (response)
159 break;
160 }
161 }
162
163 if (!response) {
193 LOG(WARNING) << "Request not handled. Returning 404: " 164 LOG(WARNING) << "Request not handled. Returning 404: "
194 << request->relative_url; 165 << request->relative_url;
195 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); 166 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
196 not_found_response->set_code(HTTP_NOT_FOUND); 167 not_found_response->set_code(HTTP_NOT_FOUND);
197 response = not_found_response.Pass(); 168 response = not_found_response.Pass();
198 } 169 }
199 170
200 connection->SendResponse(response.Pass(), 171 response->SendResponse(
201 base::Bind(&EmbeddedTestServer::DidClose, 172 base::Bind(&HttpConnection::SendResponse, base::Unretained(connection)),
202 base::Unretained(this), connection)); 173 base::Bind(&EmbeddedTestServer::DidClose, base::Unretained(this),
174 connection));
203 } 175 }
204 176
205 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 177 GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
206 DCHECK(Started()) << "You must start the server first."; 178 DCHECK(Started()) << "You must start the server first.";
207 DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE)) 179 DCHECK(base::StartsWith(relative_url, "/", base::CompareCase::SENSITIVE))
208 << relative_url; 180 << relative_url;
209 return base_url_.Resolve(relative_url); 181 return base_url_.Resolve(relative_url);
210 } 182 }
211 183
212 GURL EmbeddedTestServer::GetURL( 184 GURL EmbeddedTestServer::GetURL(
213 const std::string& hostname, 185 const std::string& hostname,
214 const std::string& relative_url) const { 186 const std::string& relative_url) const {
215 GURL local_url = GetURL(relative_url); 187 GURL local_url = GetURL(relative_url);
216 GURL::Replacements replace_host; 188 GURL::Replacements replace_host;
217 replace_host.SetHostStr(hostname); 189 replace_host.SetHostStr(hostname);
218 return local_url.ReplaceComponents(replace_host); 190 return local_url.ReplaceComponents(replace_host);
219 } 191 }
220 192
221 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const { 193 bool EmbeddedTestServer::GetAddressList(AddressList* address_list) const {
222 *address_list = AddressList(local_endpoint_); 194 *address_list = AddressList(local_endpoint_);
223 return true; 195 return true;
224 } 196 }
225 197
198 void EmbeddedTestServer::SetSSLConfig(const SSLServerConfig& ssl_config) {
199 DCHECK(!Started());
200 ssl_config_ = ssl_config;
201 }
202
203 std::string EmbeddedTestServer::GetCertificateName() const {
mmenke 2015/10/19 18:07:40 DCHECK(is_using_ssl_);?
svaldez 2015/10/19 21:56:14 Done.
204 switch (ssl_config_.server_cert) {
205 case SSLServerConfig::CERT_OK:
206 case SSLServerConfig::CERT_MISMATCHED_NAME:
207 return "ok_cert.pem";
208 case SSLServerConfig::CERT_COMMON_NAME_IS_DOMAIN:
209 return "localhost_cert.pem";
210 case SSLServerConfig::CERT_EXPIRED:
211 return "expired_cert.pem";
212 case SSLServerConfig::CERT_CHAIN_WRONG_ROOT:
213 return "redundant-server-chain.pem";
214 case SSLServerConfig::CERT_BAD_VALIDITY:
215 return "bad_validity.pem";
216 }
217
218 return "ok_cert.pem";
219 }
220
221 scoped_refptr<X509Certificate> EmbeddedTestServer::GetCertificate() const {
222 base::FilePath certs_dir(GetTestCertsDirectory());
223 return ImportCertFromFile(certs_dir, GetCertificateName());
224 }
225
226 void EmbeddedTestServer::ServeFilesFromDirectory( 226 void EmbeddedTestServer::ServeFilesFromDirectory(
227 const base::FilePath& directory) { 227 const base::FilePath& directory) {
228 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); 228 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
229 } 229 }
230 230
231 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
232 const std::string relative) {
233 base::FilePath test_data_dir;
234 if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir))
235 ServeFilesFromDirectory(test_data_dir.AppendASCII(relative));
236 }
237
238 void EmbeddedTestServer::ServeFilesFromSourceDirectory(
239 const base::FilePath& relative) {
240 base::FilePath test_data_dir;
241 if (PathService::Get(base::DIR_SOURCE_ROOT, &test_data_dir))
242 ServeFilesFromDirectory(test_data_dir.Append(relative));
243 }
244
245 void EmbeddedTestServer::AddDefaultHandlers(const base::FilePath& directory) {
mmenke 2015/10/19 18:07:40 Maybe put DCHECK(!io_thread_.get()); in all of the
svaldez 2015/10/19 21:56:14 Its safe to call this after we've started the serv
mmenke 2015/10/19 22:14:41 But it's not safe not after we've started the IOTh
246 // TODO(svaldez): Add additional default handlers.
247 ServeFilesFromSourceDirectory(directory);
248 }
249
231 void EmbeddedTestServer::RegisterRequestHandler( 250 void EmbeddedTestServer::RegisterRequestHandler(
232 const HandleRequestCallback& callback) { 251 const HandleRequestCallback& callback) {
233 request_handlers_.push_back(callback); 252 request_handlers_.push_back(callback);
234 } 253 }
235 254
255 void EmbeddedTestServer::RegisterDefaultHandler(
256 const HandleRequestCallback& callback) {
257 default_request_handlers_.push_back(callback);
258 }
259
260 void EmbeddedTestServer::LoadTestSSLRoot() {
mmenke 2015/10/19 18:07:40 Make this method static, or put into an anonymous
svaldez 2015/10/19 21:56:14 Done.
261 TestRootCerts* root_certs = TestRootCerts::GetInstance();
262 if (!root_certs)
263 return;
mmenke 2015/10/19 18:07:40 Should we fail in some way when this happens? DC
svaldez 2015/10/19 21:56:14 Done.
264 base::FilePath certs_dir(GetTestCertsDirectory());
265 root_certs->AddFromFile(certs_dir.AppendASCII("root_ca_cert.pem"));
266 }
267
268 scoped_ptr<StreamSocket> EmbeddedTestServer::DoSSLUpgrade(
269 scoped_ptr<StreamSocket> connection) {
270 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
271
272 base::FilePath certs_dir(GetTestCertsDirectory());
273 std::string cert_name = GetCertificateName();
274
275 base::FilePath key_path = certs_dir.AppendASCII(cert_name);
276 std::string key_string;
277 CHECK(base::ReadFileToString(key_path, &key_string));
278 std::vector<std::string> headers;
279 headers.push_back("PRIVATE KEY");
280 PEMTokenizer pem_tokenizer(key_string, headers);
281 pem_tokenizer.GetNext();
282 std::vector<uint8_t> key_vector;
283 key_vector.assign(pem_tokenizer.data().begin(), pem_tokenizer.data().end());
284
285 scoped_ptr<crypto::RSAPrivateKey> server_key(
286 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
287
288 return CreateSSLServerSocket(connection.Pass(), GetCertificate().get(),
289 server_key.get(), ssl_config_);
290 }
291
236 void EmbeddedTestServer::DoAcceptLoop() { 292 void EmbeddedTestServer::DoAcceptLoop() {
237 int rv = OK; 293 int rv = OK;
238 while (rv == OK) { 294 while (rv == OK) {
239 rv = listen_socket_->Accept( 295 rv = listen_socket_->Accept(
240 &accepted_socket_, base::Bind(&EmbeddedTestServer::OnAcceptCompleted, 296 &accepted_socket_, base::Bind(&EmbeddedTestServer::OnAcceptCompleted,
241 base::Unretained(this))); 297 base::Unretained(this)));
242 if (rv == ERR_IO_PENDING) 298 if (rv == ERR_IO_PENDING)
243 return; 299 return;
244 HandleAcceptResult(accepted_socket_.Pass()); 300 HandleAcceptResult(accepted_socket_.Pass());
245 } 301 }
246 } 302 }
247 303
248 void EmbeddedTestServer::OnAcceptCompleted(int rv) { 304 void EmbeddedTestServer::OnAcceptCompleted(int rv) {
249 DCHECK_NE(ERR_IO_PENDING, rv); 305 DCHECK_NE(ERR_IO_PENDING, rv);
250 HandleAcceptResult(accepted_socket_.Pass()); 306 HandleAcceptResult(accepted_socket_.Pass());
251 DoAcceptLoop(); 307 DoAcceptLoop();
252 } 308 }
253 309
310 void EmbeddedTestServer::OnHandshakeDone(HttpConnection* connection, int rv) {
311 if (connection->socket_->IsConnected())
312 ReadData(connection);
313 else
314 DidClose(connection);
315 }
316
254 void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) { 317 void EmbeddedTestServer::HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
255 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread()); 318 DCHECK(io_thread_->task_runner()->BelongsToCurrentThread());
256 if (connection_listener_) 319 if (connection_listener_)
257 connection_listener_->AcceptedSocket(*socket); 320 connection_listener_->AcceptedSocket(*socket);
258 321
322 if (is_using_ssl_)
323 socket = DoSSLUpgrade(socket.Pass());
324
259 HttpConnection* http_connection = new HttpConnection( 325 HttpConnection* http_connection = new HttpConnection(
260 socket.Pass(), 326 socket.Pass(),
261 base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this))); 327 base::Bind(&EmbeddedTestServer::HandleRequest, base::Unretained(this)));
262 connections_[http_connection->socket_.get()] = http_connection; 328 connections_[http_connection->socket_.get()] = http_connection;
263 329
264 ReadData(http_connection); 330 if (is_using_ssl_) {
331 SSLServerSocket* ssl_socket =
332 static_cast<SSLServerSocket*>(http_connection->socket_.get());
333 int rv = ssl_socket->Handshake(
334 base::Bind(&EmbeddedTestServer::OnHandshakeDone, base::Unretained(this),
335 http_connection));
336 if (rv != ERR_IO_PENDING)
337 OnHandshakeDone(http_connection, rv);
338 } else {
339 ReadData(http_connection);
340 }
265 } 341 }
266 342
267 void EmbeddedTestServer::ReadData(HttpConnection* connection) { 343 void EmbeddedTestServer::ReadData(HttpConnection* connection) {
268 while (true) { 344 while (true) {
269 int rv = 345 int rv =
270 connection->ReadData(base::Bind(&EmbeddedTestServer::OnReadCompleted, 346 connection->ReadData(base::Bind(&EmbeddedTestServer::OnReadCompleted,
271 base::Unretained(this), connection)); 347 base::Unretained(this), connection));
272 if (rv == ERR_IO_PENDING) 348 if (rv == ERR_IO_PENDING)
273 return; 349 return;
274 if (!HandleReadResult(connection, rv)) 350 if (!HandleReadResult(connection, rv))
(...skipping 64 matching lines...) Expand 10 before | Expand all | Expand 10 after
339 run_loop.QuitClosure())) { 415 run_loop.QuitClosure())) {
340 return false; 416 return false;
341 } 417 }
342 run_loop.Run(); 418 run_loop.Run();
343 419
344 return true; 420 return true;
345 } 421 }
346 422
347 } // namespace test_server 423 } // namespace test_server
348 } // namespace net 424 } // namespace net
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698