| OLD | NEW | 
|    1 // Copyright 2016 The Chromium Authors. All rights reserved. |    1 // Copyright 2016 The Chromium Authors. All rights reserved. | 
|    2 // Use of this source code is governed by a BSD-style license that can be |    2 // Use of this source code is governed by a BSD-style license that can be | 
|    3 // found in the LICENSE file. |    3 // found in the LICENSE file. | 
|    4  |    4  | 
|    5 package prpc |    5 package prpc | 
|    6  |    6  | 
|    7 import ( |    7 import ( | 
|    8         "fmt" |  | 
|    9         "net/http" |    8         "net/http" | 
|   10         "net/http/httptest" |  | 
|   11         "strings" |    9         "strings" | 
|   12         "testing" |   10         "testing" | 
|   13  |   11  | 
|   14         "github.com/golang/protobuf/proto" |   12         "github.com/golang/protobuf/proto" | 
|   15         "golang.org/x/net/context" |  | 
|   16         "google.golang.org/grpc" |  | 
|   17         "google.golang.org/grpc/codes" |   13         "google.golang.org/grpc/codes" | 
|   18  |   14  | 
|   19         "github.com/luci/luci-go/common/logging" |  | 
|   20         "github.com/luci/luci-go/common/logging/memlogger" |  | 
|   21  |  | 
|   22         . "github.com/luci/luci-go/common/testing/assertions" |   15         . "github.com/luci/luci-go/common/testing/assertions" | 
|   23         . "github.com/smartystreets/goconvey/convey" |   16         . "github.com/smartystreets/goconvey/convey" | 
|   24 ) |   17 ) | 
|   25  |   18  | 
|   26 func TestEncoding(t *testing.T) { |   19 func TestEncoding(t *testing.T) { | 
|   27         t.Parallel() |   20         t.Parallel() | 
|   28  |   21  | 
|   29         Convey("responseFormat", t, func() { |   22         Convey("responseFormat", t, func() { | 
|   30                 test := func(acceptHeader string, expectedFormat format, expecte
     dErr interface{}) { |   23                 test := func(acceptHeader string, expectedFormat format, expecte
     dErr interface{}) { | 
|   31                         acceptHeader = strings.Replace(acceptHeader, "{json}", m
     tPRPCJSNOPB, -1) |   24                         acceptHeader = strings.Replace(acceptHeader, "{json}", m
     tPRPCJSNOPB, -1) | 
| (...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after  Loading... | 
|   64                 test("{json},{binary},*/*;x=y", formatBinary, nil) |   57                 test("{json},{binary},*/*;x=y", formatBinary, nil) | 
|   65                 test("{json},{binary};q=0.9,*/*", formatBinary, nil) |   58                 test("{json},{binary};q=0.9,*/*", formatBinary, nil) | 
|   66                 test("{json},{binary};q=0.9,*/*;q=0.8", formatJSONPB, nil) |   59                 test("{json},{binary};q=0.9,*/*;q=0.8", formatJSONPB, nil) | 
|   67  |   60  | 
|   68                 // supported and unsupported mix |   61                 // supported and unsupported mix | 
|   69                 test("{json},foo/bar", formatJSONPB, nil) |   62                 test("{json},foo/bar", formatJSONPB, nil) | 
|   70                 test("{json};q=0.1,foo/bar", formatJSONPB, nil) |   63                 test("{json};q=0.1,foo/bar", formatJSONPB, nil) | 
|   71                 test("foo/bar;q=0.1,{json}", formatJSONPB, nil) |   64                 test("foo/bar;q=0.1,{json}", formatJSONPB, nil) | 
|   72  |   65  | 
|   73                 // only unsupported types |   66                 // only unsupported types | 
|   74 »       »       const err406 = "HTTP 406: Accept header: specified media types a
     re not not supported" |   67 »       »       const err406 = "pRPC: Accept header: specified media types are n
     ot not supported" | 
|   75                 test(mtPRPC+"; boo=true", 0, err406) |   68                 test(mtPRPC+"; boo=true", 0, err406) | 
|   76                 test(mtPRPC+"; encoding=blah", 0, err406) |   69                 test(mtPRPC+"; encoding=blah", 0, err406) | 
|   77                 test("x", 0, err406) |   70                 test("x", 0, err406) | 
|   78                 test("x,y", 0, err406) |   71                 test("x,y", 0, err406) | 
|   79  |   72  | 
|   80 »       »       test("x//y", 0, "HTTP 400: Accept header: expected token after s
     lash") |   73 »       »       test("x//y", 0, "pRPC: Accept header: expected token after slash
     ") | 
|   81         }) |   74         }) | 
|   82  |   75  | 
|   83 »       Convey("writeMessage", t, func() { |   76 »       Convey("respondMessage", t, func() { | 
|   84                 msg := &HelloReply{Message: "Hi"} |   77                 msg := &HelloReply{Message: "Hi"} | 
|   85  |   78  | 
|   86                 test := func(f format, body []byte, contentType string) { |   79                 test := func(f format, body []byte, contentType string) { | 
|   87                         Convey(contentType, func() { |   80                         Convey(contentType, func() { | 
|   88 »       »       »       »       res := httptest.NewRecorder() |   81 »       »       »       »       res := respondMessage(msg, f) | 
|   89 »       »       »       »       err := writeMessage(res, msg, f) |   82 »       »       »       »       So(res.code, ShouldEqual, codes.OK) | 
|   90 »       »       »       »       So(err, ShouldBeNil) |   83 »       »       »       »       So(res.header, ShouldResembleV, http.Header{ | 
|   91  |   84 »       »       »       »       »       headerContentType: []string{contentType}
     , | 
|   92 »       »       »       »       So(res.Code, ShouldEqual, http.StatusOK) |   85 »       »       »       »       }) | 
|   93 »       »       »       »       So(res.Body.Bytes(), ShouldResembleV, body) |   86 »       »       »       »       So(res.body, ShouldResembleV, body) | 
|   94 »       »       »       »       So(res.Header().Get("Content-Type"), ShouldEqual
     , contentType) |  | 
|   95                         }) |   87                         }) | 
|   96                 } |   88                 } | 
|   97  |   89  | 
|   98                 msgBytes, err := proto.Marshal(msg) |   90                 msgBytes, err := proto.Marshal(msg) | 
|   99                 So(err, ShouldBeNil) |   91                 So(err, ShouldBeNil) | 
|  100  |   92  | 
|  101                 test(formatBinary, msgBytes, mtPRPCBinary) |   93                 test(formatBinary, msgBytes, mtPRPCBinary) | 
|  102 »       »       test(formatJSONPB, []byte("{\n\t\"message\": \"Hi\"\n}\n"), mtPR
     PCJSNOPB) |   94 »       »       test(formatJSONPB, []byte(csrfPrefix+"{\"message\":\"Hi\"}"), mt
     PRPCJSNOPB) | 
|  103                 test(formatText, []byte("message: \"Hi\"\n"), mtPRPCText) |   95                 test(formatText, []byte("message: \"Hi\"\n"), mtPRPCText) | 
|  104         }) |   96         }) | 
|  105  |  | 
|  106         Convey("writeError", t, func() { |  | 
|  107                 test := func(err error, status int, body string, logMsgs ...meml
     ogger.LogEntry) { |  | 
|  108                         Convey(err.Error(), func() { |  | 
|  109                                 c := context.Background() |  | 
|  110                                 c = memlogger.Use(c) |  | 
|  111                                 log := logging.Get(c).(*memlogger.MemLogger) |  | 
|  112  |  | 
|  113                                 rec := httptest.NewRecorder() |  | 
|  114                                 writeError(c, rec, err) |  | 
|  115                                 So(rec.Code, ShouldEqual, status) |  | 
|  116                                 So(rec.Body.String(), ShouldEqual, body) |  | 
|  117  |  | 
|  118                                 actualMsgs := log.Messages() |  | 
|  119                                 for i := range actualMsgs { |  | 
|  120                                         actualMsgs[i].CallDepth = 0 |  | 
|  121                                 } |  | 
|  122                                 So(actualMsgs, ShouldResembleV, logMsgs) |  | 
|  123                         }) |  | 
|  124                 } |  | 
|  125  |  | 
|  126                 test(Errorf(http.StatusNotFound, "not found"), http.StatusNotFou
     nd, "not found\n") |  | 
|  127                 test(grpc.Errorf(codes.NotFound, "not found"), http.StatusNotFou
     nd, "not found\n") |  | 
|  128                 test( |  | 
|  129                         fmt.Errorf("unhandled"), |  | 
|  130                         http.StatusInternalServerError, |  | 
|  131                         "Internal server error\n", |  | 
|  132                         memlogger.LogEntry{Level: logging.Error, Msg: "HTTP 500:
      unhandled"}, |  | 
|  133                 ) |  | 
|  134         }) |  | 
|  135 } |   97 } | 
| OLD | NEW |