OLD | NEW |
1 // Copyright 2017 The Chromium Authors. All rights reserved. | 1 // Copyright 2017 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 package webpagereplay | 5 package webpagereplay |
6 | 6 |
7 import ( | 7 import ( |
8 "bufio" | 8 "bufio" |
9 "bytes" | 9 "bytes" |
10 "compress/gzip" | 10 "compress/gzip" |
(...skipping 10 matching lines...) Expand all Loading... |
21 "sync" | 21 "sync" |
22 ) | 22 ) |
23 | 23 |
24 var ErrNotFound = errors.New("not found") | 24 var ErrNotFound = errors.New("not found") |
25 | 25 |
26 // ArchivedRequest contains a single request and its response. | 26 // ArchivedRequest contains a single request and its response. |
27 // Immutable after creation. | 27 // Immutable after creation. |
28 type ArchivedRequest struct { | 28 type ArchivedRequest struct { |
29 SerializedRequest []byte | 29 SerializedRequest []byte |
30 SerializedResponse []byte // if empty, the request failed | 30 SerializedResponse []byte // if empty, the request failed |
31 Proto string | |
32 } | 31 } |
33 | 32 |
34 func serializeRequest(req *http.Request, resp *http.Response, proto string) (*Ar
chivedRequest, error) { | 33 func serializeRequest(req *http.Request, resp *http.Response) (*ArchivedRequest,
error) { |
35 url := req.URL.String() | 34 url := req.URL.String() |
36 ar := &ArchivedRequest{} | 35 ar := &ArchivedRequest{} |
37 { | 36 { |
38 var buf bytes.Buffer | 37 var buf bytes.Buffer |
39 if err := req.Write(&buf); err != nil { | 38 if err := req.Write(&buf); err != nil { |
40 return nil, fmt.Errorf("failed writing request for %s: %
v", url, err) | 39 return nil, fmt.Errorf("failed writing request for %s: %
v", url, err) |
41 } | 40 } |
42 ar.SerializedRequest = buf.Bytes() | 41 ar.SerializedRequest = buf.Bytes() |
43 } | 42 } |
44 { | 43 { |
45 var buf bytes.Buffer | 44 var buf bytes.Buffer |
46 if err := resp.Write(&buf); err != nil { | 45 if err := resp.Write(&buf); err != nil { |
47 return nil, fmt.Errorf("failed writing response for %s:
%v", url, err) | 46 return nil, fmt.Errorf("failed writing response for %s:
%v", url, err) |
48 } | 47 } |
49 ar.SerializedResponse = buf.Bytes() | 48 ar.SerializedResponse = buf.Bytes() |
50 } | 49 } |
51 ar.Proto = proto | |
52 return ar, nil | 50 return ar, nil |
53 } | 51 } |
54 | 52 |
55 func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, string, e
rror) { | 53 func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, error) { |
56 req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(ar.Serializ
edRequest))) | 54 req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(ar.Serializ
edRequest))) |
57 if err != nil { | 55 if err != nil { |
58 » » return nil, nil, "", fmt.Errorf("couldn't unmarshal request: %v"
, err) | 56 » » return nil, nil, fmt.Errorf("couldn't unmarshal request: %v", er
r) |
59 } | 57 } |
60 resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(ar.Serial
izedResponse)), req) | 58 resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(ar.Serial
izedResponse)), req) |
61 if err != nil { | 59 if err != nil { |
62 if req.Body != nil { | 60 if req.Body != nil { |
63 req.Body.Close() | 61 req.Body.Close() |
64 } | 62 } |
65 » » return nil, nil, "", fmt.Errorf("couldn't unmarshal response: %v
", err) | 63 » » return nil, nil, fmt.Errorf("couldn't unmarshal response: %v", e
rr) |
66 } | 64 } |
67 » proto := ar.Proto | 65 » return req, resp, nil |
68 » return req, resp, proto, nil | |
69 } | 66 } |
70 | 67 |
71 // Archive contains an archive of requests. Immutable except when embedded in a
WritableArchive. | 68 // Archive contains an archive of requests. Immutable except when embedded in a
WritableArchive. |
72 // Fields are exported to enabled JSON encoding. | 69 // Fields are exported to enabled JSON encoding. |
73 type Archive struct { | 70 type Archive struct { |
74 // Requests maps host(url) => url => []request. | 71 // Requests maps host(url) => url => []request. |
75 // The two-level mapping makes it easier to search for similar requests. | 72 // The two-level mapping makes it easier to search for similar requests. |
76 // There may be multiple requests for a given URL. | 73 // There may be multiple requests for a given URL. |
77 Requests map[string]map[string][]*ArchivedRequest | 74 Requests map[string]map[string][]*ArchivedRequest |
78 » Certs map[string][]byte | 75 » // Maps host string to DER encoded certs. |
| 76 » Certs map[string][]byte |
| 77 » // Maps host string to the negotiated protocol. eg. "http/1.1" or "h2" |
| 78 » // If absent, will default to "http/1.1". |
| 79 » NegotiatedProtocol map[string]string |
79 } | 80 } |
80 | 81 |
81 func newArchive() Archive { | 82 func newArchive() Archive { |
82 return Archive{Requests: make(map[string]map[string][]*ArchivedRequest)} | 83 return Archive{Requests: make(map[string]map[string][]*ArchivedRequest)} |
83 } | 84 } |
84 | 85 |
85 // OpenArchive opens an archive file previously written by OpenWritableArchive. | 86 // OpenArchive opens an archive file previously written by OpenWritableArchive. |
86 func OpenArchive(path string) (*Archive, error) { | 87 func OpenArchive(path string) (*Archive, error) { |
87 f, err := os.Open(path) | 88 f, err := os.Open(path) |
88 if err != nil { | 89 if err != nil { |
(...skipping 15 matching lines...) Expand all Loading... |
104 return nil, fmt.Errorf("json unmarshal failed: %v", err) | 105 return nil, fmt.Errorf("json unmarshal failed: %v", err) |
105 } | 106 } |
106 return &a, nil | 107 return &a, nil |
107 } | 108 } |
108 | 109 |
109 // ForEach applies f to all requests in the archive. | 110 // ForEach applies f to all requests in the archive. |
110 func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) { | 111 func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) { |
111 for _, urlmap := range a.Requests { | 112 for _, urlmap := range a.Requests { |
112 for url, requests := range urlmap { | 113 for url, requests := range urlmap { |
113 for k, ar := range requests { | 114 for k, ar := range requests { |
114 » » » » req, resp, _, err := ar.unmarshal() | 115 » » » » req, resp, err := ar.unmarshal() |
115 if err != nil { | 116 if err != nil { |
116 log.Printf("Error unmarshaling request #
%d for %s: %v", k, url, err) | 117 log.Printf("Error unmarshaling request #
%d for %s: %v", k, url, err) |
117 continue | 118 continue |
118 } | 119 } |
119 f(req, resp) | 120 f(req, resp) |
120 } | 121 } |
121 } | 122 } |
122 } | 123 } |
123 } | 124 } |
124 | 125 |
125 func (a *Archive) FindHostCert(host string) ([]byte, error) { | 126 // Returns the der encoded cert and negotiated protocol. |
| 127 func (a *Archive) FindHostTlsConfig(host string) ([]byte, string, error) { |
126 if cert, ok := a.Certs[host]; ok { | 128 if cert, ok := a.Certs[host]; ok { |
127 » » return cert, nil | 129 » » return cert, a.findHostNegotiatedProtocol(host), nil |
128 } | 130 } |
129 » return nil, ErrNotFound | 131 » return nil, "", ErrNotFound |
| 132 } |
| 133 |
| 134 func (a *Archive) findHostNegotiatedProtocol(host string) string { |
| 135 » if negotiatedProtocol, ok := a.NegotiatedProtocol[host]; ok { |
| 136 » » return negotiatedProtocol |
| 137 » } |
| 138 » return "http/1.1" |
130 } | 139 } |
131 | 140 |
132 // FindRequest searches for the given request in the archive. | 141 // FindRequest searches for the given request in the archive. |
133 // Returns ErrNotFound if the request could not be found. Does not consume req.B
ody. | 142 // Returns ErrNotFound if the request could not be found. Does not consume req.B
ody. |
134 // TODO: header-based matching and conditional requests | 143 // TODO: header-based matching and conditional requests |
135 func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request,
*http.Response, string, error) { | 144 func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request,
*http.Response, error) { |
136 hostMap := a.Requests[req.Host] | 145 hostMap := a.Requests[req.Host] |
137 if len(hostMap) == 0 { | 146 if len(hostMap) == 0 { |
138 » » return nil, nil, "", ErrNotFound | 147 » » return nil, nil, ErrNotFound |
139 } | 148 } |
140 | 149 |
141 // Exact match. Note that req may be relative, but hostMap keys are alwa
ys absolute. | 150 // Exact match. Note that req may be relative, but hostMap keys are alwa
ys absolute. |
142 u := *req.URL | 151 u := *req.URL |
143 if u.Host == "" { | 152 if u.Host == "" { |
144 u.Host = req.Host | 153 u.Host = req.Host |
145 u.Scheme = scheme | 154 u.Scheme = scheme |
146 } | 155 } |
147 » if req, resp, proto, err := findExactMatch(hostMap[u.String()], req.Meth
od); err == nil { | 156 » if req, resp, err := findExactMatch(hostMap[u.String()], req.Method); er
r == nil { |
148 » » return req, resp, proto, nil | 157 » » return req, resp, nil |
149 } | 158 } |
150 | 159 |
151 // For all URLs with a matching path, pick the URL that has the most mat
ching query parameters. | 160 // For all URLs with a matching path, pick the URL that has the most mat
ching query parameters. |
152 // The match ratio is defined to be 2*M/T, where | 161 // The match ratio is defined to be 2*M/T, where |
153 // M = number of matches x where a.Query[x]=b.Query[x] | 162 // M = number of matches x where a.Query[x]=b.Query[x] |
154 // T = sum(len(a.Query)) + sum(len(b.Query)) | 163 // T = sum(len(a.Query)) + sum(len(b.Query)) |
155 aq := req.URL.Query() | 164 aq := req.URL.Query() |
156 | 165 |
157 var bestURL string | 166 var bestURL string |
158 var bestRatio float64 | 167 var bestRatio float64 |
(...skipping 18 matching lines...) Expand all Loading... |
177 if ratio > bestRatio { | 186 if ratio > bestRatio { |
178 bestURL = ustr | 187 bestURL = ustr |
179 } | 188 } |
180 } | 189 } |
181 | 190 |
182 // TODO: Try each until one succeeds with a matching request method. | 191 // TODO: Try each until one succeeds with a matching request method. |
183 if bestURL != "" { | 192 if bestURL != "" { |
184 return findExactMatch(hostMap[bestURL], req.Method) | 193 return findExactMatch(hostMap[bestURL], req.Method) |
185 } | 194 } |
186 | 195 |
187 » return nil, nil, "", ErrNotFound | 196 » return nil, nil, ErrNotFound |
188 } | 197 } |
189 | 198 |
190 // findExactMatch returns the first request that exactly matches the given reque
st method. | 199 // findExactMatch returns the first request that exactly matches the given reque
st method. |
191 func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request,
*http.Response, string, error) { | 200 func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request,
*http.Response, error) { |
192 for _, ar := range requests { | 201 for _, ar := range requests { |
193 » » req, resp, proto, err := ar.unmarshal() | 202 » » req, resp, err := ar.unmarshal() |
194 if err != nil { | 203 if err != nil { |
195 log.Printf("Error unmarshaling request: %v\nAR.Request:
%q\nAR.Response: %q", err, ar.SerializedRequest, ar.SerializedResponse) | 204 log.Printf("Error unmarshaling request: %v\nAR.Request:
%q\nAR.Response: %q", err, ar.SerializedRequest, ar.SerializedResponse) |
196 continue | 205 continue |
197 } | 206 } |
198 if req.Method == method { | 207 if req.Method == method { |
199 » » » return req, resp, proto, nil | 208 » » » return req, resp, nil |
200 } | 209 } |
201 } | 210 } |
202 | 211 |
203 » return nil, nil, "", ErrNotFound | 212 » return nil, nil, ErrNotFound |
204 } | 213 } |
205 | 214 |
206 func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *htt
p.Response, proto string) error { | 215 func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *htt
p.Response) error { |
207 » ar, err := serializeRequest(req, resp, proto) | 216 » ar, err := serializeRequest(req, resp) |
208 if err != nil { | 217 if err != nil { |
209 return err | 218 return err |
210 } | 219 } |
211 if a.Requests[req.Host] == nil { | 220 if a.Requests[req.Host] == nil { |
212 a.Requests[req.Host] = make(map[string][]*ArchivedRequest) | 221 a.Requests[req.Host] = make(map[string][]*ArchivedRequest) |
213 } | 222 } |
214 // Always use the absolute URL in this mapping. | 223 // Always use the absolute URL in this mapping. |
215 u := *req.URL | 224 u := *req.URL |
216 if u.Host == "" { | 225 if u.Host == "" { |
217 u.Host = req.Host | 226 u.Host = req.Host |
218 u.Scheme = scheme | 227 u.Scheme = scheme |
219 } | 228 } |
220 ustr := u.String() | 229 ustr := u.String() |
221 a.Requests[req.Host][ustr] = append(a.Requests[req.Host][ustr], ar) | 230 a.Requests[req.Host][ustr] = append(a.Requests[req.Host][ustr], ar) |
222 return nil | 231 return nil |
223 } | 232 } |
224 | 233 |
225 // Edit iterates over all requests in the archive. For each request, it calls f
to | 234 // Edit iterates over all requests in the archive. For each request, it calls f
to |
226 // edit the request. If f returns a nil pair, the request is deleted. | 235 // edit the request. If f returns a nil pair, the request is deleted. |
227 // The edited archive is returned, leaving the current archive is unchanged. | 236 // The edited archive is returned, leaving the current archive is unchanged. |
228 func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req
uest, *http.Response, error)) (*Archive, error) { | 237 func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req
uest, *http.Response, error)) (*Archive, error) { |
229 clone := newArchive() | 238 clone := newArchive() |
230 for _, urlmap := range a.Requests { | 239 for _, urlmap := range a.Requests { |
231 for ustr, requests := range urlmap { | 240 for ustr, requests := range urlmap { |
232 u, _ := url.Parse(ustr) | 241 u, _ := url.Parse(ustr) |
233 for k, ar := range requests { | 242 for k, ar := range requests { |
234 » » » » oldReq, oldResp, proto, err := ar.unmarshal() | 243 » » » » oldReq, oldResp, err := ar.unmarshal() |
235 if err != nil { | 244 if err != nil { |
236 return nil, fmt.Errorf("Error unmarshali
ng request #%d for %s: %v", k, ustr, err) | 245 return nil, fmt.Errorf("Error unmarshali
ng request #%d for %s: %v", k, ustr, err) |
237 } | 246 } |
238 newReq, newResp, err := f(oldReq, oldResp) | 247 newReq, newResp, err := f(oldReq, oldResp) |
239 if err != nil { | 248 if err != nil { |
240 return nil, err | 249 return nil, err |
241 } | 250 } |
242 if newReq == nil || newResp == nil { | 251 if newReq == nil || newResp == nil { |
243 if newReq != nil || newResp != nil { | 252 if newReq != nil || newResp != nil { |
244 panic("programming error: newReq
/newResp must both be nil or non-nil") | 253 panic("programming error: newReq
/newResp must both be nil or non-nil") |
245 } | 254 } |
246 continue | 255 continue |
247 } | 256 } |
248 // TODO: allow changing scheme or protocol? | 257 // TODO: allow changing scheme or protocol? |
249 » » » » if err := clone.addArchivedRequest(u.Scheme, new
Req, newResp, proto); err != nil { | 258 » » » » if err := clone.addArchivedRequest(u.Scheme, new
Req, newResp); err != nil { |
250 return nil, err | 259 return nil, err |
251 } | 260 } |
252 } | 261 } |
253 } | 262 } |
254 } | 263 } |
255 return &clone, nil | 264 return &clone, nil |
256 } | 265 } |
257 | 266 |
258 // Serialize serializes this archive to the given writer. | 267 // Serialize serializes this archive to the given writer. |
259 func (a *Archive) Serialize(w io.Writer) error { | 268 func (a *Archive) Serialize(w io.Writer) error { |
(...skipping 19 matching lines...) Expand all Loading... |
279 if err != nil { | 288 if err != nil { |
280 return nil, fmt.Errorf("could not open %s: %v", path, err) | 289 return nil, fmt.Errorf("could not open %s: %v", path, err) |
281 } | 290 } |
282 return &WritableArchive{Archive: newArchive(), f: f}, nil | 291 return &WritableArchive{Archive: newArchive(), f: f}, nil |
283 } | 292 } |
284 | 293 |
285 // RecordRequest records a request/response pair in the archive. | 294 // RecordRequest records a request/response pair in the archive. |
286 func (a *WritableArchive) RecordRequest(scheme string, req *http.Request, resp *
http.Response) error { | 295 func (a *WritableArchive) RecordRequest(scheme string, req *http.Request, resp *
http.Response) error { |
287 a.mu.Lock() | 296 a.mu.Lock() |
288 defer a.mu.Unlock() | 297 defer a.mu.Unlock() |
289 » proto := "" | 298 » return a.addArchivedRequest(scheme, req, resp) |
290 » if resp.TLS != nil && resp.TLS.NegotiatedProtocolIsMutual { | |
291 » » proto = resp.TLS.NegotiatedProtocol | |
292 » } | |
293 » return a.addArchivedRequest(scheme, req, resp, proto) | |
294 } | 299 } |
295 | 300 |
296 // RecordCert records a cert in the archive. | 301 // RecordTlsConfig records the cert used and protocol negotiated for a host. |
297 func (a *WritableArchive) RecordCert(host string, der_bytes []byte) { | 302 func (a *WritableArchive) RecordTlsConfig(host string, der_bytes []byte, negotia
tedProtocol string) { |
298 a.mu.Lock() | 303 a.mu.Lock() |
299 defer a.mu.Unlock() | 304 defer a.mu.Unlock() |
300 if a.Certs == nil { | 305 if a.Certs == nil { |
301 a.Certs = make(map[string][]byte) | 306 a.Certs = make(map[string][]byte) |
302 } | 307 } |
303 if _, ok := a.Certs[host]; !ok { | 308 if _, ok := a.Certs[host]; !ok { |
304 fmt.Printf("Recorded cert for %s", host) | |
305 a.Certs[host] = der_bytes | 309 a.Certs[host] = der_bytes |
306 } | 310 } |
307 } | 311 » if a.NegotiatedProtocol == nil { |
308 | 312 » » a.NegotiatedProtocol = make(map[string]string) |
309 func (a *WritableArchive) FindHostCert(host string) ([]byte, error) { | |
310 » a.mu.Lock() | |
311 » defer a.mu.Unlock() | |
312 » if cert, ok := a.Archive.Certs[host]; ok { | |
313 » » return cert, nil | |
314 } | 313 } |
315 » return nil, ErrNotFound | 314 » a.NegotiatedProtocol[host] = negotiatedProtocol |
316 } | 315 } |
317 | 316 |
318 // Close flushes the the archive and closes the output file. | 317 // Close flushes the the archive and closes the output file. |
319 func (a *WritableArchive) Close() error { | 318 func (a *WritableArchive) Close() error { |
320 a.mu.Lock() | 319 a.mu.Lock() |
321 defer a.mu.Unlock() | 320 defer a.mu.Unlock() |
322 defer func() { a.f = nil }() | 321 defer func() { a.f = nil }() |
323 if a.f == nil { | 322 if a.f == nil { |
324 return errors.New("already closed") | 323 return errors.New("already closed") |
325 } | 324 } |
326 | 325 |
327 if err := a.Serialize(a.f); err != nil { | 326 if err := a.Serialize(a.f); err != nil { |
328 return err | 327 return err |
329 } | 328 } |
330 return a.f.Close() | 329 return a.f.Close() |
331 } | 330 } |
OLD | NEW |