Index: net/tools/quic/quic_dispatcher.cc |
diff --git a/net/tools/quic/quic_dispatcher.cc b/net/tools/quic/quic_dispatcher.cc |
index 297234d9755b176210acaa50cb2b3a75fe315c8c..21070169a979d8e1f082f78bdddc1637be4db411 100644 |
--- a/net/tools/quic/quic_dispatcher.cc |
+++ b/net/tools/quic/quic_dispatcher.cc |
@@ -17,6 +17,7 @@ |
namespace net { |
namespace tools { |
+using base::StringPiece; |
using std::make_pair; |
class DeleteSessionsAlarm : public EpollAlarm { |
@@ -35,6 +36,90 @@ class DeleteSessionsAlarm : public EpollAlarm { |
QuicDispatcher* dispatcher_; |
}; |
+class QuicDispatcher::QuicFramerVisitor : public QuicFramerVisitorInterface { |
+ public: |
+ explicit QuicFramerVisitor(QuicDispatcher* dispatcher) |
+ : dispatcher_(dispatcher) { |
+ } |
+ |
+ // QuicFramerVisitorInterface implementation |
+ virtual void OnPacket() OVERRIDE {} |
+ virtual bool OnUnauthenticatedPublicHeader( |
+ const QuicPacketPublicHeader& header) OVERRIDE { |
+ return dispatcher_->OnUnauthenticatedPublicHeader(header); |
+ } |
+ virtual bool OnUnauthenticatedHeader( |
+ const QuicPacketHeader& header) OVERRIDE { |
+ dispatcher_->OnUnauthenticatedHeader(header); |
+ return false; |
+ } |
+ virtual void OnError(QuicFramer* framer) OVERRIDE { |
+ DLOG(INFO) << QuicUtils::ErrorToString(framer->error()); |
+ } |
+ |
+ // The following methods should never get called because we always return |
+ // false from OnUnauthenticatedHeader(). As a result, we never process the |
+ // payload of the packet. |
+ virtual bool OnProtocolVersionMismatch( |
+ QuicVersion /*received_version*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual void OnPublicResetPacket( |
+ const QuicPublicResetPacket& /*packet*/) OVERRIDE { |
+ DCHECK(false); |
+ } |
+ virtual void OnVersionNegotiationPacket( |
+ const QuicVersionNegotiationPacket& /*packet*/) OVERRIDE { |
+ DCHECK(false); |
+ } |
+ virtual void OnPacketComplete() OVERRIDE { |
+ DCHECK(false); |
+ } |
+ virtual bool OnPacketHeader(const QuicPacketHeader& /*header*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual void OnRevivedPacket() OVERRIDE { |
+ DCHECK(false); |
+ } |
+ virtual void OnFecProtectedPayload(StringPiece /*payload*/) OVERRIDE { |
+ DCHECK(false); |
+ } |
+ virtual bool OnStreamFrame(const QuicStreamFrame& /*frame*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual bool OnAckFrame(const QuicAckFrame& /*frame*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual bool OnCongestionFeedbackFrame( |
+ const QuicCongestionFeedbackFrame& /*frame*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual bool OnRstStreamFrame(const QuicRstStreamFrame& /*frame*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual bool OnConnectionCloseFrame( |
+ const QuicConnectionCloseFrame & /*frame*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual bool OnGoAwayFrame(const QuicGoAwayFrame& /*frame*/) OVERRIDE { |
+ DCHECK(false); |
+ return false; |
+ } |
+ virtual void OnFecData(const QuicFecData& /*fec*/) OVERRIDE { |
+ DCHECK(false); |
+ } |
+ |
+ private: |
+ QuicDispatcher* dispatcher_; |
+}; |
+ |
QuicDispatcher::QuicDispatcher(const QuicConfig& config, |
const QuicCryptoServerConfig& crypto_config, |
const QuicVersionVector& supported_versions, |
@@ -50,7 +135,11 @@ QuicDispatcher::QuicDispatcher(const QuicConfig& config, |
write_blocked_(false), |
helper_(new QuicEpollConnectionHelper(epoll_server_)), |
writer_(new QuicDefaultPacketWriter(fd)), |
- supported_versions_(supported_versions) { |
+ supported_versions_(supported_versions), |
+ current_packet_(NULL), |
+ framer_(supported_versions, /*unused*/ QuicTime::Zero(), true), |
+ framer_visitor_(new QuicFramerVisitor(this)) { |
+ framer_.set_visitor(framer_visitor_.get()); |
} |
QuicDispatcher::~QuicDispatcher() { |
@@ -87,41 +176,56 @@ bool QuicDispatcher::IsWriteBlockedDataBuffered() const { |
void QuicDispatcher::ProcessPacket(const IPEndPoint& server_address, |
const IPEndPoint& client_address, |
- QuicGuid guid, |
- bool has_version_flag, |
const QuicEncryptedPacket& packet) { |
+ current_server_address_ = server_address; |
+ current_client_address_ = client_address; |
+ current_packet_ = &packet; |
+ // ProcessPacket will cause the packet to be dispatched in |
+ // OnUnauthenticatedPublicHeader, or sent to the time wait list manager |
+ // in OnAuthenticatedHeader. |
+ framer_.ProcessPacket(packet); |
+ // TODO(rjshade): Return a status describing if/why a packet was dropped, |
+ // and log somehow. Maybe expose as a varz. |
+} |
+ |
+bool QuicDispatcher::OnUnauthenticatedPublicHeader( |
+ const QuicPacketPublicHeader& header) { |
QuicSession* session = NULL; |
+ QuicGuid guid = header.guid; |
SessionMap::iterator it = session_map_.find(guid); |
if (it == session_map_.end()) { |
+ if (header.reset_flag) { |
+ return false; |
+ } |
if (time_wait_list_manager_->IsGuidInTimeWait(guid)) { |
- time_wait_list_manager_->ProcessPacket(server_address, |
- client_address, |
- guid, |
- packet); |
- return; |
+ return HandlePacketForTimeWait(header); |
} |
// Ensure the packet has a version negotiation bit set before creating a new |
// session for it. All initial packets for a new connection are required to |
// have the flag set. Otherwise it may be a stray packet. |
- if (has_version_flag) { |
- session = CreateQuicSession(guid, server_address, client_address); |
+ if (header.version_flag) { |
+ session = CreateQuicSession(guid, current_server_address_, |
+ current_client_address_); |
} |
if (session == NULL) { |
DLOG(INFO) << "Failed to create session for " << guid; |
// Add this guid fo the time-wait state, to safely reject future packets. |
- // We don't know the version here, so assume latest. |
- // TODO(ianswett): Produce a no-version version negotiation packet. |
- time_wait_list_manager_->AddGuidToTimeWait(guid, |
- supported_versions_.front(), |
- NULL); |
- time_wait_list_manager_->ProcessPacket(server_address, |
- client_address, |
- guid, |
- packet); |
- return; |
+ |
+ if (header.version_flag && |
+ !framer_.IsSupportedVersion(header.versions.front())) { |
+ // TODO(ianswett): Produce a no-version version negotiation packet. |
+ return false; |
+ } |
+ |
+ // Use the version in the packet if possible, otherwise assume the latest. |
+ QuicVersion version = header.version_flag ? header.versions.front() : |
+ supported_versions_.front(); |
+ time_wait_list_manager_->AddGuidToTimeWait(guid, version, NULL); |
+ DCHECK(time_wait_list_manager_->IsGuidInTimeWait(guid)); |
+ return HandlePacketForTimeWait(header); |
} |
DLOG(INFO) << "Created new session for " << guid; |
session_map_.insert(make_pair(guid, session)); |
@@ -130,7 +234,18 @@ void QuicDispatcher::ProcessPacket(const IPEndPoint& server_address, |
} |
session->connection()->ProcessUdpPacket( |
- server_address, client_address, packet); |
+ current_server_address_, current_client_address_, *current_packet_); |
+ |
+ // Do not parse the packet further. The session will process it completely. |
+ return false; |
+} |
+ |
+void QuicDispatcher::OnUnauthenticatedHeader(const QuicPacketHeader& header) { |
+ DCHECK(time_wait_list_manager_->IsGuidInTimeWait(header.public_header.guid)); |
+ time_wait_list_manager_->ProcessPacket(current_server_address_, |
+ current_client_address_, |
+ header.public_header.guid, |
+ header.packet_sequence_number); |
} |
void QuicDispatcher::CleanUpSession(SessionMap::iterator it) { |
@@ -223,5 +338,22 @@ QuicSession* QuicDispatcher::CreateQuicSession( |
return session; |
} |
+bool QuicDispatcher::HandlePacketForTimeWait( |
+ const QuicPacketPublicHeader& header) { |
+ if (header.reset_flag) { |
+ // Public reset packets do not have sequence numbers, so ignore the packet. |
+ return false; |
+ } |
+ |
+ // Switch the framer to the correct version, so that the sequence number can |
+ // be parsed correctly. |
+ framer_.set_version(time_wait_list_manager_->GetQuicVersionFromGuid( |
+ header.guid)); |
+ |
+ // Continue parsing the packet to extract the sequence number. Then |
+ // send it to the time wait manager in OnUnathenticatedHeader. |
+ return true; |
+} |
+ |
} // namespace tools |
} // namespace net |