OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/quic/test_tools/crypto_test_utils.h" | 5 #include "net/quic/test_tools/crypto_test_utils.h" |
6 | 6 |
7 #include "net/quic/crypto/channel_id.h" | 7 #include "net/quic/crypto/channel_id.h" |
8 #include "net/quic/crypto/common_cert_set.h" | 8 #include "net/quic/crypto/common_cert_set.h" |
9 #include "net/quic/crypto/crypto_handshake.h" | 9 #include "net/quic/crypto/crypto_handshake.h" |
10 #include "net/quic/crypto/quic_crypto_server_config.h" | 10 #include "net/quic/crypto/quic_crypto_server_config.h" |
(...skipping 23 matching lines...) Expand all Loading... |
34 const char kServerHostname[] = "test.example.com"; | 34 const char kServerHostname[] = "test.example.com"; |
35 const uint16 kServerPort = 80; | 35 const uint16 kServerPort = 80; |
36 | 36 |
37 // CryptoFramerVisitor is a framer visitor that records handshake messages. | 37 // CryptoFramerVisitor is a framer visitor that records handshake messages. |
38 class CryptoFramerVisitor : public CryptoFramerVisitorInterface { | 38 class CryptoFramerVisitor : public CryptoFramerVisitorInterface { |
39 public: | 39 public: |
40 CryptoFramerVisitor() | 40 CryptoFramerVisitor() |
41 : error_(false) { | 41 : error_(false) { |
42 } | 42 } |
43 | 43 |
44 virtual void OnError(CryptoFramer* framer) override { error_ = true; } | 44 void OnError(CryptoFramer* framer) override { error_ = true; } |
45 | 45 |
46 virtual void OnHandshakeMessage( | 46 void OnHandshakeMessage(const CryptoHandshakeMessage& message) override { |
47 const CryptoHandshakeMessage& message) override { | |
48 messages_.push_back(message); | 47 messages_.push_back(message); |
49 } | 48 } |
50 | 49 |
51 bool error() const { | 50 bool error() const { |
52 return error_; | 51 return error_; |
53 } | 52 } |
54 | 53 |
55 const vector<CryptoHandshakeMessage>& messages() const { | 54 const vector<CryptoHandshakeMessage>& messages() const { |
56 return messages_; | 55 return messages_; |
57 } | 56 } |
(...skipping 72 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
130 } | 129 } |
131 | 130 |
132 // A ChannelIDSource that works in asynchronous mode unless the |callback| | 131 // A ChannelIDSource that works in asynchronous mode unless the |callback| |
133 // argument to GetChannelIDKey is nullptr. | 132 // argument to GetChannelIDKey is nullptr. |
134 class AsyncTestChannelIDSource : public ChannelIDSource, | 133 class AsyncTestChannelIDSource : public ChannelIDSource, |
135 public CryptoTestUtils::CallbackSource { | 134 public CryptoTestUtils::CallbackSource { |
136 public: | 135 public: |
137 // Takes ownership of |sync_source|, a synchronous ChannelIDSource. | 136 // Takes ownership of |sync_source|, a synchronous ChannelIDSource. |
138 explicit AsyncTestChannelIDSource(ChannelIDSource* sync_source) | 137 explicit AsyncTestChannelIDSource(ChannelIDSource* sync_source) |
139 : sync_source_(sync_source) {} | 138 : sync_source_(sync_source) {} |
140 virtual ~AsyncTestChannelIDSource() {} | 139 ~AsyncTestChannelIDSource() override {} |
141 | 140 |
142 // ChannelIDSource implementation. | 141 // ChannelIDSource implementation. |
143 virtual QuicAsyncStatus GetChannelIDKey( | 142 QuicAsyncStatus GetChannelIDKey(const string& hostname, |
144 const string& hostname, | 143 scoped_ptr<ChannelIDKey>* channel_id_key, |
145 scoped_ptr<ChannelIDKey>* channel_id_key, | 144 ChannelIDSourceCallback* callback) override { |
146 ChannelIDSourceCallback* callback) override { | |
147 // Synchronous mode. | 145 // Synchronous mode. |
148 if (!callback) { | 146 if (!callback) { |
149 return sync_source_->GetChannelIDKey(hostname, channel_id_key, nullptr); | 147 return sync_source_->GetChannelIDKey(hostname, channel_id_key, nullptr); |
150 } | 148 } |
151 | 149 |
152 // Asynchronous mode. | 150 // Asynchronous mode. |
153 QuicAsyncStatus status = | 151 QuicAsyncStatus status = |
154 sync_source_->GetChannelIDKey(hostname, &channel_id_key_, nullptr); | 152 sync_source_->GetChannelIDKey(hostname, &channel_id_key_, nullptr); |
155 if (status != QUIC_SUCCESS) { | 153 if (status != QUIC_SUCCESS) { |
156 return QUIC_FAILURE; | 154 return QUIC_FAILURE; |
157 } | 155 } |
158 callback_.reset(callback); | 156 callback_.reset(callback); |
159 return QUIC_PENDING; | 157 return QUIC_PENDING; |
160 } | 158 } |
161 | 159 |
162 // CallbackSource implementation. | 160 // CallbackSource implementation. |
163 virtual void RunPendingCallbacks() override { | 161 void RunPendingCallbacks() override { |
164 if (callback_.get()) { | 162 if (callback_.get()) { |
165 callback_->Run(&channel_id_key_); | 163 callback_->Run(&channel_id_key_); |
166 callback_.reset(); | 164 callback_.reset(); |
167 } | 165 } |
168 } | 166 } |
169 | 167 |
170 private: | 168 private: |
171 scoped_ptr<ChannelIDSource> sync_source_; | 169 scoped_ptr<ChannelIDSource> sync_source_; |
172 scoped_ptr<ChannelIDSourceCallback> callback_; | 170 scoped_ptr<ChannelIDSourceCallback> callback_; |
173 scoped_ptr<ChannelIDKey> channel_id_key_; | 171 scoped_ptr<ChannelIDKey> channel_id_key_; |
(...skipping 172 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
346 } | 344 } |
347 | 345 |
348 class MockCommonCertSets : public CommonCertSets { | 346 class MockCommonCertSets : public CommonCertSets { |
349 public: | 347 public: |
350 MockCommonCertSets(StringPiece cert, uint64 hash, uint32 index) | 348 MockCommonCertSets(StringPiece cert, uint64 hash, uint32 index) |
351 : cert_(cert.as_string()), | 349 : cert_(cert.as_string()), |
352 hash_(hash), | 350 hash_(hash), |
353 index_(index) { | 351 index_(index) { |
354 } | 352 } |
355 | 353 |
356 virtual StringPiece GetCommonHashes() const override { | 354 StringPiece GetCommonHashes() const override { |
357 CHECK(false) << "not implemented"; | 355 CHECK(false) << "not implemented"; |
358 return StringPiece(); | 356 return StringPiece(); |
359 } | 357 } |
360 | 358 |
361 virtual StringPiece GetCert(uint64 hash, uint32 index) const override { | 359 StringPiece GetCert(uint64 hash, uint32 index) const override { |
362 if (hash == hash_ && index == index_) { | 360 if (hash == hash_ && index == index_) { |
363 return cert_; | 361 return cert_; |
364 } | 362 } |
365 return StringPiece(); | 363 return StringPiece(); |
366 } | 364 } |
367 | 365 |
368 virtual bool MatchCert(StringPiece cert, | 366 bool MatchCert(StringPiece cert, |
369 StringPiece common_set_hashes, | 367 StringPiece common_set_hashes, |
370 uint64* out_hash, | 368 uint64* out_hash, |
371 uint32* out_index) const override { | 369 uint32* out_index) const override { |
372 if (cert != cert_) { | 370 if (cert != cert_) { |
373 return false; | 371 return false; |
374 } | 372 } |
375 | 373 |
376 if (common_set_hashes.size() % sizeof(uint64) != 0) { | 374 if (common_set_hashes.size() % sizeof(uint64) != 0) { |
377 return false; | 375 return false; |
378 } | 376 } |
379 bool client_has_set = false; | 377 bool client_has_set = false; |
380 for (size_t i = 0; i < common_set_hashes.size(); i += sizeof(uint64)) { | 378 for (size_t i = 0; i < common_set_hashes.size(); i += sizeof(uint64)) { |
381 uint64 hash; | 379 uint64 hash; |
(...skipping 242 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
624 scoped_ptr<QuicData> bytes(CryptoFramer::ConstructHandshakeMessage(msg)); | 622 scoped_ptr<QuicData> bytes(CryptoFramer::ConstructHandshakeMessage(msg)); |
625 scoped_ptr<CryptoHandshakeMessage> parsed( | 623 scoped_ptr<CryptoHandshakeMessage> parsed( |
626 CryptoFramer::ParseMessage(bytes->AsStringPiece())); | 624 CryptoFramer::ParseMessage(bytes->AsStringPiece())); |
627 CHECK(parsed.get()); | 625 CHECK(parsed.get()); |
628 | 626 |
629 return *parsed; | 627 return *parsed; |
630 } | 628 } |
631 | 629 |
632 } // namespace test | 630 } // namespace test |
633 } // namespace net | 631 } // namespace net |
OLD | NEW |