Index: net/server/http_server_unittest.cc |
diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc |
index 42e56399a10afc4d3ed71d17ebe8573957fad1f8..216cb03416178e69e7328c91b64f6f8bcfd9f99b 100644 |
--- a/net/server/http_server_unittest.cc |
+++ b/net/server/http_server_unittest.cc |
@@ -189,6 +189,8 @@ class HttpServerTest : public testing::Test, |
ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_)); |
} |
+ virtual void OnConnect(int connection_id) OVERRIDE {} |
+ |
virtual void OnHttpRequest(int connection_id, |
const HttpServerRequestInfo& info) OVERRIDE { |
requests_.push_back(std::make_pair(info, connection_id)); |
@@ -243,6 +245,8 @@ class HttpServerTest : public testing::Test, |
size_t quit_after_request_count_; |
}; |
+namespace { |
+ |
class WebSocketTest : public HttpServerTest { |
virtual void OnHttpRequest(int connection_id, |
const HttpServerRequestInfo& info) OVERRIDE { |
@@ -461,8 +465,6 @@ TEST_F(HttpServerTest, SendRaw) { |
ASSERT_EQ(expected_response, response); |
} |
-namespace { |
- |
class MockStreamSocket : public StreamSocket { |
public: |
MockStreamSocket() |
@@ -557,8 +559,6 @@ class MockStreamSocket : public StreamSocket { |
DISALLOW_COPY_AND_ASSIGN(MockStreamSocket); |
}; |
-} // namespace |
- |
TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) { |
MockStreamSocket* socket = new MockStreamSocket(); |
HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket)); |
@@ -619,4 +619,26 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { |
ASSERT_TRUE(EndsWith(response3, "Content for /test3", true)); |
} |
+class CloseOnConnectHttpServerTest : public HttpServerTest { |
+ public: |
+ virtual void OnConnect(int connection_id) OVERRIDE { |
+ connection_ids_.push_back(connection_id); |
+ server_->Close(connection_id); |
+ } |
+ |
+ protected: |
+ std::vector<int> connection_ids_; |
+}; |
+ |
+TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) { |
+ TestHttpClient client; |
+ ASSERT_EQ(OK, client.ConnectAndWait(server_address_)); |
+ client.Send("GET / HTTP/1.1\r\n\r\n"); |
+ ASSERT_FALSE(RunUntilRequestsReceived(1)); |
+ ASSERT_EQ(1ul, connection_ids_.size()); |
+ ASSERT_EQ(0ul, requests_.size()); |
+} |
+ |
+} // namespace |
+ |
} // namespace net |