Index: net/server/http_server_unittest.cc |
diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc |
index a492cf13de0e609dcbd9f1e236fa184706207c37..5783ca20fbda4eeadb5e4e00923ec9a262a9afda 100644 |
--- a/net/server/http_server_unittest.cc |
+++ b/net/server/http_server_unittest.cc |
@@ -2,11 +2,13 @@ |
// Use of this source code is governed by a BSD-style license that can be |
// found in the LICENSE file. |
+#include <algorithm> |
#include <utility> |
#include <vector> |
#include "base/bind.h" |
#include "base/bind_helpers.h" |
+#include "base/callback_helpers.h" |
#include "base/compiler_specific.h" |
#include "base/format_macros.h" |
#include "base/memory/ref_counted.h" |
@@ -24,11 +26,14 @@ |
#include "net/base/ip_endpoint.h" |
#include "net/base/net_errors.h" |
#include "net/base/net_log.h" |
+#include "net/base/net_util.h" |
#include "net/base/test_completion_callback.h" |
+#include "net/http/http_response_headers.h" |
+#include "net/http/http_util.h" |
#include "net/server/http_server.h" |
#include "net/server/http_server_request_info.h" |
#include "net/socket/tcp_client_socket.h" |
-#include "net/socket/tcp_listen_socket.h" |
+#include "net/socket/tcp_server_socket.h" |
#include "net/url_request/url_fetcher.h" |
#include "net/url_request/url_fetcher_delegate.h" |
#include "net/url_request/url_request_context.h" |
@@ -90,10 +95,6 @@ class TestHttpClient { |
Write(); |
} |
- bool Read(std::string* message) { |
- return Read(message, 1); |
- } |
- |
bool Read(std::string* message, int expected_bytes) { |
int total_bytes_received = 0; |
message->clear(); |
@@ -110,6 +111,18 @@ class TestHttpClient { |
return true; |
} |
+ bool ReadResponse(std::string* message) { |
+ if (!Read(message, 1)) |
+ return false; |
+ while (!IsCompleteResponse(*message)) { |
+ std::string chunk; |
+ if (!Read(&chunk, 1)) |
+ return false; |
+ message->append(chunk); |
+ } |
+ return true; |
+ } |
+ |
private: |
void OnConnect(const base::Closure& quit_loop, int result) { |
connect_result_ = result; |
@@ -141,6 +154,21 @@ class TestHttpClient { |
callback.Run(result); |
} |
+ bool IsCompleteResponse(const std::string& response) { |
+ // Check end of headers first. |
+ int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(), |
+ response.size()); |
+ if (end_of_headers < 0) |
+ return false; |
+ |
+ // Return true if response has data equal to or more than content length. |
+ int64 body_size = static_cast<int64>(response.size()) - end_of_headers; |
+ DCHECK_LE(0, body_size); |
+ scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders( |
+ HttpUtil::AssembleRawHeaders(response.data(), end_of_headers))); |
+ return body_size >= headers->GetContentLength(); |
+ } |
+ |
scoped_refptr<IOBufferWithSize> read_buffer_; |
scoped_refptr<DrainableIOBuffer> write_buffer_; |
scoped_ptr<TCPClientSocket> socket_; |
@@ -155,8 +183,10 @@ class HttpServerTest : public testing::Test, |
HttpServerTest() : quit_after_request_count_(0) {} |
virtual void SetUp() OVERRIDE { |
- TCPListenSocketFactory socket_factory("127.0.0.1", 0); |
- server_ = new HttpServer(socket_factory, this); |
+ scoped_ptr<ServerSocket> server_socket( |
+ new TCPServerSocket(NULL, net::NetLog::Source())); |
+ server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1); |
+ server_.reset(new HttpServer(server_socket.Pass(), this)); |
ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_)); |
} |
@@ -199,8 +229,13 @@ class HttpServerTest : public testing::Test, |
return requests_[request_index].second; |
} |
+ void HandleAcceptResult(scoped_ptr<StreamSocket> socket) { |
+ server_->accepted_socket_.reset(socket.release()); |
+ server_->HandleAcceptResult(OK); |
+ } |
+ |
protected: |
- scoped_refptr<HttpServer> server_; |
+ scoped_ptr<HttpServer> server_; |
IPEndPoint server_address_; |
base::Closure run_loop_quit_func_; |
std::vector<std::pair<HttpServerRequestInfo, int> > requests_; |
@@ -407,7 +442,7 @@ TEST_F(HttpServerTest, Send200) { |
server_->Send200(GetConnectionId(0), "Response!", "text/plain"); |
std::string response; |
- ASSERT_TRUE(client.Read(&response)); |
+ ASSERT_TRUE(client.ReadResponse(&response)); |
ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true)); |
ASSERT_TRUE(EndsWith(response, "Response!", true)); |
} |
@@ -429,23 +464,105 @@ TEST_F(HttpServerTest, SendRaw) { |
namespace { |
-class MockStreamListenSocket : public StreamListenSocket { |
+class MockStreamSocket : public StreamSocket { |
public: |
- MockStreamListenSocket(StreamListenSocket::Delegate* delegate) |
- : StreamListenSocket(kInvalidSocket, delegate) {} |
+ MockStreamSocket() |
+ : connected_(true), |
+ read_buf_(NULL), |
+ read_buf_len_(0) {} |
+ |
+ // StreamSocket |
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE { |
+ return ERR_NOT_IMPLEMENTED; |
+ } |
+ virtual void Disconnect() OVERRIDE { |
+ connected_ = false; |
+ if (!read_callback_.is_null()) { |
+ read_buf_ = NULL; |
+ read_buf_len_ = 0; |
+ base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED); |
+ } |
+ } |
+ virtual bool IsConnected() const OVERRIDE { return connected_; } |
+ virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); } |
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { |
+ return ERR_NOT_IMPLEMENTED; |
+ } |
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { |
+ return ERR_NOT_IMPLEMENTED; |
+ } |
+ virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; } |
+ virtual void SetSubresourceSpeculation() OVERRIDE {} |
+ virtual void SetOmniboxSpeculation() OVERRIDE {} |
+ virtual bool WasEverUsed() const OVERRIDE { return true; } |
+ virtual bool UsingTCPFastOpen() const OVERRIDE { return false; } |
+ virtual bool WasNpnNegotiated() const OVERRIDE { return false; } |
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE { |
+ return kProtoUnknown; |
+ } |
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; } |
+ |
+ // Socket |
+ virtual int Read(IOBuffer* buf, int buf_len, |
+ const CompletionCallback& callback) OVERRIDE { |
+ if (!connected_) { |
+ return ERR_SOCKET_NOT_CONNECTED; |
+ } |
+ if (pending_read_data_.empty()) { |
+ read_buf_ = buf; |
+ read_buf_len_ = buf_len; |
+ read_callback_ = callback; |
+ return ERR_IO_PENDING; |
+ } |
+ DCHECK_GT(buf_len, 0); |
+ int read_len = std::min(static_cast<int>(pending_read_data_.size()), |
+ buf_len); |
+ memcpy(buf->data(), pending_read_data_.data(), read_len); |
+ pending_read_data_.erase(0, read_len); |
+ return read_len; |
+ } |
+ virtual int Write(IOBuffer* buf, int buf_len, |
+ const CompletionCallback& callback) OVERRIDE { |
+ return ERR_NOT_IMPLEMENTED; |
+ } |
+ virtual int SetReceiveBufferSize(int32 size) OVERRIDE { |
+ return ERR_NOT_IMPLEMENTED; |
+ } |
+ virtual int SetSendBufferSize(int32 size) OVERRIDE { |
+ return ERR_NOT_IMPLEMENTED; |
+ } |
- virtual void Accept() OVERRIDE { NOTREACHED(); } |
+ void DidRead(const char* data, int data_len) { |
+ if (!read_buf_) { |
+ pending_read_data_.append(data, data_len); |
+ return; |
+ } |
+ int read_len = std::min(data_len, read_buf_len_); |
+ memcpy(read_buf_->data(), data, read_len); |
+ pending_read_data_.assign(data + read_len, data_len - read_len); |
+ read_buf_ = NULL; |
+ read_buf_len_ = 0; |
+ base::ResetAndReturn(&read_callback_).Run(read_len); |
+ } |
private: |
- virtual ~MockStreamListenSocket() {} |
+ virtual ~MockStreamSocket() {} |
+ |
+ bool connected_; |
+ scoped_refptr<IOBuffer> read_buf_; |
+ int read_buf_len_; |
+ CompletionCallback read_callback_; |
+ std::string pending_read_data_; |
+ BoundNetLog net_log_; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(MockStreamSocket); |
}; |
} // namespace |
TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { |
- StreamListenSocket* socket = |
- new MockStreamListenSocket(server_.get()); |
- server_->DidAccept(NULL, make_scoped_ptr(socket)); |
+ MockStreamSocket* socket = new MockStreamSocket(); |
+ HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket)); |
std::string body("body"); |
std::string request_text = base::StringPrintf( |
"GET /test HTTP/1.1\r\n" |
@@ -453,9 +570,9 @@ TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { |
"Content-Length: %" PRIuS "\r\n\r\n%s", |
body.length(), |
body.c_str()); |
- server_->DidRead(socket, request_text.c_str(), request_text.length() - 2); |
+ socket->DidRead(request_text.c_str(), request_text.length() - 2); |
ASSERT_EQ(0u, requests_.size()); |
- server_->DidRead(socket, request_text.c_str() + request_text.length() - 2, 2); |
+ socket->DidRead(request_text.c_str() + request_text.length() - 2, 2); |
ASSERT_EQ(1u, requests_.size()); |
ASSERT_EQ(body, GetRequest(0).data); |
} |
@@ -477,7 +594,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { |
int client_connection_id = GetConnectionId(0); |
server_->Send200(client_connection_id, "Content for /test", "text/plain"); |
std::string response1; |
- ASSERT_TRUE(client.Read(&response1)); |
+ ASSERT_TRUE(client.ReadResponse(&response1)); |
ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true)); |
ASSERT_TRUE(EndsWith(response1, "Content for /test", true)); |
@@ -488,7 +605,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { |
ASSERT_EQ(client_connection_id, GetConnectionId(1)); |
server_->Send404(client_connection_id); |
std::string response2; |
- ASSERT_TRUE(client.Read(&response2)); |
+ ASSERT_TRUE(client.ReadResponse(&response2)); |
ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true)); |
client.Send("GET /test3 HTTP/1.1\r\n\r\n"); |
@@ -498,12 +615,9 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { |
ASSERT_EQ(client_connection_id, GetConnectionId(2)); |
server_->Send200(client_connection_id, "Content for /test3", "text/plain"); |
std::string response3; |
- ASSERT_TRUE(client.Read(&response3)); |
+ ASSERT_TRUE(client.ReadResponse(&response3)); |
ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true)); |
-#if 0 |
- // TODO(byungchul): Figure out why it fails in windows build bot. |
ASSERT_TRUE(EndsWith(response3, "Content for /test3", true)); |
-#endif |
} |
} // namespace net |