OLD | NEW |
---|---|
1 // Copyright (c) 2016 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2016 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/tools/transport_security_state_generator/trie/trie_writer.h" | 5 #include "net/tools/transport_security_state_generator/trie/trie_writer.h" |
6 | 6 |
7 #include <algorithm> | 7 #include <algorithm> |
8 | 8 |
9 #include "base/logging.h" | 9 #include "base/logging.h" |
10 #include "base/strings/string_piece.h" | 10 #include "base/strings/string_piece.h" |
(...skipping 46 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
57 HuffmanBuilder* huffman_builder) | 57 HuffmanBuilder* huffman_builder) |
58 : huffman_table_(huffman_table), | 58 : huffman_table_(huffman_table), |
59 domain_ids_map_(domain_ids_map), | 59 domain_ids_map_(domain_ids_map), |
60 expect_ct_report_uri_map_(expect_ct_report_uri_map), | 60 expect_ct_report_uri_map_(expect_ct_report_uri_map), |
61 expect_staple_report_uri_map_(expect_staple_report_uri_map), | 61 expect_staple_report_uri_map_(expect_staple_report_uri_map), |
62 pinsets_map_(pinsets_map), | 62 pinsets_map_(pinsets_map), |
63 huffman_builder_(huffman_builder) {} | 63 huffman_builder_(huffman_builder) {} |
64 | 64 |
65 TrieWriter::~TrieWriter() {} | 65 TrieWriter::~TrieWriter() {} |
66 | 66 |
67 uint32_t TrieWriter::WriteEntries( | 67 bool TrieWriter::WriteEntries(const TransportSecurityStateEntries& entries, |
68 const TransportSecurityStateEntries& entries) { | 68 uint32_t* root_position) { |
69 if (entries.empty()) | |
70 return false; | |
71 | |
69 ReversedEntries reversed_entries; | 72 ReversedEntries reversed_entries; |
70 | |
71 for (auto const& entry : entries) { | 73 for (auto const& entry : entries) { |
72 std::unique_ptr<ReversedEntry> reversed_entry( | 74 std::unique_ptr<ReversedEntry> reversed_entry( |
73 new ReversedEntry(ReverseName(entry->hostname), entry.get())); | 75 new ReversedEntry(ReverseName(entry->hostname), entry.get())); |
74 reversed_entries.push_back(std::move(reversed_entry)); | 76 reversed_entries.push_back(std::move(reversed_entry)); |
75 } | 77 } |
76 | 78 |
77 std::stable_sort(reversed_entries.begin(), reversed_entries.end(), | 79 std::stable_sort(reversed_entries.begin(), reversed_entries.end(), |
78 CompareReversedEntries); | 80 CompareReversedEntries); |
79 | 81 |
80 return WriteDispatchTables(reversed_entries.begin(), reversed_entries.end()); | 82 return WriteDispatchTables(reversed_entries.begin(), reversed_entries.end(), |
83 root_position); | |
81 } | 84 } |
82 | 85 |
83 uint32_t TrieWriter::WriteDispatchTables(ReversedEntries::iterator start, | 86 bool TrieWriter::WriteDispatchTables(ReversedEntries::iterator start, |
84 ReversedEntries::iterator end) { | 87 ReversedEntries::iterator end, |
88 uint32_t* position) { | |
85 DCHECK(start != end) << "No entries passed to WriteDispatchTables"; | 89 DCHECK(start != end) << "No entries passed to WriteDispatchTables"; |
86 | 90 |
87 TrieBitBuffer writer; | 91 TrieBitBuffer writer; |
88 | 92 |
89 std::vector<uint8_t> prefix = LongestCommonPrefix(start, end); | 93 std::vector<uint8_t> prefix = LongestCommonPrefix(start, end); |
90 for (size_t i = 0; i < prefix.size(); ++i) { | 94 for (size_t i = 0; i < prefix.size(); ++i) { |
91 writer.WriteBit(1); | 95 writer.WriteBit(1); |
92 } | 96 } |
93 writer.WriteBit(0); | 97 writer.WriteBit(0); |
94 | 98 |
95 if (prefix.size()) { | 99 if (prefix.size()) { |
96 for (size_t i = 0; i < prefix.size(); ++i) { | 100 for (size_t i = 0; i < prefix.size(); ++i) { |
97 writer.WriteChar(prefix.at(i), huffman_table_, huffman_builder_); | 101 if (!writer.WriteChar(prefix.at(i), huffman_table_, huffman_builder_)) { |
102 return false; | |
103 } | |
98 } | 104 } |
99 } | 105 } |
100 | 106 |
101 RemovePrefix(prefix.size(), start, end); | 107 RemovePrefix(prefix.size(), start, end); |
102 int32_t last_position = -1; | 108 int32_t last_position = -1; |
103 | 109 |
104 while (start != end) { | 110 while (start != end) { |
105 uint8_t candidate = (*start)->reversed_name.at(0); | 111 uint8_t candidate = (*start)->reversed_name.at(0); |
106 ReversedEntries::iterator sub_entries_end = start + 1; | 112 ReversedEntries::iterator sub_entries_end = start + 1; |
107 | 113 |
108 for (; sub_entries_end != end; sub_entries_end++) { | 114 for (; sub_entries_end != end; sub_entries_end++) { |
109 if ((*sub_entries_end)->reversed_name.at(0) != candidate) { | 115 if ((*sub_entries_end)->reversed_name.at(0) != candidate) { |
110 break; | 116 break; |
111 } | 117 } |
112 } | 118 } |
113 | 119 |
114 writer.WriteChar(candidate, huffman_table_, huffman_builder_); | 120 if (!writer.WriteChar(candidate, huffman_table_, huffman_builder_)) { |
121 return false; | |
122 } | |
115 | 123 |
116 if (candidate == kTerminalValue) { | 124 if (candidate == kTerminalValue) { |
117 DCHECK((sub_entries_end - start) == 1) | 125 if (sub_entries_end - start != 1) { |
118 << "Multiple values with the same name"; | 126 return false; |
119 WriteEntry((*start)->entry, &writer); | 127 } |
128 if (!WriteEntry((*start)->entry, &writer)) { | |
129 return false; | |
130 } | |
120 } else { | 131 } else { |
121 RemovePrefix(1, start, sub_entries_end); | 132 RemovePrefix(1, start, sub_entries_end); |
122 uint32_t position = WriteDispatchTables(start, sub_entries_end); | 133 uint32_t table_position; |
123 writer.WritePosition(position, &last_position); | 134 if (!WriteDispatchTables(start, sub_entries_end, &table_position)) { |
135 return false; | |
136 } | |
137 | |
138 writer.WritePosition(table_position, &last_position); | |
124 } | 139 } |
125 | 140 |
126 start = sub_entries_end; | 141 start = sub_entries_end; |
127 } | 142 } |
128 | 143 |
129 writer.WriteChar(kEndOfTableValue, huffman_table_, huffman_builder_); | 144 if (!writer.WriteChar(kEndOfTableValue, huffman_table_, huffman_builder_)) { |
145 return false; | |
146 } | |
130 | 147 |
131 uint32_t position = buffer_.position(); | 148 *position = buffer_.position(); |
132 writer.Flush(); | 149 writer.Flush(); |
133 writer.WriteToBitWriter(&buffer_); | 150 writer.WriteToBitWriter(&buffer_); |
134 return position; | 151 return true; |
135 } | 152 } |
136 | 153 |
137 void TrieWriter::WriteEntry(const TransportSecurityStateEntry* entry, | 154 bool TrieWriter::WriteEntry(const TransportSecurityStateEntry* entry, |
138 TrieBitBuffer* writer) { | 155 TrieBitBuffer* writer) { |
139 uint8_t include_subdomains = 0; | 156 uint8_t include_subdomains = 0; |
140 if (entry->include_subdomains) { | 157 if (entry->include_subdomains) { |
141 include_subdomains = 1; | 158 include_subdomains = 1; |
142 } | 159 } |
143 writer->WriteBit(include_subdomains); | 160 writer->WriteBit(include_subdomains); |
144 | 161 |
145 uint8_t force_https = 0; | 162 uint8_t force_https = 0; |
146 if (entry->force_https) { | 163 if (entry->force_https) { |
147 force_https = 1; | 164 force_https = 1; |
148 } | 165 } |
149 writer->WriteBit(force_https); | 166 writer->WriteBit(force_https); |
150 | 167 |
151 if (entry->pinset.size()) { | 168 if (entry->pinset.size()) { |
152 writer->WriteBit(1); | 169 writer->WriteBit(1); |
170 | |
153 NameIDMap::const_iterator pin_id_it = pinsets_map_.find(entry->pinset); | 171 NameIDMap::const_iterator pin_id_it = pinsets_map_.find(entry->pinset); |
154 DCHECK(pin_id_it != pinsets_map_.cend()) << "invalid pinset"; | 172 if (pin_id_it == pinsets_map_.cend()) { |
173 return false; | |
174 } | |
175 | |
155 const uint8_t& pin_id = pin_id_it->second; | 176 const uint8_t& pin_id = pin_id_it->second; |
156 DCHECK(pin_id <= 16) << "too many pinsets"; | 177 if (pin_id > 15) { |
martijnc
2017/02/08 20:58:21
The check was incorrect before, 4 bits can encode
| |
178 return false; | |
179 } | |
180 | |
157 writer->WriteBits(pin_id, 4); | 181 writer->WriteBits(pin_id, 4); |
158 | 182 |
159 NameIDMap::const_iterator domain_id_it = | 183 NameIDMap::const_iterator domain_id_it = |
160 domain_ids_map_.find(DomainConstant(entry->hostname)); | 184 domain_ids_map_.find(DomainConstant(entry->hostname)); |
161 DCHECK(domain_id_it != domain_ids_map_.cend()) << "invalid domain id"; | 185 if (domain_id_it == domain_ids_map_.cend()) { |
186 return false; | |
187 } | |
188 | |
162 uint32_t domain_id = domain_id_it->second; | 189 uint32_t domain_id = domain_id_it->second; |
163 DCHECK(domain_id < 512) << "too many domain ids"; | 190 if (domain_id > 511) { |
191 return false; | |
192 } | |
193 | |
164 writer->WriteBits(domain_id, 9); | 194 writer->WriteBits(domain_id, 9); |
165 | 195 |
166 if (!entry->include_subdomains) { | 196 if (!entry->include_subdomains) { |
167 uint8_t include_subdomains_for_pinning = 0; | 197 uint8_t include_subdomains_for_pinning = 0; |
168 if (entry->hpkp_include_subdomains) { | 198 if (entry->hpkp_include_subdomains) { |
169 include_subdomains_for_pinning = 1; | 199 include_subdomains_for_pinning = 1; |
170 } | 200 } |
171 writer->WriteBit(include_subdomains_for_pinning); | 201 writer->WriteBit(include_subdomains_for_pinning); |
172 } | 202 } |
173 } else { | 203 } else { |
174 writer->WriteBit(0); | 204 writer->WriteBit(0); |
175 } | 205 } |
176 | 206 |
177 if (entry->expect_ct) { | 207 if (entry->expect_ct) { |
178 writer->WriteBit(1); | 208 writer->WriteBit(1); |
179 NameIDMap::const_iterator expect_ct_report_uri_it = | 209 NameIDMap::const_iterator expect_ct_report_uri_it = |
180 expect_ct_report_uri_map_.find(entry->expect_ct_report_uri); | 210 expect_ct_report_uri_map_.find(entry->expect_ct_report_uri); |
181 DCHECK(expect_ct_report_uri_it != expect_ct_report_uri_map_.cend()) | 211 if (expect_ct_report_uri_it == expect_ct_report_uri_map_.cend()) { |
182 << "invalid expect-ct report-uri"; | 212 return false; |
213 } | |
214 | |
183 const uint8_t& expect_ct_report_id = expect_ct_report_uri_it->second; | 215 const uint8_t& expect_ct_report_id = expect_ct_report_uri_it->second; |
184 | 216 if (expect_ct_report_id > 15) { |
185 DCHECK(expect_ct_report_id < 16) << "too many expect-ct ids"; | 217 return false; |
218 } | |
186 | 219 |
187 writer->WriteBits(expect_ct_report_id, 4); | 220 writer->WriteBits(expect_ct_report_id, 4); |
188 } else { | 221 } else { |
189 writer->WriteBit(0); | 222 writer->WriteBit(0); |
190 } | 223 } |
191 | 224 |
192 if (entry->expect_staple) { | 225 if (entry->expect_staple) { |
193 writer->WriteBit(1); | 226 writer->WriteBit(1); |
194 | 227 |
195 if (entry->expect_staple_include_subdomains) { | 228 if (entry->expect_staple_include_subdomains) { |
196 writer->WriteBit(1); | 229 writer->WriteBit(1); |
197 } else { | 230 } else { |
198 writer->WriteBit(0); | 231 writer->WriteBit(0); |
199 } | 232 } |
200 | 233 |
201 NameIDMap::const_iterator expect_staple_report_uri_it = | 234 NameIDMap::const_iterator expect_staple_report_uri_it = |
202 expect_staple_report_uri_map_.find(entry->expect_staple_report_uri); | 235 expect_staple_report_uri_map_.find(entry->expect_staple_report_uri); |
203 DCHECK(expect_staple_report_uri_it != expect_staple_report_uri_map_.cend()) | 236 if (expect_staple_report_uri_it == expect_staple_report_uri_map_.cend()) { |
204 << "invalid expect-ct report-uri"; | 237 return false; |
238 } | |
239 | |
205 const uint8_t& expect_staple_report_id = | 240 const uint8_t& expect_staple_report_id = |
206 expect_staple_report_uri_it->second; | 241 expect_staple_report_uri_it->second; |
207 DCHECK(expect_staple_report_id < 16) << "too many expect-staple ids"; | 242 if (expect_staple_report_id > 15) { |
243 return false; | |
244 } | |
208 | 245 |
209 writer->WriteBits(expect_staple_report_id, 4); | 246 writer->WriteBits(expect_staple_report_id, 4); |
210 } else { | 247 } else { |
211 writer->WriteBit(0); | 248 writer->WriteBit(0); |
212 } | 249 } |
250 | |
251 return true; | |
213 } | 252 } |
214 | 253 |
215 void TrieWriter::RemovePrefix(size_t length, | 254 void TrieWriter::RemovePrefix(size_t length, |
216 ReversedEntries::iterator start, | 255 ReversedEntries::iterator start, |
217 ReversedEntries::iterator end) { | 256 ReversedEntries::iterator end) { |
218 for (ReversedEntries::iterator it = start; it != end; ++it) { | 257 for (ReversedEntries::iterator it = start; it != end; ++it) { |
219 (*it)->reversed_name.erase((*it)->reversed_name.begin(), | 258 (*it)->reversed_name.erase((*it)->reversed_name.begin(), |
220 (*it)->reversed_name.begin() + length); | 259 (*it)->reversed_name.begin() + length); |
221 } | 260 } |
222 } | 261 } |
223 | 262 |
224 std::vector<uint8_t> TrieWriter::LongestCommonPrefix( | 263 std::vector<uint8_t> TrieWriter::LongestCommonPrefix( |
225 ReversedEntries::iterator start, | 264 ReversedEntries::const_iterator start, |
226 ReversedEntries::iterator end) const { | 265 ReversedEntries::const_iterator end) const { |
227 if (start == end) { | 266 if (start == end) { |
228 return std::vector<uint8_t>(); | 267 return std::vector<uint8_t>(); |
229 } | 268 } |
230 | 269 |
231 std::vector<uint8_t> prefix; | 270 std::vector<uint8_t> prefix; |
232 for (size_t i = 0;; ++i) { | 271 for (size_t i = 0;; ++i) { |
233 if (i > (*start)->reversed_name.size()) { | 272 if (i > (*start)->reversed_name.size()) { |
234 break; | 273 break; |
235 } | 274 } |
236 | 275 |
237 uint8_t candidate = (*start)->reversed_name.at(i); | 276 uint8_t candidate = (*start)->reversed_name.at(i); |
238 if (candidate == kTerminalValue) { | 277 if (candidate == kTerminalValue) { |
239 break; | 278 break; |
240 } | 279 } |
241 | 280 |
242 bool ok = true; | 281 bool ok = true; |
243 for (ReversedEntries::iterator it = start + 1; it != end; ++it) { | 282 for (ReversedEntries::const_iterator it = start + 1; it != end; ++it) { |
244 if (i > (*it)->reversed_name.size() || | 283 if (i > (*it)->reversed_name.size() || |
245 (*it)->reversed_name.at(i) != candidate) { | 284 (*it)->reversed_name.at(i) != candidate) { |
246 ok = false; | 285 ok = false; |
247 break; | 286 break; |
248 } | 287 } |
249 } | 288 } |
250 | 289 |
251 if (!ok) { | 290 if (!ok) { |
252 break; | 291 break; |
253 } | 292 } |
(...skipping 21 matching lines...) Expand all Loading... | |
275 return buffer_.position(); | 314 return buffer_.position(); |
276 } | 315 } |
277 | 316 |
278 void TrieWriter::Flush() { | 317 void TrieWriter::Flush() { |
279 buffer_.Flush(); | 318 buffer_.Flush(); |
280 } | 319 } |
281 | 320 |
282 } // namespace transport_security_state | 321 } // namespace transport_security_state |
283 | 322 |
284 } // namespace net | 323 } // namespace net |
OLD | NEW |