Index: net/test/openssl_helper.cc |
diff --git a/net/test/openssl_helper.cc b/net/test/openssl_helper.cc |
index b3eb20fd3e557864bab7bf745f7fb0b04d5ea0e4..25989cb60166607e0f67e01e176df506b0a2e157 100644 |
--- a/net/test/openssl_helper.cc |
+++ b/net/test/openssl_helper.cc |
@@ -30,9 +30,19 @@ static int verify_cb(int preverify_ok, X509_STORE_CTX *ctx) { |
// Next Protocol Negotiation callback from OpenSSL |
static int next_proto_cb(SSL *ssl, const unsigned char **out, |
unsigned int *outlen, void *arg) { |
+ bool* npn_mispredict = reinterpret_cast<bool*>(arg); |
static char kProtos[] = "\003foo\003bar"; |
- *out = (const unsigned char*) kProtos; |
- *outlen = sizeof(kProtos) - 1; |
+ static char kProtos2[] = "\003baz\003boo"; |
+ static unsigned count = 0; |
+ |
+ if (!*npn_mispredict || count == 0) { |
+ *out = (const unsigned char*) kProtos; |
+ *outlen = sizeof(kProtos) - 1; |
+ } else { |
+ *out = (const unsigned char*) kProtos2; |
+ *outlen = sizeof(kProtos2) - 1; |
+ } |
+ count++; |
return SSL_TLSEXT_ERR_OK; |
} |
@@ -46,6 +56,7 @@ main(int argc, char **argv) { |
bool sni = false, sni_good = false, snap_start = false; |
bool snap_start_recovery = false, sslv3 = false, session_tickets = false; |
bool fail_resume = false, client_cert = false, npn = false; |
+ bool npn_mispredict = false; |
const char* key_file = kDefaultPEMFile; |
const char* cert_file = kDefaultPEMFile; |
@@ -76,6 +87,10 @@ main(int argc, char **argv) { |
} else if (strcmp(argv[i], "npn") == 0) { |
// Advertise NPN |
npn = true; |
+ } else if (strcmp(argv[i], "npn-mispredict") == 0) { |
+ // Advertise NPN |
+ npn = true; |
+ npn_mispredict = true; |
} else if (strcmp(argv[i], "--key-file") == 0) { |
// Use alternative key file |
i++; |
@@ -165,11 +180,13 @@ main(int argc, char **argv) { |
} |
if (npn) |
- SSL_CTX_set_next_protos_advertised_cb(ctx, next_proto_cb, NULL); |
+ SSL_CTX_set_next_protos_advertised_cb(ctx, next_proto_cb, &npn_mispredict); |
unsigned connection_limit = 1; |
if (snap_start || session_tickets) |
connection_limit = 2; |
+ if (npn_mispredict) |
+ connection_limit = 3; |
for (unsigned connections = 0; connections < connection_limit; |
connections++) { |
@@ -209,10 +226,17 @@ main(int argc, char **argv) { |
} |
if (npn) { |
- const unsigned char *data; |
- unsigned len; |
+ const unsigned char *data, *expected_data; |
+ unsigned len, expected_len; |
SSL_get0_next_proto_negotiated(server, &data, &len); |
- if (len != 3 || memcmp(data, "bar", 3) != 0) { |
+ if (!npn_mispredict || connections == 0) { |
+ expected_data = (unsigned char*) "foo"; |
+ expected_len = 3; |
+ } else { |
+ expected_data = (unsigned char*) "baz"; |
+ expected_len = 3; |
+ } |
+ if (len != expected_len || memcmp(data, expected_data, len) != 0) { |
fprintf(stderr, "Bad NPN: %d\n", len); |
return 1; |
} |