Index: web_page_replay_go/src/webpagereplay/archive.go |
diff --git a/web_page_replay_go/src/webpagereplay/archive.go b/web_page_replay_go/src/webpagereplay/archive.go |
index 5a5dd23e26e112ff30186adca57fe12c66bcb6e5..d0eee49a9e6af6391ff06e1f3b37a2f3a16972f6 100644 |
--- a/web_page_replay_go/src/webpagereplay/archive.go |
+++ b/web_page_replay_go/src/webpagereplay/archive.go |
@@ -28,10 +28,9 @@ var ErrNotFound = errors.New("not found") |
type ArchivedRequest struct { |
SerializedRequest []byte |
SerializedResponse []byte // if empty, the request failed |
- Proto string |
} |
-func serializeRequest(req *http.Request, resp *http.Response, proto string) (*ArchivedRequest, error) { |
+func serializeRequest(req *http.Request, resp *http.Response) (*ArchivedRequest, error) { |
url := req.URL.String() |
ar := &ArchivedRequest{} |
{ |
@@ -48,24 +47,22 @@ func serializeRequest(req *http.Request, resp *http.Response, proto string) (*Ar |
} |
ar.SerializedResponse = buf.Bytes() |
} |
- ar.Proto = proto |
return ar, nil |
} |
-func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, string, error) { |
+func (ar *ArchivedRequest) unmarshal() (*http.Request, *http.Response, error) { |
req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(ar.SerializedRequest))) |
if err != nil { |
- return nil, nil, "", fmt.Errorf("couldn't unmarshal request: %v", err) |
+ return nil, nil, fmt.Errorf("couldn't unmarshal request: %v", err) |
} |
resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(ar.SerializedResponse)), req) |
if err != nil { |
if req.Body != nil { |
req.Body.Close() |
} |
- return nil, nil, "", fmt.Errorf("couldn't unmarshal response: %v", err) |
+ return nil, nil, fmt.Errorf("couldn't unmarshal response: %v", err) |
} |
- proto := ar.Proto |
- return req, resp, proto, nil |
+ return req, resp, nil |
} |
// Archive contains an archive of requests. Immutable except when embedded in a WritableArchive. |
@@ -75,7 +72,11 @@ type Archive struct { |
// The two-level mapping makes it easier to search for similar requests. |
// There may be multiple requests for a given URL. |
Requests map[string]map[string][]*ArchivedRequest |
- Certs map[string][]byte |
+ // Maps host string to DER encoded certs. |
+ Certs map[string][]byte |
+ // Maps host string to the negotiated protocol. eg. "http/1.1" or "h2" |
+ // If absent, will default to "http/1.1". |
+ NegotiatedProtocol map[string]string |
} |
func newArchive() Archive { |
@@ -111,7 +112,7 @@ func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) { |
for _, urlmap := range a.Requests { |
for url, requests := range urlmap { |
for k, ar := range requests { |
- req, resp, _, err := ar.unmarshal() |
+ req, resp, err := ar.unmarshal() |
if err != nil { |
log.Printf("Error unmarshaling request #%d for %s: %v", k, url, err) |
continue |
@@ -122,20 +123,28 @@ func (a *Archive) ForEach(f func(req *http.Request, resp *http.Response)) { |
} |
} |
-func (a *Archive) FindHostCert(host string) ([]byte, error) { |
+// Returns the der encoded cert and negotiated protocol. |
+func (a *Archive) FindHostTlsConfig(host string) ([]byte, string, error) { |
if cert, ok := a.Certs[host]; ok { |
- return cert, nil |
+ return cert, a.findHostNegotiatedProtocol(host), nil |
} |
- return nil, ErrNotFound |
+ return nil, "", ErrNotFound |
+} |
+ |
+func (a *Archive) findHostNegotiatedProtocol(host string) string { |
+ if negotiatedProtocol, ok := a.NegotiatedProtocol[host]; ok { |
+ return negotiatedProtocol |
+ } |
+ return "http/1.1" |
} |
// FindRequest searches for the given request in the archive. |
// Returns ErrNotFound if the request could not be found. Does not consume req.Body. |
// TODO: header-based matching and conditional requests |
-func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request, *http.Response, string, error) { |
+func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request, *http.Response, error) { |
hostMap := a.Requests[req.Host] |
if len(hostMap) == 0 { |
- return nil, nil, "", ErrNotFound |
+ return nil, nil, ErrNotFound |
} |
// Exact match. Note that req may be relative, but hostMap keys are always absolute. |
@@ -144,8 +153,8 @@ func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request, |
u.Host = req.Host |
u.Scheme = scheme |
} |
- if req, resp, proto, err := findExactMatch(hostMap[u.String()], req.Method); err == nil { |
- return req, resp, proto, nil |
+ if req, resp, err := findExactMatch(hostMap[u.String()], req.Method); err == nil { |
+ return req, resp, nil |
} |
// For all URLs with a matching path, pick the URL that has the most matching query parameters. |
@@ -184,27 +193,27 @@ func (a *Archive) FindRequest(req *http.Request, scheme string) (*http.Request, |
return findExactMatch(hostMap[bestURL], req.Method) |
} |
- return nil, nil, "", ErrNotFound |
+ return nil, nil, ErrNotFound |
} |
// findExactMatch returns the first request that exactly matches the given request method. |
-func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request, *http.Response, string, error) { |
+func findExactMatch(requests []*ArchivedRequest, method string) (*http.Request, *http.Response, error) { |
for _, ar := range requests { |
- req, resp, proto, err := ar.unmarshal() |
+ req, resp, err := ar.unmarshal() |
if err != nil { |
log.Printf("Error unmarshaling request: %v\nAR.Request: %q\nAR.Response: %q", err, ar.SerializedRequest, ar.SerializedResponse) |
continue |
} |
if req.Method == method { |
- return req, resp, proto, nil |
+ return req, resp, nil |
} |
} |
- return nil, nil, "", ErrNotFound |
+ return nil, nil, ErrNotFound |
} |
-func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *http.Response, proto string) error { |
- ar, err := serializeRequest(req, resp, proto) |
+func (a *Archive) addArchivedRequest(scheme string, req *http.Request, resp *http.Response) error { |
+ ar, err := serializeRequest(req, resp) |
if err != nil { |
return err |
} |
@@ -231,7 +240,7 @@ func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req |
for ustr, requests := range urlmap { |
u, _ := url.Parse(ustr) |
for k, ar := range requests { |
- oldReq, oldResp, proto, err := ar.unmarshal() |
+ oldReq, oldResp, err := ar.unmarshal() |
if err != nil { |
return nil, fmt.Errorf("Error unmarshaling request #%d for %s: %v", k, ustr, err) |
} |
@@ -246,7 +255,7 @@ func (a *Archive) Edit(f func(req *http.Request, resp *http.Response) (*http.Req |
continue |
} |
// TODO: allow changing scheme or protocol? |
- if err := clone.addArchivedRequest(u.Scheme, newReq, newResp, proto); err != nil { |
+ if err := clone.addArchivedRequest(u.Scheme, newReq, newResp); err != nil { |
return nil, err |
} |
} |
@@ -286,33 +295,23 @@ func OpenWritableArchive(path string) (*WritableArchive, error) { |
func (a *WritableArchive) RecordRequest(scheme string, req *http.Request, resp *http.Response) error { |
a.mu.Lock() |
defer a.mu.Unlock() |
- proto := "" |
- if resp.TLS != nil && resp.TLS.NegotiatedProtocolIsMutual { |
- proto = resp.TLS.NegotiatedProtocol |
- } |
- return a.addArchivedRequest(scheme, req, resp, proto) |
+ return a.addArchivedRequest(scheme, req, resp) |
} |
-// RecordCert records a cert in the archive. |
-func (a *WritableArchive) RecordCert(host string, der_bytes []byte) { |
+// RecordTlsConfig records the cert used and protocol negotiated for a host. |
+func (a *WritableArchive) RecordTlsConfig(host string, der_bytes []byte, negotiatedProtocol string) { |
a.mu.Lock() |
defer a.mu.Unlock() |
if a.Certs == nil { |
a.Certs = make(map[string][]byte) |
} |
if _, ok := a.Certs[host]; !ok { |
- fmt.Printf("Recorded cert for %s", host) |
a.Certs[host] = der_bytes |
} |
-} |
- |
-func (a *WritableArchive) FindHostCert(host string) ([]byte, error) { |
- a.mu.Lock() |
- defer a.mu.Unlock() |
- if cert, ok := a.Archive.Certs[host]; ok { |
- return cert, nil |
+ if a.NegotiatedProtocol == nil { |
+ a.NegotiatedProtocol = make(map[string]string) |
} |
- return nil, ErrNotFound |
+ a.NegotiatedProtocol[host] = negotiatedProtocol |
} |
// Close flushes the the archive and closes the output file. |