Index: net/server/http_server_unittest.cc |
diff --git a/net/server/http_server_unittest.cc b/net/server/http_server_unittest.cc |
index 03544f12b3d5f82a90c83ab5d8a0ba1cfbf2d4d7..997d8650cde1513f43e72e18b40d28e90bd27e45 100644 |
--- a/net/server/http_server_unittest.cc |
+++ b/net/server/http_server_unittest.cc |
@@ -132,6 +132,8 @@ class TestHttpClient { |
return true; |
} |
+ TCPClientSocket& socket() { return *socket_; } |
+ |
private: |
void OnConnect(const base::Closure& quit_loop, int result) { |
connect_result_ = result; |
@@ -198,7 +200,10 @@ class HttpServerTest : public testing::Test, |
ASSERT_THAT(server_->GetLocalAddress(&server_address_), IsOk()); |
} |
- void OnConnect(int connection_id) override {} |
+ void OnConnect(int connection_id) override { |
+ DCHECK(connection_map_.find(connection_id) == connection_map_.end()); |
+ connection_map_[connection_id] = true; |
+ } |
void OnHttpRequest(int connection_id, |
const HttpServerRequestInfo& info) override { |
@@ -216,7 +221,10 @@ class HttpServerTest : public testing::Test, |
NOTREACHED(); |
} |
- void OnClose(int connection_id) override {} |
+ void OnClose(int connection_id) override { |
+ DCHECK(connection_map_.find(connection_id) != connection_map_.end()); |
+ connection_map_[connection_id] = false; |
+ } |
bool RunUntilRequestsReceived(size_t count) { |
quit_after_request_count_ = count; |
@@ -243,11 +251,15 @@ class HttpServerTest : public testing::Test, |
server_->HandleAcceptResult(OK); |
} |
+ std::unordered_map<int, bool>& connection_map() { return connection_map_; } |
+ |
protected: |
std::unique_ptr<HttpServer> server_; |
IPEndPoint server_address_; |
base::Closure run_loop_quit_func_; |
std::vector<std::pair<HttpServerRequestInfo, int> > requests_; |
+ std::unordered_map<int /* connection_id */, bool /* connected */> |
+ connection_map_; |
private: |
size_t quit_after_request_count_; |
@@ -472,6 +484,38 @@ TEST_F(HttpServerTest, SendRaw) { |
ASSERT_EQ(expected_response, response); |
} |
+TEST_F(HttpServerTest, WrongProtocolRequest) { |
+ const char* const kBadProtocolRequests[] = { |
+ "GET /test HTTP/1.0\r\n\r\n", |
+ "GET /test foo\r\n\r\n", |
+ "GET /test \r\n\r\n", |
+ }; |
+ |
+ for (size_t i = 0; i < arraysize(kBadProtocolRequests); ++i) { |
+ TestHttpClient client; |
+ ASSERT_THAT(client.ConnectAndWait(server_address_), IsOk()); |
+ |
+ client.Send(kBadProtocolRequests[i]); |
+ ASSERT_FALSE(RunUntilRequestsReceived(1)); |
+ |
+ // Assert that the delegate was updated properly. |
+ ASSERT_EQ(1u, connection_map().size()); |
+ ASSERT_FALSE(connection_map().begin()->second); |
+ |
+ // Assert that the socket was opened... |
+ ASSERT_TRUE(client.socket().WasEverUsed()); |
+ |
+ // ...then closed when the server disconnected. Verify that the socket was |
+ // closed by checking that a Read() fails. |
+ std::string response; |
+ ASSERT_FALSE(client.Read(&response, 1u)); |
+ ASSERT_EQ(std::string(), response); |
+ |
+ // Reset the state of the connection map. |
+ connection_map().clear(); |
+ } |
+} |
+ |
class MockStreamSocket : public StreamSocket { |
public: |
MockStreamSocket() |
@@ -640,6 +684,7 @@ TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) { |
class CloseOnConnectHttpServerTest : public HttpServerTest { |
public: |
void OnConnect(int connection_id) override { |
+ HttpServerTest::OnConnect(connection_id); |
connection_ids_.push_back(connection_id); |
server_->Close(connection_id); |
} |