| OLD | NEW |
| (Empty) | |
| 1 // Copyright (c) 2016 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. |
| 4 |
| 5 #include "net/tools/domain_security_preload_generator/trie/trie_writer.h" |
| 6 |
| 7 #include <algorithm> |
| 8 |
| 9 #include "base/logging.h" |
| 10 #include "base/strings/string_piece.h" |
| 11 #include "base/strings/string_split.h" |
| 12 #include "base/strings/string_util.h" |
| 13 #include "net/tools/domain_security_preload_generator/trie/trie_bit_buffer.h" |
| 14 |
| 15 namespace net { |
| 16 |
| 17 namespace { |
| 18 |
| 19 bool CompareReversedEntries(const std::unique_ptr<ReversedEntry>& lhs, |
| 20 const std::unique_ptr<ReversedEntry>& rhs) { |
| 21 return lhs->reversed_name < rhs->reversed_name; |
| 22 } |
| 23 |
| 24 std::string DomainConstant(base::StringPiece input) { |
| 25 std::vector<base::StringPiece> parts = base::SplitStringPiece( |
| 26 input, ".", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL); |
| 27 std::string gtld = parts[parts.size() - 1].as_string(); |
| 28 |
| 29 if (parts.size() == 1) { |
| 30 return base::ToUpperASCII(gtld); |
| 31 } |
| 32 |
| 33 std::string domain = base::ToUpperASCII(parts[parts.size() - 2].as_string()); |
| 34 base::ReplaceChars(domain, "-", "_", &domain); |
| 35 |
| 36 return base::ToUpperASCII(domain + "_" + gtld); |
| 37 } |
| 38 |
| 39 } // namespace |
| 40 |
| 41 ReversedEntry::ReversedEntry(std::vector<uint8_t> reversed_name, |
| 42 const DomainSecurityEntry* entry) |
| 43 : reversed_name(reversed_name), entry(entry) {} |
| 44 |
| 45 ReversedEntry::~ReversedEntry() {} |
| 46 |
| 47 TrieWriter::TrieWriter(const HuffmanRepresentationTable& huffman_table, |
| 48 const NameIDMap& domain_ids_map, |
| 49 const NameIDMap& expect_ct_report_uri_map, |
| 50 const NameIDMap& expect_staple_report_uri_map, |
| 51 const NameIDMap& pinsets_map, |
| 52 HuffmanFrequencyTracker* frequency_tracker) |
| 53 : huffman_table_(huffman_table), |
| 54 domain_ids_map_(domain_ids_map), |
| 55 expect_ct_report_uri_map_(expect_ct_report_uri_map), |
| 56 expect_staple_report_uri_map_(expect_staple_report_uri_map), |
| 57 pinsets_map_(pinsets_map), |
| 58 frequency_tracker_(frequency_tracker) {} |
| 59 |
| 60 TrieWriter::~TrieWriter() {} |
| 61 |
| 62 int TrieWriter::WriteEntries(const DomainSecurityEntries& entries) { |
| 63 ReversedEntries reversed_entries; |
| 64 |
| 65 for (auto const& entry : entries) { |
| 66 std::unique_ptr<ReversedEntry> reversed_entry( |
| 67 new ReversedEntry(ReverseName(entry->hostname()), entry.get())); |
| 68 reversed_entries.push_back(std::move(reversed_entry)); |
| 69 } |
| 70 |
| 71 std::stable_sort(reversed_entries.begin(), reversed_entries.end(), |
| 72 CompareReversedEntries); |
| 73 |
| 74 return WriteDispatchTables(reversed_entries.begin(), reversed_entries.end(), |
| 75 0); |
| 76 } |
| 77 |
| 78 int TrieWriter::WriteDispatchTables(ReversedEntries::iterator start, |
| 79 ReversedEntries::iterator end, |
| 80 int depth) { |
| 81 CHECK(start != end) << "No entries passed to WriteDispatchTables"; |
| 82 |
| 83 TrieBitBuffer writer; |
| 84 |
| 85 std::vector<uint8_t> prefix = LongestCommonPrefix(start, end); |
| 86 for (size_t i = 0; i < prefix.size(); ++i) { |
| 87 writer.WriteBit(1); |
| 88 } |
| 89 writer.WriteBit(0); |
| 90 |
| 91 if (prefix.size()) { |
| 92 for (size_t i = 0; i < prefix.size(); ++i) { |
| 93 writer.WriteChar(prefix.at(i), huffman_table_, frequency_tracker_); |
| 94 depth++; |
| 95 } |
| 96 } |
| 97 |
| 98 RemovePrefix(prefix.size(), start, end); |
| 99 int last_position = -1; |
| 100 |
| 101 while (start != end) { |
| 102 uint8_t candidate = (*start)->reversed_name.at(0); |
| 103 ReversedEntries::iterator sub_entries_end = start + 1; |
| 104 |
| 105 for (; sub_entries_end != end; sub_entries_end++) { |
| 106 if ((*sub_entries_end)->reversed_name.at(0) != candidate) { |
| 107 break; |
| 108 } |
| 109 } |
| 110 |
| 111 writer.WriteChar(candidate, huffman_table_, frequency_tracker_); |
| 112 |
| 113 if (candidate == kTerminalValue) { |
| 114 CHECK((sub_entries_end - start) == 1) |
| 115 << "Multiple values with the same name"; |
| 116 WriteSecurityEntry((*start)->entry, &writer); |
| 117 } else { |
| 118 RemovePrefix(1, start, sub_entries_end); |
| 119 int position = WriteDispatchTables(start, sub_entries_end, depth + 2); |
| 120 writer.WritePosition(position, &last_position); |
| 121 } |
| 122 |
| 123 start = sub_entries_end; |
| 124 } |
| 125 |
| 126 writer.WriteChar(kEndOfTableValue, huffman_table_, frequency_tracker_); |
| 127 |
| 128 uint32_t position = buffer_.position(); |
| 129 writer.Close(); |
| 130 writer.WriteToBitWriter(buffer_); |
| 131 return position; |
| 132 } |
| 133 |
| 134 void TrieWriter::WriteSecurityEntry(const DomainSecurityEntry* entry, |
| 135 TrieBitBuffer* writer) { |
| 136 uint8_t include_subdomains = 0; |
| 137 if (entry->include_subdomains()) { |
| 138 include_subdomains = 1; |
| 139 } |
| 140 writer->WriteBit(include_subdomains); |
| 141 |
| 142 uint8_t force_https = 0; |
| 143 if (entry->force_https()) { |
| 144 force_https = 1; |
| 145 } |
| 146 writer->WriteBit(force_https); |
| 147 |
| 148 if (entry->pinset().size()) { |
| 149 writer->WriteBit(1); |
| 150 NameIDMap::const_iterator pin_id_it = pinsets_map_.find(entry->pinset()); |
| 151 CHECK(pin_id_it != pinsets_map_.cend()) << "invalid pinset"; |
| 152 const uint8_t& pin_id = pin_id_it->second; |
| 153 CHECK(pin_id <= 16) << "too many pinsets"; |
| 154 writer->WriteBits(pin_id, 4); |
| 155 |
| 156 NameIDMap::const_iterator domain_id_it = |
| 157 domain_ids_map_.find(DomainConstant(entry->hostname())); |
| 158 CHECK(domain_id_it != domain_ids_map_.cend()) << "invalid domain id"; |
| 159 uint32_t domain_id = domain_id_it->second; |
| 160 CHECK(domain_id < 512) << "too many domain ids"; |
| 161 writer->WriteBits(domain_id, 9); |
| 162 |
| 163 if (!entry->include_subdomains()) { |
| 164 uint8_t include_subdomains_for_pinning = 0; |
| 165 if (entry->hpkp_include_subdomains()) { |
| 166 include_subdomains_for_pinning = 1; |
| 167 } |
| 168 writer->WriteBit(include_subdomains_for_pinning); |
| 169 } |
| 170 } else { |
| 171 writer->WriteBit(0); |
| 172 } |
| 173 |
| 174 if (entry->expect_ct()) { |
| 175 writer->WriteBit(1); |
| 176 NameIDMap::const_iterator expect_ct_report_uri_it = |
| 177 expect_ct_report_uri_map_.find(entry->expect_ct_report_uri()); |
| 178 CHECK(expect_ct_report_uri_it != expect_ct_report_uri_map_.cend()) |
| 179 << "invalid expect-ct report-uri"; |
| 180 const uint8_t& expect_ct_report_id = expect_ct_report_uri_it->second; |
| 181 |
| 182 CHECK(expect_ct_report_id < 16) << "too many expect-ct ids"; |
| 183 |
| 184 writer->WriteBits(expect_ct_report_id, 4); |
| 185 } else { |
| 186 writer->WriteBit(0); |
| 187 } |
| 188 |
| 189 if (entry->expect_staple()) { |
| 190 writer->WriteBit(1); |
| 191 |
| 192 if (entry->expect_staple_include_subdomains()) { |
| 193 writer->WriteBit(1); |
| 194 } else { |
| 195 writer->WriteBit(0); |
| 196 } |
| 197 |
| 198 NameIDMap::const_iterator expect_staple_report_uri_it = |
| 199 expect_staple_report_uri_map_.find(entry->expect_staple_report_uri()); |
| 200 CHECK(expect_staple_report_uri_it != expect_staple_report_uri_map_.cend()) |
| 201 << "invalid expect-ct report-uri"; |
| 202 const uint8_t& expect_staple_report_id = |
| 203 expect_staple_report_uri_it->second; |
| 204 CHECK(expect_staple_report_id < 16) << "too many expect-staple ids"; |
| 205 |
| 206 writer->WriteBits(expect_staple_report_id, 4); |
| 207 } else { |
| 208 writer->WriteBit(0); |
| 209 } |
| 210 } |
| 211 |
| 212 void TrieWriter::RemovePrefix(size_t length, |
| 213 ReversedEntries::iterator start, |
| 214 ReversedEntries::iterator end) { |
| 215 for (ReversedEntries::iterator it = start; it != end; ++it) { |
| 216 (*it)->reversed_name.erase((*it)->reversed_name.begin(), |
| 217 (*it)->reversed_name.begin() + length); |
| 218 } |
| 219 } |
| 220 |
| 221 std::vector<uint8_t> TrieWriter::LongestCommonPrefix( |
| 222 ReversedEntries::iterator start, |
| 223 ReversedEntries::iterator end) const { |
| 224 if (start == end) { |
| 225 return std::vector<uint8_t>(); |
| 226 } |
| 227 |
| 228 std::vector<uint8_t> prefix; |
| 229 for (size_t i = 0;; ++i) { |
| 230 if (i > (*start)->reversed_name.size()) { |
| 231 break; |
| 232 } |
| 233 |
| 234 uint8_t candidate = (*start)->reversed_name.at(i); |
| 235 if (candidate == kTerminalValue) { |
| 236 break; |
| 237 } |
| 238 |
| 239 bool ok = true; |
| 240 for (ReversedEntries::iterator it = start + 1; it != end; ++it) { |
| 241 if (i > (*it)->reversed_name.size() || |
| 242 (*it)->reversed_name.at(i) != candidate) { |
| 243 ok = false; |
| 244 break; |
| 245 } |
| 246 } |
| 247 |
| 248 if (!ok) { |
| 249 break; |
| 250 } |
| 251 |
| 252 prefix.push_back(candidate); |
| 253 } |
| 254 |
| 255 return prefix; |
| 256 } |
| 257 |
| 258 std::vector<uint8_t> TrieWriter::ReverseName( |
| 259 const std::string& hostname) const { |
| 260 size_t hostname_size = hostname.size(); |
| 261 std::vector<uint8_t> reversed_name(hostname_size + 1); |
| 262 |
| 263 for (size_t i = 0; i < hostname_size; ++i) { |
| 264 reversed_name[i] = hostname[hostname_size - i - 1]; |
| 265 } |
| 266 |
| 267 reversed_name[reversed_name.size() - 1] = kTerminalValue; |
| 268 return reversed_name; |
| 269 } |
| 270 |
| 271 uint32_t TrieWriter::position() const { |
| 272 return buffer_.position(); |
| 273 } |
| 274 |
| 275 void TrieWriter::close() { |
| 276 buffer_.Close(); |
| 277 } |
| 278 |
| 279 } // namespace net |
| OLD | NEW |