| OLD | NEW |
| 1 // Copyright 2017 The Chromium Authors. All rights reserved. | 1 // Copyright 2017 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 package webpagereplay | 5 package webpagereplay |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "bufio" | 8 "bufio" |
| 9 "bytes" | 9 "bytes" |
| 10 "compress/gzip" | 10 "compress/gzip" |
| (...skipping 10 matching lines...) Expand all Loading... |
| 21 "sync" | 21 "sync" |
| 22 ) | 22 ) |
| 23 | 23 |
| 24 var ErrNotFound = errors.New("not found") | 24 var ErrNotFound = errors.New("not found") |
| 25 | 25 |
| 26 // ArchivedRequest contains a single request and its response. | 26 // ArchivedRequest contains a single request and its response. |
| 27 // Immutable after creation. | 27 // Immutable after creation. |
| 28 type ArchivedRequest struct { | 28 type ArchivedRequest struct { |
| 29 SerializedRequest []byte | 29 SerializedRequest []byte |
| 30 SerializedResponse []byte // if empty, the request failed | 30 SerializedResponse []byte // if empty, the request failed |
| 31 Proto string | |
| 32 } | 31 } |
| 33 | 32 |
| 34 func serializeRequest(req *http.Request, resp *http.Response, proto string) (*Ar
chivedRequest, error) { | 33 func serializeRequest(req *http.Request, resp *http.Response) (*ArchivedRequest,
error) { |
| 35 url := req.URL.String() | 34 url := req.URL.String() |
| 36 ar := &ArchivedRequest{} | 35 ar := &ArchivedRequest{} |
| 37 { | 36 { |
| 38 var buf bytes.Buffer | 37 var buf bytes.Buffer |
| 39 if err := req.Write(&buf); err != nil { | 38 if err := req.Write(&buf); err != nil { |
| 40 return nil, fmt.Errorf("failed writing request for %s: %
v", url, err) | 39 return nil, fmt.Errorf("failed writing request for %s: %
v", url, err) |
| 41 } | 40 } |
| 42 ar.SerializedRequest = buf.Bytes() | 41 ar.SerializedRequest = buf.Bytes() |
| 43 } | 42 } |
| 44 { | 43 { |
| 45 var buf bytes.Buffer | 44 var buf bytes.Buffer |
| 46 if err := resp.Write(&buf); err != nil { | 45 if err := resp.Write(&buf); err != nil { |
| 47 return nil, fmt.Errorf("failed writing response for %s:
%v", url, err) | 46 return nil, fmt.Errorf("failed writing response for %s:
%v", url, err) |
| 48 } | 47 } |
| 49 ar.SerializedResponse = buf.Bytes() | 48 ar.SerializedResponse = buf.Bytes() |
| 50 } | 49 } |
| 51 ar.Proto = proto | |
| 52 return ar, nil | 50 return ar, nil |
| 53 } | 51 } |
| 54 | 52 |
| 55 func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, string, e
rror) { | 53 func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, error) { |
| 56 req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(ar.Serializ
edRequest))) | 54 req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(ar.Serializ
edRequest))) |
| 57 if err != nil { | 55 if err != nil { |
| 58 » » return nil, nil, "", fmt.Errorf("couldn't unmarshal request: %v"
, err) | 56 » » return nil, nil, fmt.Errorf("couldn't unmarshal request: %v", er
r) |
| 59 } | 57 } |
| 60 resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(ar.Serial
izedResponse)), req) | 58 resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(ar.Serial
izedResponse)), req) |
| 61 if err != nil { | 59 if err != nil { |
| 62 if req.Body != nil { | 60 if req.Body != nil { |
| 63 req.Body.Close() | 61 req.Body.Close() |
| 64 } | 62 } |
| 65 » » return nil, nil, "", fmt.Errorf("couldn't unmarshal response: %v
", err) | 63 » » return nil, nil, fmt.Errorf("couldn't unmarshal response: %v", e
rr) |
| 66 } | 64 } |
| 67 » proto := ar.Proto | 65 » return req, resp, nil |
| 68 » return req, resp, proto, nil | |
| 69 } | 66 } |
| 70 | 67 |
| 71 // Archive contains an archive of requests. Immutable except when embedded in a
WritableArchive. | 68 // Archive contains an archive of requests. Immutable except when embedded in a
WritableArchive. |
| 72 // Fields are exported to enabled JSON encoding. | 69 // Fields are exported to enabled JSON encoding. |
| 73 type Archive struct { | 70 type Archive struct { |
| 74 // Requests maps host(url) => url => []request. | 71 // Requests maps host(url) => url => []request. |
| 75 // The two-level mapping makes it easier to search for similar requests. | 72 // The two-level mapping makes it easier to search for similar requests. |
| 76 // There may be multiple requests for a given URL. | 73 // There may be multiple requests for a given URL. |
| 77 Requests map[string]map[string][]*ArchivedRequest | 74 Requests map[string]map[string][]*ArchivedRequest |
| 78 » Certs map[string][]byte | 75 » // Maps host string to DER encoded certs. |
| 76 » Certs map[string][]byte |
| 77 » // Maps host string to the negotiated protocol. eg. "http/1.1" or "h2" |
| 78 » // If absent, will default to "http/1.1". |
| 79 » NegotiatedProtocol map[string]string |
| 79 } | 80 } |
| 80 | 81 |
| 81 func newArchive() Archive { | 82 func newArchive() Archive { |
| 82 return Archive{Requests: make(map[string]map[string][]*ArchivedRequest)} | 83 return Archive{Requests: make(map[string]map[string][]*ArchivedRequest)} |
| 83 } | 84 } |
| 84 | 85 |
| 85 // OpenArchive opens an archive file previously written by OpenWritableArchive. | 86 // OpenArchive opens an archive file previously written by OpenWritableArchive. |
| 86 func OpenArchive(path string) (*Archive, error) { | 87 func OpenArchive(path string) (*Archive, error) { |
| 87 f, err := os.Open(path) | 88 f, err := os.Open(path) |
| 88 if err != nil { | 89 if err != nil { |
| (...skipping 15 matching lines...) Expand all Loading... |
| 104 return nil, fmt.Errorf("json unmarshal failed: %v", err) | 105 return nil, fmt.Errorf("json unmarshal failed: %v", err) |
| 105 } | 106 } |
| 106 return &a, nil | 107 return &a, nil |
| 107 } | 108 } |
| 108 | 109 |
| 109 // ForEach applies f to all requests in the archive. | 110 // ForEach applies f to all requests in the archive. |
| 110 func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) { | 111 func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) { |
| 111 for _, urlmap := range a.Requests { | 112 for _, urlmap := range a.Requests { |
| 112 for url, requests := range urlmap { | 113 for url, requests := range urlmap { |
| 113 for k, ar := range requests { | 114 for k, ar := range requests { |
| 114 » » » » req, resp, _, err := ar.unmarshal() | 115 » » » » req, resp, err := ar.unmarshal() |
| 115 if err != nil { | 116 if err != nil { |
| 116 log.Printf("Error unmarshaling request #
%d for %s: %v", k, url, err) | 117 log.Printf("Error unmarshaling request #
%d for %s: %v", k, url, err) |
| 117 continue | 118 continue |
| 118 } | 119 } |
| 119 f(req, resp) | 120 f(req, resp) |
| 120 } | 121 } |
| 121 } | 122 } |
| 122 } | 123 } |
| 123 } | 124 } |
| 124 | 125 |
| 125 func (a *Archive) FindHostCert(host string) ([]byte, error) { | 126 // Returns the der encoded cert and negotiated protocol. |
| 127 func (a *Archive) FindHostTlsConfig(host string) ([]byte, string, error) { |
| 126 if cert, ok := a.Certs[host]; ok { | 128 if cert, ok := a.Certs[host]; ok { |
| 127 » » return cert, nil | 129 » » return cert, a.findHostNegotiatedProtocol(host), nil |
| 128 } | 130 } |
| 129 » return nil, ErrNotFound | 131 » return nil, "", ErrNotFound |
| 132 } |
| 133 |
| 134 func (a *Archive) findHostNegotiatedProtocol(host string) string { |
| 135 » if negotiatedProtocol, ok := a.NegotiatedProtocol[host]; ok { |
| 136 » » return negotiatedProtocol |
| 137 » } |
| 138 » return "http/1.1" |
| 130 } | 139 } |
| 131 | 140 |
| 132 // FindRequest searches for the given request in the archive. | 141 // FindRequest searches for the given request in the archive. |
| 133 // Returns ErrNotFound if the request could not be found. Does not consume req.B
ody. | 142 // Returns ErrNotFound if the request could not be found. Does not consume req.B
ody. |
| 134 // TODO: header-based matching and conditional requests | 143 // TODO: header-based matching and conditional requests |
| 135 func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request,
*http.Response, string, error) { | 144 func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request,
*http.Response, error) { |
| 136 hostMap := a.Requests[req.Host] | 145 hostMap := a.Requests[req.Host] |
| 137 if len(hostMap) == 0 { | 146 if len(hostMap) == 0 { |
| 138 » » return nil, nil, "", ErrNotFound | 147 » » return nil, nil, ErrNotFound |
| 139 } | 148 } |
| 140 | 149 |
| 141 // Exact match. Note that req may be relative, but hostMap keys are alwa
ys absolute. | 150 // Exact match. Note that req may be relative, but hostMap keys are alwa
ys absolute. |
| 142 u := *req.URL | 151 u := *req.URL |
| 143 if u.Host == "" { | 152 if u.Host == "" { |
| 144 u.Host = req.Host | 153 u.Host = req.Host |
| 145 u.Scheme = scheme | 154 u.Scheme = scheme |
| 146 } | 155 } |
| 147 » if req, resp, proto, err := findExactMatch(hostMap[u.String()], req.Meth
od); err == nil { | 156 » if req, resp, err := findExactMatch(hostMap[u.String()], req.Method); er
r == nil { |
| 148 » » return req, resp, proto, nil | 157 » » return req, resp, nil |
| 149 } | 158 } |
| 150 | 159 |
| 151 // For all URLs with a matching path, pick the URL that has the most mat
ching query parameters. | 160 // For all URLs with a matching path, pick the URL that has the most mat
ching query parameters. |
| 152 // The match ratio is defined to be 2*M/T, where | 161 // The match ratio is defined to be 2*M/T, where |
| 153 // M = number of matches x where a.Query[x]=b.Query[x] | 162 // M = number of matches x where a.Query[x]=b.Query[x] |
| 154 // T = sum(len(a.Query)) + sum(len(b.Query)) | 163 // T = sum(len(a.Query)) + sum(len(b.Query)) |
| 155 aq := req.URL.Query() | 164 aq := req.URL.Query() |
| 156 | 165 |
| 157 var bestURL string | 166 var bestURL string |
| 158 var bestRatio float64 | 167 var bestRatio float64 |
| (...skipping 18 matching lines...) Expand all Loading... |
| 177 if ratio > bestRatio { | 186 if ratio > bestRatio { |
| 178 bestURL = ustr | 187 bestURL = ustr |
| 179 } | 188 } |
| 180 } | 189 } |
| 181 | 190 |
| 182 // TODO: Try each until one succeeds with a matching request method. | 191 // TODO: Try each until one succeeds with a matching request method. |
| 183 if bestURL != "" { | 192 if bestURL != "" { |
| 184 return findExactMatch(hostMap[bestURL], req.Method) | 193 return findExactMatch(hostMap[bestURL], req.Method) |
| 185 } | 194 } |
| 186 | 195 |
| 187 » return nil, nil, "", ErrNotFound | 196 » return nil, nil, ErrNotFound |
| 188 } | 197 } |
| 189 | 198 |
| 190 // findExactMatch returns the first request that exactly matches the given reque
st method. | 199 // findExactMatch returns the first request that exactly matches the given reque
st method. |
| 191 func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request,
*http.Response, string, error) { | 200 func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request,
*http.Response, error) { |
| 192 for _, ar := range requests { | 201 for _, ar := range requests { |
| 193 » » req, resp, proto, err := ar.unmarshal() | 202 » » req, resp, err := ar.unmarshal() |
| 194 if err != nil { | 203 if err != nil { |
| 195 log.Printf("Error unmarshaling request: %v\nAR.Request:
%q\nAR.Response: %q", err, ar.SerializedRequest, ar.SerializedResponse) | 204 log.Printf("Error unmarshaling request: %v\nAR.Request:
%q\nAR.Response: %q", err, ar.SerializedRequest, ar.SerializedResponse) |
| 196 continue | 205 continue |
| 197 } | 206 } |
| 198 if req.Method == method { | 207 if req.Method == method { |
| 199 » » » return req, resp, proto, nil | 208 » » » return req, resp, nil |
| 200 } | 209 } |
| 201 } | 210 } |
| 202 | 211 |
| 203 » return nil, nil, "", ErrNotFound | 212 » return nil, nil, ErrNotFound |
| 204 } | 213 } |
| 205 | 214 |
| 206 func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *htt
p.Response, proto string) error { | 215 func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *htt
p.Response) error { |
| 207 » ar, err := serializeRequest(req, resp, proto) | 216 » ar, err := serializeRequest(req, resp) |
| 208 if err != nil { | 217 if err != nil { |
| 209 return err | 218 return err |
| 210 } | 219 } |
| 211 if a.Requests[req.Host] == nil { | 220 if a.Requests[req.Host] == nil { |
| 212 a.Requests[req.Host] = make(map[string][]*ArchivedRequest) | 221 a.Requests[req.Host] = make(map[string][]*ArchivedRequest) |
| 213 } | 222 } |
| 214 // Always use the absolute URL in this mapping. | 223 // Always use the absolute URL in this mapping. |
| 215 u := *req.URL | 224 u := *req.URL |
| 216 if u.Host == "" { | 225 if u.Host == "" { |
| 217 u.Host = req.Host | 226 u.Host = req.Host |
| 218 u.Scheme = scheme | 227 u.Scheme = scheme |
| 219 } | 228 } |
| 220 ustr := u.String() | 229 ustr := u.String() |
| 221 a.Requests[req.Host][ustr] = append(a.Requests[req.Host][ustr], ar) | 230 a.Requests[req.Host][ustr] = append(a.Requests[req.Host][ustr], ar) |
| 222 return nil | 231 return nil |
| 223 } | 232 } |
| 224 | 233 |
| 225 // Edit iterates over all requests in the archive. For each request, it calls f
to | 234 // Edit iterates over all requests in the archive. For each request, it calls f
to |
| 226 // edit the request. If f returns a nil pair, the request is deleted. | 235 // edit the request. If f returns a nil pair, the request is deleted. |
| 227 // The edited archive is returned, leaving the current archive is unchanged. | 236 // The edited archive is returned, leaving the current archive is unchanged. |
| 228 func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req
uest, *http.Response, error)) (*Archive, error) { | 237 func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req
uest, *http.Response, error)) (*Archive, error) { |
| 229 clone := newArchive() | 238 clone := newArchive() |
| 230 for _, urlmap := range a.Requests { | 239 for _, urlmap := range a.Requests { |
| 231 for ustr, requests := range urlmap { | 240 for ustr, requests := range urlmap { |
| 232 u, _ := url.Parse(ustr) | 241 u, _ := url.Parse(ustr) |
| 233 for k, ar := range requests { | 242 for k, ar := range requests { |
| 234 » » » » oldReq, oldResp, proto, err := ar.unmarshal() | 243 » » » » oldReq, oldResp, err := ar.unmarshal() |
| 235 if err != nil { | 244 if err != nil { |
| 236 return nil, fmt.Errorf("Error unmarshali
ng request #%d for %s: %v", k, ustr, err) | 245 return nil, fmt.Errorf("Error unmarshali
ng request #%d for %s: %v", k, ustr, err) |
| 237 } | 246 } |
| 238 newReq, newResp, err := f(oldReq, oldResp) | 247 newReq, newResp, err := f(oldReq, oldResp) |
| 239 if err != nil { | 248 if err != nil { |
| 240 return nil, err | 249 return nil, err |
| 241 } | 250 } |
| 242 if newReq == nil || newResp == nil { | 251 if newReq == nil || newResp == nil { |
| 243 if newReq != nil || newResp != nil { | 252 if newReq != nil || newResp != nil { |
| 244 panic("programming error: newReq
/newResp must both be nil or non-nil") | 253 panic("programming error: newReq
/newResp must both be nil or non-nil") |
| 245 } | 254 } |
| 246 continue | 255 continue |
| 247 } | 256 } |
| 248 // TODO: allow changing scheme or protocol? | 257 // TODO: allow changing scheme or protocol? |
| 249 » » » » if err := clone.addArchivedRequest(u.Scheme, new
Req, newResp, proto); err != nil { | 258 » » » » if err := clone.addArchivedRequest(u.Scheme, new
Req, newResp); err != nil { |
| 250 return nil, err | 259 return nil, err |
| 251 } | 260 } |
| 252 } | 261 } |
| 253 } | 262 } |
| 254 } | 263 } |
| 255 return &clone, nil | 264 return &clone, nil |
| 256 } | 265 } |
| 257 | 266 |
| 258 // Serialize serializes this archive to the given writer. | 267 // Serialize serializes this archive to the given writer. |
| 259 func (a *Archive) Serialize(w io.Writer) error { | 268 func (a *Archive) Serialize(w io.Writer) error { |
| (...skipping 19 matching lines...) Expand all Loading... |
| 279 if err != nil { | 288 if err != nil { |
| 280 return nil, fmt.Errorf("could not open %s: %v", path, err) | 289 return nil, fmt.Errorf("could not open %s: %v", path, err) |
| 281 } | 290 } |
| 282 return &WritableArchive{Archive: newArchive(), f: f}, nil | 291 return &WritableArchive{Archive: newArchive(), f: f}, nil |
| 283 } | 292 } |
| 284 | 293 |
| 285 // RecordRequest records a request/response pair in the archive. | 294 // RecordRequest records a request/response pair in the archive. |
| 286 func (a *WritableArchive) RecordRequest(scheme string, req *http.Request, resp *
http.Response) error { | 295 func (a *WritableArchive) RecordRequest(scheme string, req *http.Request, resp *
http.Response) error { |
| 287 a.mu.Lock() | 296 a.mu.Lock() |
| 288 defer a.mu.Unlock() | 297 defer a.mu.Unlock() |
| 289 » proto := "" | 298 » return a.addArchivedRequest(scheme, req, resp) |
| 290 » if resp.TLS != nil && resp.TLS.NegotiatedProtocolIsMutual { | |
| 291 » » proto = resp.TLS.NegotiatedProtocol | |
| 292 » } | |
| 293 » return a.addArchivedRequest(scheme, req, resp, proto) | |
| 294 } | 299 } |
| 295 | 300 |
| 296 // RecordCert records a cert in the archive. | 301 // RecordTlsConfig records the cert used and protocol negotiated for a host. |
| 297 func (a *WritableArchive) RecordCert(host string, der_bytes []byte) { | 302 func (a *WritableArchive) RecordTlsConfig(host string, der_bytes []byte, negotia
tedProtocol string) { |
| 298 a.mu.Lock() | 303 a.mu.Lock() |
| 299 defer a.mu.Unlock() | 304 defer a.mu.Unlock() |
| 300 if a.Certs == nil { | 305 if a.Certs == nil { |
| 301 a.Certs = make(map[string][]byte) | 306 a.Certs = make(map[string][]byte) |
| 302 } | 307 } |
| 303 if _, ok := a.Certs[host]; !ok { | 308 if _, ok := a.Certs[host]; !ok { |
| 304 fmt.Printf("Recorded cert for %s", host) | |
| 305 a.Certs[host] = der_bytes | 309 a.Certs[host] = der_bytes |
| 306 } | 310 } |
| 307 } | 311 » if a.NegotiatedProtocol == nil { |
| 308 | 312 » » a.NegotiatedProtocol = make(map[string]string) |
| 309 func (a *WritableArchive) FindHostCert(host string) ([]byte, error) { | |
| 310 » a.mu.Lock() | |
| 311 » defer a.mu.Unlock() | |
| 312 » if cert, ok := a.Archive.Certs[host]; ok { | |
| 313 » » return cert, nil | |
| 314 } | 313 } |
| 315 » return nil, ErrNotFound | 314 » a.NegotiatedProtocol[host] = negotiatedProtocol |
| 316 } | 315 } |
| 317 | 316 |
| 318 // Close flushes the the archive and closes the output file. | 317 // Close flushes the the archive and closes the output file. |
| 319 func (a *WritableArchive) Close() error { | 318 func (a *WritableArchive) Close() error { |
| 320 a.mu.Lock() | 319 a.mu.Lock() |
| 321 defer a.mu.Unlock() | 320 defer a.mu.Unlock() |
| 322 defer func() { a.f = nil }() | 321 defer func() { a.f = nil }() |
| 323 if a.f == nil { | 322 if a.f == nil { |
| 324 return errors.New("already closed") | 323 return errors.New("already closed") |
| 325 } | 324 } |
| 326 | 325 |
| 327 if err := a.Serialize(a.f); err != nil { | 326 if err := a.Serialize(a.f); err != nil { |
| 328 return err | 327 return err |
| 329 } | 328 } |
| 330 return a.f.Close() | 329 return a.f.Close() |
| 331 } | 330 } |
| OLD | NEW |