OLD | NEW |
| (Empty) |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "chrome_frame/test/test_server.h" | |
6 | |
7 #include <windows.h> | |
8 #include <objbase.h> | |
9 #include <urlmon.h> | |
10 | |
11 #include "base/bind.h" | |
12 #include "base/logging.h" | |
13 #include "base/strings/string_number_conversions.h" | |
14 #include "base/strings/string_piece.h" | |
15 #include "base/strings/string_util.h" | |
16 #include "base/strings/stringprintf.h" | |
17 #include "base/strings/utf_string_conversions.h" | |
18 #include "chrome_frame/test/chrome_frame_test_utils.h" | |
19 #include "net/base/winsock_init.h" | |
20 #include "net/http/http_util.h" | |
21 #include "net/socket/tcp_listen_socket.h" | |
22 | |
23 namespace test_server { | |
24 const char kDefaultHeaderTemplate[] = | |
25 "HTTP/1.1 %hs\r\n" | |
26 "Connection: close\r\n" | |
27 "Content-Type: %hs\r\n" | |
28 "Content-Length: %i\r\n\r\n"; | |
29 const char kStatusOk[] = "200 OK"; | |
30 const char kStatusNotFound[] = "404 Not Found"; | |
31 const char kDefaultContentType[] = "text/html; charset=UTF-8"; | |
32 | |
33 void Request::ParseHeaders(const std::string& headers) { | |
34 DCHECK(method_.length() == 0); | |
35 | |
36 size_t pos = headers.find("\r\n"); | |
37 DCHECK(pos != std::string::npos); | |
38 if (pos != std::string::npos) { | |
39 headers_ = headers.substr(pos + 2); | |
40 | |
41 base::StringTokenizer tokenizer( | |
42 headers.begin(), headers.begin() + pos, " "); | |
43 std::string* parse[] = { &method_, &path_, &version_ }; | |
44 int field = 0; | |
45 while (tokenizer.GetNext() && field < arraysize(parse)) { | |
46 parse[field++]->assign(tokenizer.token_begin(), | |
47 tokenizer.token_end()); | |
48 } | |
49 } | |
50 | |
51 // Check for content-length in case we're being sent some data. | |
52 net::HttpUtil::HeadersIterator it(headers_.begin(), headers_.end(), | |
53 "\r\n"); | |
54 while (it.GetNext()) { | |
55 if (LowerCaseEqualsASCII(it.name(), "content-length")) { | |
56 int int_content_length; | |
57 base::StringToInt(base::StringPiece(it.values_begin(), | |
58 it.values_end()), | |
59 &int_content_length); | |
60 content_length_ = int_content_length; | |
61 break; | |
62 } | |
63 } | |
64 } | |
65 | |
66 void Request::OnDataReceived(const std::string& data) { | |
67 content_ += data; | |
68 | |
69 if (method_.length() == 0) { | |
70 size_t index = content_.find("\r\n\r\n"); | |
71 if (index != std::string::npos) { | |
72 // Parse the headers before returning and chop them of the | |
73 // data buffer we've already received. | |
74 std::string headers(content_.substr(0, index + 2)); | |
75 ParseHeaders(headers); | |
76 content_.erase(0, index + 4); | |
77 } | |
78 } | |
79 } | |
80 | |
81 ResponseForPath::~ResponseForPath() { | |
82 } | |
83 | |
84 SimpleResponse::~SimpleResponse() { | |
85 } | |
86 | |
87 bool FileResponse::GetContentType(std::string* content_type) const { | |
88 size_t length = ContentLength(); | |
89 char buffer[4096]; | |
90 void* data = NULL; | |
91 | |
92 if (length) { | |
93 // Create a copy of the first few bytes of the file. | |
94 // If we try and use the mapped file directly, FindMimeFromData will crash | |
95 // 'cause it cheats and temporarily tries to write to the buffer! | |
96 length = std::min(arraysize(buffer), length); | |
97 memcpy(buffer, file_->data(), length); | |
98 data = buffer; | |
99 } | |
100 | |
101 LPOLESTR mime_type = NULL; | |
102 FindMimeFromData(NULL, file_path_.value().c_str(), data, length, NULL, | |
103 FMFD_DEFAULT, &mime_type, 0); | |
104 if (mime_type) { | |
105 *content_type = WideToASCII(mime_type); | |
106 ::CoTaskMemFree(mime_type); | |
107 } | |
108 | |
109 return content_type->length() > 0; | |
110 } | |
111 | |
112 void FileResponse::WriteContents(net::StreamListenSocket* socket) const { | |
113 DCHECK(file_.get()); | |
114 if (file_.get()) { | |
115 socket->Send(reinterpret_cast<const char*>(file_->data()), | |
116 file_->length(), false); | |
117 } | |
118 } | |
119 | |
120 size_t FileResponse::ContentLength() const { | |
121 if (file_.get() == NULL) { | |
122 file_.reset(new base::MemoryMappedFile()); | |
123 if (!file_->Initialize(file_path_)) { | |
124 NOTREACHED(); | |
125 file_.reset(); | |
126 } | |
127 } | |
128 return file_.get() ? file_->length() : 0; | |
129 } | |
130 | |
131 bool RedirectResponse::GetCustomHeaders(std::string* headers) const { | |
132 *headers = base::StringPrintf("HTTP/1.1 302 Found\r\n" | |
133 "Connection: close\r\n" | |
134 "Content-Length: 0\r\n" | |
135 "Content-Type: text/html\r\n" | |
136 "Location: %hs\r\n\r\n", | |
137 redirect_url_.c_str()); | |
138 return true; | |
139 } | |
140 | |
141 SimpleWebServer::SimpleWebServer(int port) { | |
142 Construct(chrome_frame_test::GetLocalIPv4Address(), port); | |
143 } | |
144 | |
145 SimpleWebServer::SimpleWebServer(const std::string& address, int port) { | |
146 Construct(address, port); | |
147 } | |
148 | |
149 SimpleWebServer::~SimpleWebServer() { | |
150 ConnectionList::const_iterator it; | |
151 for (it = connections_.begin(); it != connections_.end(); ++it) | |
152 delete (*it); | |
153 connections_.clear(); | |
154 } | |
155 | |
156 void SimpleWebServer::Construct(const std::string& address, int port) { | |
157 CHECK(base::MessageLoop::current()) | |
158 << "SimpleWebServer requires a message loop"; | |
159 net::EnsureWinsockInit(); | |
160 AddResponse(&quit_); | |
161 host_ = address; | |
162 server_ = net::TCPListenSocket::CreateAndListen(address, port, this); | |
163 LOG_IF(DFATAL, !server_.get()) | |
164 << "Failed to create listener socket at " << address << ":" << port; | |
165 } | |
166 | |
167 void SimpleWebServer::AddResponse(Response* response) { | |
168 responses_.push_back(response); | |
169 } | |
170 | |
171 void SimpleWebServer::DeleteAllResponses() { | |
172 std::list<Response*>::const_iterator it; | |
173 for (it = responses_.begin(); it != responses_.end(); ++it) { | |
174 if ((*it) != &quit_) | |
175 delete (*it); | |
176 } | |
177 } | |
178 | |
179 Response* SimpleWebServer::FindResponse(const Request& request) const { | |
180 std::list<Response*>::const_iterator it; | |
181 for (it = responses_.begin(); it != responses_.end(); it++) { | |
182 Response* response = (*it); | |
183 if (response->Matches(request)) { | |
184 return response; | |
185 } | |
186 } | |
187 return NULL; | |
188 } | |
189 | |
190 Connection* SimpleWebServer::FindConnection( | |
191 const net::StreamListenSocket* socket) const { | |
192 ConnectionList::const_iterator it; | |
193 for (it = connections_.begin(); it != connections_.end(); it++) { | |
194 if ((*it)->IsSame(socket)) { | |
195 return (*it); | |
196 } | |
197 } | |
198 return NULL; | |
199 } | |
200 | |
201 void SimpleWebServer::DidAccept( | |
202 net::StreamListenSocket* server, | |
203 scoped_ptr<net::StreamListenSocket> connection) { | |
204 connections_.push_back(new Connection(connection.Pass())); | |
205 } | |
206 | |
207 void SimpleWebServer::DidRead(net::StreamListenSocket* connection, | |
208 const char* data, | |
209 int len) { | |
210 Connection* c = FindConnection(connection); | |
211 DCHECK(c); | |
212 Request& r = c->request(); | |
213 std::string str(data, len); | |
214 r.OnDataReceived(str); | |
215 if (r.AllContentReceived()) { | |
216 const Request& request = c->request(); | |
217 Response* response = FindResponse(request); | |
218 if (response) { | |
219 std::string headers; | |
220 if (!response->GetCustomHeaders(&headers)) { | |
221 std::string content_type; | |
222 if (!response->GetContentType(&content_type)) | |
223 content_type = kDefaultContentType; | |
224 headers = base::StringPrintf(kDefaultHeaderTemplate, kStatusOk, | |
225 content_type.c_str(), | |
226 response->ContentLength()); | |
227 } | |
228 | |
229 connection->Send(headers, false); | |
230 response->WriteContents(connection); | |
231 response->IncrementAccessCounter(); | |
232 } else { | |
233 std::string payload = "sorry, I can't find " + request.path(); | |
234 std::string headers(base::StringPrintf(kDefaultHeaderTemplate, | |
235 kStatusNotFound, | |
236 kDefaultContentType, | |
237 payload.length())); | |
238 connection->Send(headers, false); | |
239 connection->Send(payload, false); | |
240 } | |
241 } | |
242 } | |
243 | |
244 void SimpleWebServer::DidClose(net::StreamListenSocket* sock) { | |
245 // To keep the historical list of connections reasonably tidy, we delete | |
246 // 404's when the connection ends. | |
247 Connection* c = FindConnection(sock); | |
248 DCHECK(c); | |
249 c->OnSocketClosed(); | |
250 if (!FindResponse(c->request())) { | |
251 // extremely inefficient, but in one line and not that common... :) | |
252 connections_.erase(std::find(connections_.begin(), connections_.end(), c)); | |
253 delete c; | |
254 } | |
255 } | |
256 | |
257 HTTPTestServer::HTTPTestServer(int port, const std::wstring& address, | |
258 base::FilePath root_dir) | |
259 : port_(port), address_(address), root_dir_(root_dir) { | |
260 net::EnsureWinsockInit(); | |
261 server_ = net::TCPListenSocket::CreateAndListen( | |
262 base::WideToUTF8(address), port, this); | |
263 } | |
264 | |
265 HTTPTestServer::~HTTPTestServer() { | |
266 } | |
267 | |
268 std::list<scoped_refptr<ConfigurableConnection>>::iterator | |
269 HTTPTestServer::FindConnection(const net::StreamListenSocket* socket) { | |
270 ConnectionList::iterator it; | |
271 // Scan through the list searching for the desired socket. Along the way, | |
272 // erase any connections for which the corresponding socket has already been | |
273 // forgotten about as a result of all data having been sent. | |
274 for (it = connection_list_.begin(); it != connection_list_.end(); ) { | |
275 ConfigurableConnection* connection = it->get(); | |
276 if (connection->socket_ == NULL) { | |
277 connection_list_.erase(it++); | |
278 continue; | |
279 } | |
280 if (connection->socket_ == socket) | |
281 break; | |
282 ++it; | |
283 } | |
284 | |
285 return it; | |
286 } | |
287 | |
288 scoped_refptr<ConfigurableConnection> HTTPTestServer::ConnectionFromSocket( | |
289 const net::StreamListenSocket* socket) { | |
290 ConnectionList::iterator it = FindConnection(socket); | |
291 if (it != connection_list_.end()) | |
292 return *it; | |
293 return NULL; | |
294 } | |
295 | |
296 void HTTPTestServer::DidAccept(net::StreamListenSocket* server, | |
297 scoped_ptr<net::StreamListenSocket> socket) { | |
298 connection_list_.push_back(new ConfigurableConnection(socket.Pass())); | |
299 } | |
300 | |
301 void HTTPTestServer::DidRead(net::StreamListenSocket* socket, | |
302 const char* data, | |
303 int len) { | |
304 scoped_refptr<ConfigurableConnection> connection = | |
305 ConnectionFromSocket(socket); | |
306 if (connection) { | |
307 std::string str(data, len); | |
308 connection->r_.OnDataReceived(str); | |
309 if (connection->r_.AllContentReceived()) { | |
310 VLOG(1) << __FUNCTION__ << ": " << connection->r_.method() << " " | |
311 << connection->r_.path(); | |
312 std::wstring path = base::UTF8ToWide(connection->r_.path()); | |
313 if (LowerCaseEqualsASCII(connection->r_.method(), "post")) | |
314 this->Post(connection, path, connection->r_); | |
315 else | |
316 this->Get(connection, path, connection->r_); | |
317 } | |
318 } | |
319 } | |
320 | |
321 void HTTPTestServer::DidClose(net::StreamListenSocket* socket) { | |
322 ConnectionList::iterator it = FindConnection(socket); | |
323 if (it != connection_list_.end()) | |
324 connection_list_.erase(it); | |
325 } | |
326 | |
327 std::wstring HTTPTestServer::Resolve(const std::wstring& path) { | |
328 // Remove the first '/' if needed. | |
329 std::wstring stripped_path = path; | |
330 if (path.size() && path[0] == L'/') | |
331 stripped_path = path.substr(1); | |
332 | |
333 if (port_ == 80) { | |
334 if (stripped_path.empty()) { | |
335 return base::StringPrintf(L"http://%ls", address_.c_str()); | |
336 } else { | |
337 return base::StringPrintf(L"http://%ls/%ls", address_.c_str(), | |
338 stripped_path.c_str()); | |
339 } | |
340 } else { | |
341 if (stripped_path.empty()) { | |
342 return base::StringPrintf(L"http://%ls:%d", address_.c_str(), port_); | |
343 } else { | |
344 return base::StringPrintf(L"http://%ls:%d/%ls", address_.c_str(), port_, | |
345 stripped_path.c_str()); | |
346 } | |
347 } | |
348 } | |
349 | |
350 void ConfigurableConnection::SendChunk() { | |
351 int size = (int)data_.size(); | |
352 const char* chunk_ptr = data_.c_str() + cur_pos_; | |
353 int bytes_to_send = std::min(options_.chunk_size_, size - cur_pos_); | |
354 | |
355 socket_->Send(chunk_ptr, bytes_to_send); | |
356 VLOG(1) << "Sent(" << cur_pos_ << "," << bytes_to_send << "): " | |
357 << base::StringPiece(chunk_ptr, bytes_to_send); | |
358 | |
359 cur_pos_ += bytes_to_send; | |
360 if (cur_pos_ < size) { | |
361 base::MessageLoop::current()->PostDelayedTask( | |
362 FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this), | |
363 base::TimeDelta::FromMilliseconds(options_.timeout_)); | |
364 } else { | |
365 Close(); | |
366 } | |
367 } | |
368 | |
369 void ConfigurableConnection::Close() { | |
370 socket_.reset(); | |
371 } | |
372 | |
373 void ConfigurableConnection::Send(const std::string& headers, | |
374 const std::string& content) { | |
375 SendOptions options(SendOptions::IMMEDIATE, 0, 0); | |
376 SendWithOptions(headers, content, options); | |
377 } | |
378 | |
379 void ConfigurableConnection::SendWithOptions(const std::string& headers, | |
380 const std::string& content, | |
381 const SendOptions& options) { | |
382 std::string content_length_header; | |
383 if (!content.empty() && | |
384 std::string::npos == headers.find("Context-Length:")) { | |
385 content_length_header = base::StringPrintf("Content-Length: %u\r\n", | |
386 content.size()); | |
387 } | |
388 | |
389 // Save the options. | |
390 options_ = options; | |
391 | |
392 if (options_.speed_ == SendOptions::IMMEDIATE) { | |
393 socket_->Send(headers); | |
394 socket_->Send(content_length_header, true); | |
395 socket_->Send(content); | |
396 // Post a task to close the socket since StreamListenSocket doesn't like | |
397 // instances to go away from within its callbacks. | |
398 base::MessageLoop::current()->PostTask( | |
399 FROM_HERE, base::Bind(&ConfigurableConnection::Close, this)); | |
400 | |
401 return; | |
402 } | |
403 | |
404 if (options_.speed_ == SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT) { | |
405 socket_->Send(headers); | |
406 socket_->Send(content_length_header, true); | |
407 VLOG(1) << "Headers sent: " << headers << content_length_header; | |
408 data_.append(content); | |
409 } | |
410 | |
411 if (options_.speed_ == SendOptions::DELAYED) { | |
412 data_ = headers; | |
413 data_.append(content_length_header); | |
414 data_.append("\r\n"); | |
415 } | |
416 | |
417 base::MessageLoop::current()->PostDelayedTask( | |
418 FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this), | |
419 base::TimeDelta::FromMilliseconds(options.timeout_)); | |
420 } | |
421 | |
422 } // namespace test_server | |
OLD | NEW |