| OLD | NEW |
| (Empty) | |
| 1 // Copyright (c) 2016 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. |
| 4 |
| 5 #include "net/tools/domain_security_preload_generator/trie/trie_writer.h" |
| 6 |
| 7 #include <algorithm> |
| 8 |
| 9 #include "base/logging.h" |
| 10 #include "base/strings/string_piece.h" |
| 11 #include "base/strings/string_split.h" |
| 12 #include "base/strings/string_util.h" |
| 13 #include "net/tools/domain_security_preload_generator/trie/trie_bit_buffer.h" |
| 14 |
| 15 namespace net { |
| 16 |
| 17 namespace transport_security_state { |
| 18 |
| 19 namespace { |
| 20 |
| 21 bool CompareReversedEntries(const std::unique_ptr<ReversedEntry>& lhs, |
| 22 const std::unique_ptr<ReversedEntry>& rhs) { |
| 23 return lhs->reversed_name < rhs->reversed_name; |
| 24 } |
| 25 |
| 26 std::string DomainConstant(base::StringPiece input) { |
| 27 std::vector<base::StringPiece> parts = base::SplitStringPiece( |
| 28 input, ".", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL); |
| 29 if (parts.empty()) { |
| 30 return std::string(); |
| 31 } |
| 32 |
| 33 std::string gtld = parts[parts.size() - 1].as_string(); |
| 34 if (parts.size() == 1) { |
| 35 return base::ToUpperASCII(gtld); |
| 36 } |
| 37 |
| 38 std::string domain = base::ToUpperASCII(parts[parts.size() - 2].as_string()); |
| 39 base::ReplaceChars(domain, "-", "_", &domain); |
| 40 |
| 41 return base::ToUpperASCII(domain + "_" + gtld); |
| 42 } |
| 43 |
| 44 } // namespace |
| 45 |
| 46 ReversedEntry::ReversedEntry(std::vector<uint8_t> reversed_name, |
| 47 const DomainSecurityEntry* entry) |
| 48 : reversed_name(reversed_name), entry(entry) {} |
| 49 |
| 50 ReversedEntry::~ReversedEntry() {} |
| 51 |
| 52 TrieWriter::TrieWriter(const HuffmanRepresentationTable& huffman_table, |
| 53 const NameIDMap& domain_ids_map, |
| 54 const NameIDMap& expect_ct_report_uri_map, |
| 55 const NameIDMap& expect_staple_report_uri_map, |
| 56 const NameIDMap& pinsets_map, |
| 57 HuffmanFrequencyTracker* frequency_tracker) |
| 58 : huffman_table_(huffman_table), |
| 59 domain_ids_map_(domain_ids_map), |
| 60 expect_ct_report_uri_map_(expect_ct_report_uri_map), |
| 61 expect_staple_report_uri_map_(expect_staple_report_uri_map), |
| 62 pinsets_map_(pinsets_map), |
| 63 frequency_tracker_(frequency_tracker) {} |
| 64 |
| 65 TrieWriter::~TrieWriter() {} |
| 66 |
| 67 uint32_t TrieWriter::WriteEntries(const DomainSecurityEntries& entries) { |
| 68 ReversedEntries reversed_entries; |
| 69 |
| 70 for (auto const& entry : entries) { |
| 71 std::unique_ptr<ReversedEntry> reversed_entry( |
| 72 new ReversedEntry(ReverseName(entry->hostname), entry.get())); |
| 73 reversed_entries.push_back(std::move(reversed_entry)); |
| 74 } |
| 75 |
| 76 std::stable_sort(reversed_entries.begin(), reversed_entries.end(), |
| 77 CompareReversedEntries); |
| 78 |
| 79 return WriteDispatchTables(reversed_entries.begin(), reversed_entries.end()); |
| 80 } |
| 81 |
| 82 uint32_t TrieWriter::WriteDispatchTables(ReversedEntries::iterator start, |
| 83 ReversedEntries::iterator end) { |
| 84 DCHECK(start != end) << "No entries passed to WriteDispatchTables"; |
| 85 |
| 86 TrieBitBuffer writer; |
| 87 |
| 88 std::vector<uint8_t> prefix = LongestCommonPrefix(start, end); |
| 89 for (size_t i = 0; i < prefix.size(); ++i) { |
| 90 writer.WriteBit(1); |
| 91 } |
| 92 writer.WriteBit(0); |
| 93 |
| 94 if (prefix.size()) { |
| 95 for (size_t i = 0; i < prefix.size(); ++i) { |
| 96 writer.WriteChar(prefix.at(i), huffman_table_, frequency_tracker_); |
| 97 } |
| 98 } |
| 99 |
| 100 RemovePrefix(prefix.size(), start, end); |
| 101 int32_t last_position = -1; |
| 102 |
| 103 while (start != end) { |
| 104 uint8_t candidate = (*start)->reversed_name.at(0); |
| 105 ReversedEntries::iterator sub_entries_end = start + 1; |
| 106 |
| 107 for (; sub_entries_end != end; sub_entries_end++) { |
| 108 if ((*sub_entries_end)->reversed_name.at(0) != candidate) { |
| 109 break; |
| 110 } |
| 111 } |
| 112 |
| 113 writer.WriteChar(candidate, huffman_table_, frequency_tracker_); |
| 114 |
| 115 if (candidate == kTerminalValue) { |
| 116 DCHECK((sub_entries_end - start) == 1) |
| 117 << "Multiple values with the same name"; |
| 118 WriteSecurityEntry((*start)->entry, &writer); |
| 119 } else { |
| 120 RemovePrefix(1, start, sub_entries_end); |
| 121 uint32_t position = WriteDispatchTables(start, sub_entries_end); |
| 122 writer.WritePosition(position, &last_position); |
| 123 } |
| 124 |
| 125 start = sub_entries_end; |
| 126 } |
| 127 |
| 128 writer.WriteChar(kEndOfTableValue, huffman_table_, frequency_tracker_); |
| 129 |
| 130 uint32_t position = buffer_.position(); |
| 131 writer.Flush(); |
| 132 writer.WriteToBitWriter(&buffer_); |
| 133 return position; |
| 134 } |
| 135 |
| 136 void TrieWriter::WriteSecurityEntry(const DomainSecurityEntry* entry, |
| 137 TrieBitBuffer* writer) { |
| 138 uint8_t include_subdomains = 0; |
| 139 if (entry->include_subdomains) { |
| 140 include_subdomains = 1; |
| 141 } |
| 142 writer->WriteBit(include_subdomains); |
| 143 |
| 144 uint8_t force_https = 0; |
| 145 if (entry->force_https) { |
| 146 force_https = 1; |
| 147 } |
| 148 writer->WriteBit(force_https); |
| 149 |
| 150 if (entry->pinset.size()) { |
| 151 writer->WriteBit(1); |
| 152 NameIDMap::const_iterator pin_id_it = pinsets_map_.find(entry->pinset); |
| 153 DCHECK(pin_id_it != pinsets_map_.cend()) << "invalid pinset"; |
| 154 const uint8_t& pin_id = pin_id_it->second; |
| 155 DCHECK(pin_id <= 16) << "too many pinsets"; |
| 156 writer->WriteBits(pin_id, 4); |
| 157 |
| 158 NameIDMap::const_iterator domain_id_it = |
| 159 domain_ids_map_.find(DomainConstant(entry->hostname)); |
| 160 DCHECK(domain_id_it != domain_ids_map_.cend()) << "invalid domain id"; |
| 161 uint32_t domain_id = domain_id_it->second; |
| 162 DCHECK(domain_id < 512) << "too many domain ids"; |
| 163 writer->WriteBits(domain_id, 9); |
| 164 |
| 165 if (!entry->include_subdomains) { |
| 166 uint8_t include_subdomains_for_pinning = 0; |
| 167 if (entry->hpkp_include_subdomains) { |
| 168 include_subdomains_for_pinning = 1; |
| 169 } |
| 170 writer->WriteBit(include_subdomains_for_pinning); |
| 171 } |
| 172 } else { |
| 173 writer->WriteBit(0); |
| 174 } |
| 175 |
| 176 if (entry->expect_ct) { |
| 177 writer->WriteBit(1); |
| 178 NameIDMap::const_iterator expect_ct_report_uri_it = |
| 179 expect_ct_report_uri_map_.find(entry->expect_ct_report_uri); |
| 180 DCHECK(expect_ct_report_uri_it != expect_ct_report_uri_map_.cend()) |
| 181 << "invalid expect-ct report-uri"; |
| 182 const uint8_t& expect_ct_report_id = expect_ct_report_uri_it->second; |
| 183 |
| 184 DCHECK(expect_ct_report_id < 16) << "too many expect-ct ids"; |
| 185 |
| 186 writer->WriteBits(expect_ct_report_id, 4); |
| 187 } else { |
| 188 writer->WriteBit(0); |
| 189 } |
| 190 |
| 191 if (entry->expect_staple) { |
| 192 writer->WriteBit(1); |
| 193 |
| 194 if (entry->expect_staple_include_subdomains) { |
| 195 writer->WriteBit(1); |
| 196 } else { |
| 197 writer->WriteBit(0); |
| 198 } |
| 199 |
| 200 NameIDMap::const_iterator expect_staple_report_uri_it = |
| 201 expect_staple_report_uri_map_.find(entry->expect_staple_report_uri); |
| 202 DCHECK(expect_staple_report_uri_it != expect_staple_report_uri_map_.cend()) |
| 203 << "invalid expect-ct report-uri"; |
| 204 const uint8_t& expect_staple_report_id = |
| 205 expect_staple_report_uri_it->second; |
| 206 DCHECK(expect_staple_report_id < 16) << "too many expect-staple ids"; |
| 207 |
| 208 writer->WriteBits(expect_staple_report_id, 4); |
| 209 } else { |
| 210 writer->WriteBit(0); |
| 211 } |
| 212 } |
| 213 |
| 214 void TrieWriter::RemovePrefix(size_t length, |
| 215 ReversedEntries::iterator start, |
| 216 ReversedEntries::iterator end) { |
| 217 for (ReversedEntries::iterator it = start; it != end; ++it) { |
| 218 (*it)->reversed_name.erase((*it)->reversed_name.begin(), |
| 219 (*it)->reversed_name.begin() + length); |
| 220 } |
| 221 } |
| 222 |
| 223 std::vector<uint8_t> TrieWriter::LongestCommonPrefix( |
| 224 ReversedEntries::iterator start, |
| 225 ReversedEntries::iterator end) const { |
| 226 if (start == end) { |
| 227 return std::vector<uint8_t>(); |
| 228 } |
| 229 |
| 230 std::vector<uint8_t> prefix; |
| 231 for (size_t i = 0;; ++i) { |
| 232 if (i > (*start)->reversed_name.size()) { |
| 233 break; |
| 234 } |
| 235 |
| 236 uint8_t candidate = (*start)->reversed_name.at(i); |
| 237 if (candidate == kTerminalValue) { |
| 238 break; |
| 239 } |
| 240 |
| 241 bool ok = true; |
| 242 for (ReversedEntries::iterator it = start + 1; it != end; ++it) { |
| 243 if (i > (*it)->reversed_name.size() || |
| 244 (*it)->reversed_name.at(i) != candidate) { |
| 245 ok = false; |
| 246 break; |
| 247 } |
| 248 } |
| 249 |
| 250 if (!ok) { |
| 251 break; |
| 252 } |
| 253 |
| 254 prefix.push_back(candidate); |
| 255 } |
| 256 |
| 257 return prefix; |
| 258 } |
| 259 |
| 260 std::vector<uint8_t> TrieWriter::ReverseName( |
| 261 const std::string& hostname) const { |
| 262 size_t hostname_size = hostname.size(); |
| 263 std::vector<uint8_t> reversed_name(hostname_size + 1); |
| 264 |
| 265 for (size_t i = 0; i < hostname_size; ++i) { |
| 266 reversed_name[i] = hostname[hostname_size - i - 1]; |
| 267 } |
| 268 |
| 269 reversed_name[reversed_name.size() - 1] = kTerminalValue; |
| 270 return reversed_name; |
| 271 } |
| 272 |
| 273 uint32_t TrieWriter::position() const { |
| 274 return buffer_.position(); |
| 275 } |
| 276 |
| 277 void TrieWriter::Flush() { |
| 278 buffer_.Flush(); |
| 279 } |
| 280 |
| 281 } // namespace transport_security_state |
| 282 |
| 283 } // namespace net |
| OLD | NEW |