Index: media/cdm/json_web_key.cc |
diff --git a/media/cdm/json_web_key.cc b/media/cdm/json_web_key.cc |
index 446d389ac8b4b6323393f938471f6210384072f2..fffcee4871d931543858478914ab63c0814c816e 100644 |
--- a/media/cdm/json_web_key.cc |
+++ b/media/cdm/json_web_key.cc |
@@ -24,13 +24,17 @@ const char kKeyTag[] = "k"; |
const char kKeyIdTag[] = "kid"; |
const char kKeyIdsTag[] = "kids"; |
const char kBase64Padding = '='; |
+const char kBase64Plus[] = "+"; |
+const char kBase64UrlPlusReplacement[] = "-"; |
+const char kBase64Slash[] = "/"; |
+const char kBase64UrlSlashReplacement[] = "_"; |
const char kTypeTag[] = "type"; |
const char kTemporarySession[] = "temporary"; |
const char kPersistentLicenseSession[] = "persistent-license"; |
const char kPersistentReleaseMessageSession[] = "persistent-release-message"; |
-// Encodes |input| into a base64 string without padding. |
-static std::string EncodeBase64(const uint8* input, int input_length) { |
+// Encodes |input| into a base64url string without padding. |
+static std::string EncodeBase64Url(const uint8* input, int input_length) { |
std::string encoded_text; |
base::Base64Encode( |
std::string(reinterpret_cast<const char*>(input), input_length), |
@@ -41,16 +45,32 @@ static std::string EncodeBase64(const uint8* input, int input_length) { |
if (found != std::string::npos) |
encoded_text.erase(found + 1); |
+ // base64url encoding means the characters '-' and '_' must be used |
+ // instead of '+' and '/', respectively. |
+ base::ReplaceChars(encoded_text, kBase64Plus, kBase64UrlPlusReplacement, |
+ &encoded_text); |
+ base::ReplaceChars(encoded_text, kBase64Slash, kBase64UrlSlashReplacement, |
+ &encoded_text); |
+ |
return encoded_text; |
} |
-// Decodes an unpadded base64 string. Returns empty string on error. |
-static std::string DecodeBase64(const std::string& encoded_text) { |
+// Decodes a base64url string. Returns empty string on error. |
+static std::string DecodeBase64Url(const std::string& encoded_text) { |
// EME spec doesn't allow padding characters. |
if (encoded_text.find_first_of(kBase64Padding) != std::string::npos) { |
DVLOG(1) << "Padding characters not allowed: " << encoded_text; |
return std::string(); |
} |
+ // TODO(jrummell): Enable once blink tests updated to use base64url encoding. |
+ //if (encoded_text.find(kBase64Plus) != std::string::npos) { |
+ // DVLOG(1) << "Base64 '+' characters not allowed: " << encoded_text; |
+ // return std::string(); |
+ //} |
+ //if (encoded_text.find(kBase64Slash) != std::string::npos) { |
+ // DVLOG(1) << "Base64 '/' characters not allowed: " << encoded_text; |
+ // return std::string(); |
+ //} |
// Since base::Base64Decode() requires padding characters, add them so length |
// of |encoded_text| is exactly a multiple of 4. |
@@ -59,6 +79,14 @@ static std::string DecodeBase64(const std::string& encoded_text) { |
if (num_last_grouping_chars > 0) |
modified_text.append(4 - num_last_grouping_chars, kBase64Padding); |
+ // base64url encoding means the characters '-' and '_' must be used |
+ // instead of '+' and '/', respectively, so replace them before calling |
+ // base::Base64Decode(). |
+ base::ReplaceChars(modified_text, kBase64UrlPlusReplacement, kBase64Plus, |
+ &modified_text); |
+ base::ReplaceChars(modified_text, kBase64UrlSlashReplacement, kBase64Slash, |
+ &modified_text); |
+ |
std::string decoded_text; |
if (!base::Base64Decode(modified_text, &decoded_text)) { |
DVLOG(1) << "Base64 decoding failed on: " << modified_text; |
@@ -71,8 +99,8 @@ static std::string DecodeBase64(const std::string& encoded_text) { |
std::string GenerateJWKSet(const uint8* key, int key_length, |
const uint8* key_id, int key_id_length) { |
// Both |key| and |key_id| need to be base64 encoded strings in the JWK. |
- std::string key_base64 = EncodeBase64(key, key_length); |
- std::string key_id_base64 = EncodeBase64(key_id, key_id_length); |
+ std::string key_base64 = EncodeBase64Url(key, key_length); |
+ std::string key_id_base64 = EncodeBase64Url(key_id, key_id_length); |
// Create the JWK, and wrap it into a JWK Set. |
scoped_ptr<base::DictionaryValue> jwk(new base::DictionaryValue()); |
@@ -121,13 +149,13 @@ static bool ConvertJwkToKeyPair(const base::DictionaryValue& jwk, |
} |
// Key ID and key are base64-encoded strings, so decode them. |
- std::string raw_key_id = DecodeBase64(encoded_key_id); |
+ std::string raw_key_id = DecodeBase64Url(encoded_key_id); |
if (raw_key_id.empty()) { |
DVLOG(1) << "Invalid '" << kKeyIdTag << "' value: " << encoded_key_id; |
return false; |
} |
- std::string raw_key = DecodeBase64(encoded_key); |
+ std::string raw_key = DecodeBase64Url(encoded_key); |
if (raw_key.empty()) { |
DVLOG(1) << "Invalid '" << kKeyTag << "' value: " << encoded_key; |
return false; |
@@ -213,7 +241,7 @@ void CreateLicenseRequest(const uint8* key_id, |
// Create the license request. |
scoped_ptr<base::DictionaryValue> request(new base::DictionaryValue()); |
scoped_ptr<base::ListValue> list(new base::ListValue()); |
- list->AppendString(EncodeBase64(key_id, key_id_length)); |
+ list->AppendString(EncodeBase64Url(key_id, key_id_length)); |
request->Set(kKeyIdsTag, list.release()); |
switch (session_type) { |
@@ -275,7 +303,7 @@ bool ExtractFirstKeyIdFromLicenseRequest(const std::vector<uint8>& license, |
return false; |
} |
- std::string decoded_string = DecodeBase64(encoded_key); |
+ std::string decoded_string = DecodeBase64Url(encoded_key); |
if (decoded_string.empty()) { |
DVLOG(1) << "Invalid '" << kKeyIdsTag << "' value: " << encoded_key; |
return false; |