Index: net/tools/quic/stateless_rejector.cc |
diff --git a/net/tools/quic/stateless_rejector.cc b/net/tools/quic/stateless_rejector.cc |
index e85e0a6a82842b23fa266c6a8843b74cb2761a1d..3c142cedd677b3f75272160615be4a7ef5b1d122 100644 |
--- a/net/tools/quic/stateless_rejector.cc |
+++ b/net/tools/quic/stateless_rejector.cc |
@@ -12,8 +12,10 @@ namespace net { |
class StatelessRejector::ValidateCallback |
: public ValidateClientHelloResultCallback { |
public: |
- explicit ValidateCallback(StatelessRejector* rejector) |
- : rejector_(rejector) {} |
+ explicit ValidateCallback( |
+ std::unique_ptr<StatelessRejector> rejector, |
+ std::unique_ptr<StatelessRejector::ProcessDoneCallback> cb) |
+ : rejector_(std::move(rejector)), cb_(std::move(cb)) {} |
~ValidateCallback() override {} |
@@ -21,11 +23,14 @@ class StatelessRejector::ValidateCallback |
const Result& result, |
std::unique_ptr<ProofSource::Details> /* proof_source_details */) |
override { |
- rejector_->ProcessClientHello(client_hello, result); |
+ StatelessRejector* rejector_ptr = rejector_.get(); |
+ rejector_ptr->ProcessClientHello(client_hello, result, std::move(rejector_), |
+ std::move(cb_)); |
} |
private: |
- StatelessRejector* rejector_; |
+ std::unique_ptr<StatelessRejector> rejector_; |
+ std::unique_ptr<StatelessRejector::ProcessDoneCallback> cb_; |
}; |
StatelessRejector::StatelessRejector( |
@@ -38,7 +43,7 @@ StatelessRejector::StatelessRejector( |
QuicByteCount chlo_packet_size, |
const IPEndPoint& client_address, |
const IPEndPoint& server_address) |
- : state_(FAILED), |
+ : state_(UNKNOWN), |
error_(QUIC_INTERNAL_ERROR), |
version_(version), |
versions_(versions), |
@@ -49,8 +54,7 @@ StatelessRejector::StatelessRejector( |
clock_(clock), |
random_(random), |
crypto_config_(crypto_config), |
- compressed_certs_cache_(compressed_certs_cache), |
- chlo_(nullptr) {} |
+ compressed_certs_cache_(compressed_certs_cache) {} |
StatelessRejector::~StatelessRejector() {} |
@@ -60,6 +64,7 @@ void StatelessRejector::OnChlo(QuicVersion version, |
const CryptoHandshakeMessage& message) { |
DCHECK_EQ(kCHLO, message.tag()); |
DCHECK_NE(connection_id, server_designated_connection_id); |
+ DCHECK_EQ(state_, UNKNOWN); |
if (!FLAGS_enable_quic_stateless_reject_support || |
!FLAGS_quic_use_cheap_stateless_rejects || |
@@ -71,16 +76,31 @@ void StatelessRejector::OnChlo(QuicVersion version, |
connection_id_ = connection_id; |
server_designated_connection_id_ = server_designated_connection_id; |
- chlo_ = &message; |
+ chlo_ = message; // Note: copies the message |
+} |
- crypto_config_->ValidateClientHello( |
- message, client_address_.address(), server_address_.address(), version_, |
- clock_, &proof_, new ValidateCallback(this)); |
+void StatelessRejector::Process(std::unique_ptr<StatelessRejector> rejector, |
+ std::unique_ptr<ProcessDoneCallback> cb) { |
+ // If we were able to make a decision about this CHLO based purely on the |
+ // information available in OnChlo, just invoke the done callback immediately. |
+ if (rejector->state() != UNKNOWN) { |
+ cb->Run(std::move(rejector)); |
+ return; |
+ } |
+ |
+ StatelessRejector* rejector_ptr = rejector.get(); |
+ rejector_ptr->crypto_config_->ValidateClientHello( |
+ rejector_ptr->chlo_, rejector_ptr->client_address_.address(), |
+ rejector_ptr->server_address_.address(), rejector_ptr->version_, |
+ rejector_ptr->clock_, &rejector_ptr->proof_, |
+ new ValidateCallback(std::move(rejector), std::move(cb))); |
} |
void StatelessRejector::ProcessClientHello( |
const CryptoHandshakeMessage& client_hello, |
- const ValidateClientHelloResultCallback::Result& result) { |
+ const ValidateClientHelloResultCallback::Result& result, |
+ std::unique_ptr<StatelessRejector> rejector, |
+ std::unique_ptr<StatelessRejector::ProcessDoneCallback> cb) { |
QuicCryptoNegotiatedParameters params; |
DiversificationNonce diversification_nonce; |
QuicErrorCode error = crypto_config_->ProcessClientHello( |
@@ -93,15 +113,13 @@ void StatelessRejector::ProcessClientHello( |
chlo_packet_size_, &reply_, &diversification_nonce, &error_details_); |
if (error != QUIC_NO_ERROR) { |
error_ = error; |
- return; |
- } |
- |
- if (reply_.tag() == kSREJ) { |
+ state_ = FAILED; |
+ } else if (reply_.tag() == kSREJ) { |
state_ = REJECTED; |
- return; |
+ } else { |
+ state_ = ACCEPTED; |
} |
- |
- state_ = ACCEPTED; |
+ cb->Run(std::move(rejector)); |
} |
} // namespace net |