| Index: web_page_replay_go/src/webpagereplay/archive.go
|
| diff --git a/web_page_replay_go/src/webpagereplay/archive.go b/web_page_replay_go/src/webpagereplay/archive.go
|
| index 5a5dd23e26e112ff30186adca57fe12c66bcb6e5..d0eee49a9e6af6391ff06e1f3b37a2f3a16972f6 100644
|
| --- a/web_page_replay_go/src/webpagereplay/archive.go
|
| +++ b/web_page_replay_go/src/webpagereplay/archive.go
|
| @@ -28,10 +28,9 @@ var ErrNotFound = errors.New("not found")
|
| type ArchivedRequest struct {
|
| SerializedRequest []byte
|
| SerializedResponse []byte // if empty, the request failed
|
| - Proto string
|
| }
|
|
|
| -func serializeRequest(req *http.Request, resp *http.Response, proto string) (*ArchivedRequest, error) {
|
| +func serializeRequest(req *http.Request, resp *http.Response) (*ArchivedRequest, error) {
|
| url := req.URL.String()
|
| ar := &ArchivedRequest{}
|
| {
|
| @@ -48,24 +47,22 @@ func serializeRequest(req *http.Request, resp *http.Response, proto string) (*Ar
|
| }
|
| ar.SerializedResponse = buf.Bytes()
|
| }
|
| - ar.Proto = proto
|
| return ar, nil
|
| }
|
|
|
| -func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, string, error) {
|
| +func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, error) {
|
| req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(ar.SerializedRequest)))
|
| if err != nil {
|
| - return nil, nil, "", fmt.Errorf("couldn't unmarshal request: %v", err)
|
| + return nil, nil, fmt.Errorf("couldn't unmarshal request: %v", err)
|
| }
|
| resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(ar.SerializedResponse)), req)
|
| if err != nil {
|
| if req.Body != nil {
|
| req.Body.Close()
|
| }
|
| - return nil, nil, "", fmt.Errorf("couldn't unmarshal response: %v", err)
|
| + return nil, nil, fmt.Errorf("couldn't unmarshal response: %v", err)
|
| }
|
| - proto := ar.Proto
|
| - return req, resp, proto, nil
|
| + return req, resp, nil
|
| }
|
|
|
| // Archive contains an archive of requests. Immutable except when embedded in a WritableArchive.
|
| @@ -75,7 +72,11 @@ type Archive struct {
|
| // The two-level mapping makes it easier to search for similar requests.
|
| // There may be multiple requests for a given URL.
|
| Requests map[string]map[string][]*ArchivedRequest
|
| - Certs map[string][]byte
|
| + // Maps host string to DER encoded certs.
|
| + Certs map[string][]byte
|
| + // Maps host string to the negotiated protocol. eg. "http/1.1" or "h2"
|
| + // If absent, will default to "http/1.1".
|
| + NegotiatedProtocol map[string]string
|
| }
|
|
|
| func newArchive() Archive {
|
| @@ -111,7 +112,7 @@ func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) {
|
| for _, urlmap := range a.Requests {
|
| for url, requests := range urlmap {
|
| for k, ar := range requests {
|
| - req, resp, _, err := ar.unmarshal()
|
| + req, resp, err := ar.unmarshal()
|
| if err != nil {
|
| log.Printf("Error unmarshaling request #%d for %s: %v", k, url, err)
|
| continue
|
| @@ -122,20 +123,28 @@ func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) {
|
| }
|
| }
|
|
|
| -func (a *Archive) FindHostCert(host string) ([]byte, error) {
|
| +// Returns the der encoded cert and negotiated protocol.
|
| +func (a *Archive) FindHostTlsConfig(host string) ([]byte, string, error) {
|
| if cert, ok := a.Certs[host]; ok {
|
| - return cert, nil
|
| + return cert, a.findHostNegotiatedProtocol(host), nil
|
| }
|
| - return nil, ErrNotFound
|
| + return nil, "", ErrNotFound
|
| +}
|
| +
|
| +func (a *Archive) findHostNegotiatedProtocol(host string) string {
|
| + if negotiatedProtocol, ok := a.NegotiatedProtocol[host]; ok {
|
| + return negotiatedProtocol
|
| + }
|
| + return "http/1.1"
|
| }
|
|
|
| // FindRequest searches for the given request in the archive.
|
| // Returns ErrNotFound if the request could not be found. Does not consume req.Body.
|
| // TODO: header-based matching and conditional requests
|
| -func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request, *http.Response, string, error) {
|
| +func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request, *http.Response, error) {
|
| hostMap := a.Requests[req.Host]
|
| if len(hostMap) == 0 {
|
| - return nil, nil, "", ErrNotFound
|
| + return nil, nil, ErrNotFound
|
| }
|
|
|
| // Exact match. Note that req may be relative, but hostMap keys are always absolute.
|
| @@ -144,8 +153,8 @@ func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request,
|
| u.Host = req.Host
|
| u.Scheme = scheme
|
| }
|
| - if req, resp, proto, err := findExactMatch(hostMap[u.String()], req.Method); err == nil {
|
| - return req, resp, proto, nil
|
| + if req, resp, err := findExactMatch(hostMap[u.String()], req.Method); err == nil {
|
| + return req, resp, nil
|
| }
|
|
|
| // For all URLs with a matching path, pick the URL that has the most matching query parameters.
|
| @@ -184,27 +193,27 @@ func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request,
|
| return findExactMatch(hostMap[bestURL], req.Method)
|
| }
|
|
|
| - return nil, nil, "", ErrNotFound
|
| + return nil, nil, ErrNotFound
|
| }
|
|
|
| // findExactMatch returns the first request that exactly matches the given request method.
|
| -func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request, *http.Response, string, error) {
|
| +func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request, *http.Response, error) {
|
| for _, ar := range requests {
|
| - req, resp, proto, err := ar.unmarshal()
|
| + req, resp, err := ar.unmarshal()
|
| if err != nil {
|
| log.Printf("Error unmarshaling request: %v\nAR.Request: %q\nAR.Response: %q", err, ar.SerializedRequest, ar.SerializedResponse)
|
| continue
|
| }
|
| if req.Method == method {
|
| - return req, resp, proto, nil
|
| + return req, resp, nil
|
| }
|
| }
|
|
|
| - return nil, nil, "", ErrNotFound
|
| + return nil, nil, ErrNotFound
|
| }
|
|
|
| -func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *http.Response, proto string) error {
|
| - ar, err := serializeRequest(req, resp, proto)
|
| +func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *http.Response) error {
|
| + ar, err := serializeRequest(req, resp)
|
| if err != nil {
|
| return err
|
| }
|
| @@ -231,7 +240,7 @@ func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req
|
| for ustr, requests := range urlmap {
|
| u, _ := url.Parse(ustr)
|
| for k, ar := range requests {
|
| - oldReq, oldResp, proto, err := ar.unmarshal()
|
| + oldReq, oldResp, err := ar.unmarshal()
|
| if err != nil {
|
| return nil, fmt.Errorf("Error unmarshaling request #%d for %s: %v", k, ustr, err)
|
| }
|
| @@ -246,7 +255,7 @@ func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req
|
| continue
|
| }
|
| // TODO: allow changing scheme or protocol?
|
| - if err := clone.addArchivedRequest(u.Scheme, newReq, newResp, proto); err != nil {
|
| + if err := clone.addArchivedRequest(u.Scheme, newReq, newResp); err != nil {
|
| return nil, err
|
| }
|
| }
|
| @@ -286,33 +295,23 @@ func OpenWritableArchive(path string) (*WritableArchive, error) {
|
| func (a *WritableArchive) RecordRequest(scheme string, req *http.Request, resp *http.Response) error {
|
| a.mu.Lock()
|
| defer a.mu.Unlock()
|
| - proto := ""
|
| - if resp.TLS != nil && resp.TLS.NegotiatedProtocolIsMutual {
|
| - proto = resp.TLS.NegotiatedProtocol
|
| - }
|
| - return a.addArchivedRequest(scheme, req, resp, proto)
|
| + return a.addArchivedRequest(scheme, req, resp)
|
| }
|
|
|
| -// RecordCert records a cert in the archive.
|
| -func (a *WritableArchive) RecordCert(host string, der_bytes []byte) {
|
| +// RecordTlsConfig records the cert used and protocol negotiated for a host.
|
| +func (a *WritableArchive) RecordTlsConfig(host string, der_bytes []byte, negotiatedProtocol string) {
|
| a.mu.Lock()
|
| defer a.mu.Unlock()
|
| if a.Certs == nil {
|
| a.Certs = make(map[string][]byte)
|
| }
|
| if _, ok := a.Certs[host]; !ok {
|
| - fmt.Printf("Recorded cert for %s", host)
|
| a.Certs[host] = der_bytes
|
| }
|
| -}
|
| -
|
| -func (a *WritableArchive) FindHostCert(host string) ([]byte, error) {
|
| - a.mu.Lock()
|
| - defer a.mu.Unlock()
|
| - if cert, ok := a.Archive.Certs[host]; ok {
|
| - return cert, nil
|
| + if a.NegotiatedProtocol == nil {
|
| + a.NegotiatedProtocol = make(map[string]string)
|
| }
|
| - return nil, ErrNotFound
|
| + a.NegotiatedProtocol[host] = negotiatedProtocol
|
| }
|
|
|
| // Close flushes the the archive and closes the output file.
|
|
|