| OLD | NEW | 
|    1 // Copyright 2016 The Chromium Authors. All rights reserved. |    1 // Copyright 2016 The Chromium Authors. All rights reserved. | 
|    2 // Use of this source code is governed by a BSD-style license that can be |    2 // Use of this source code is governed by a BSD-style license that can be | 
|    3 // found in the LICENSE file. |    3 // found in the LICENSE file. | 
|    4  |    4  | 
|    5 package prpc |    5 package prpc | 
|    6  |    6  | 
|    7 // This file implements encoding of RPC results to HTTP responses. |    7 // This file implements encoding of RPC results to HTTP responses. | 
|    8  |    8  | 
|    9 import ( |    9 import ( | 
|   10         "bytes" |   10         "bytes" | 
|   11         "io" |  | 
|   12         "net/http" |   11         "net/http" | 
|   13         "sort" |   12         "sort" | 
|   14  |   13  | 
|   15         "github.com/golang/protobuf/jsonpb" |   14         "github.com/golang/protobuf/jsonpb" | 
|   16         "github.com/golang/protobuf/proto" |   15         "github.com/golang/protobuf/proto" | 
|   17         "golang.org/x/net/context" |   16         "golang.org/x/net/context" | 
|   18         "google.golang.org/grpc" |   17         "google.golang.org/grpc" | 
|   19         "google.golang.org/grpc/codes" |   18         "google.golang.org/grpc/codes" | 
|   20  |  | 
|   21         "github.com/luci/luci-go/common/logging" |  | 
|   22 ) |   19 ) | 
|   23  |   20  | 
|   24 const ( |   21 const ( | 
|   25         headerAccept = "Accept" |   22         headerAccept = "Accept" | 
 |   23         csrfPrefix   = ")]}'\n" | 
|   26 ) |   24 ) | 
|   27  |   25  | 
|   28 // responseFormat returns the format to be used in a response. |   26 // responseFormat returns the format to be used in a response. | 
|   29 // Can return only formatBinary (preferred), formatJSONPB or formatText. |   27 // Can return only formatBinary (preferred), formatJSONPB or formatText. | 
|   30 // In case of an error, format is undefined and the error has an HTTP status. |   28 // In case of an error, format is undefined. | 
|   31 func responseFormat(acceptHeader string) (format, *httpError) { |   29 func responseFormat(acceptHeader string) (format, *protocolError) { | 
|   32         if acceptHeader == "" { |   30         if acceptHeader == "" { | 
|   33                 return formatBinary, nil |   31                 return formatBinary, nil | 
|   34         } |   32         } | 
|   35  |   33  | 
|   36         parsed, err := parseAccept(acceptHeader) |   34         parsed, err := parseAccept(acceptHeader) | 
|   37         if err != nil { |   35         if err != nil { | 
|   38                 return formatBinary, errorf(http.StatusBadRequest, "Accept heade
     r: %s", err) |   36                 return formatBinary, errorf(http.StatusBadRequest, "Accept heade
     r: %s", err) | 
|   39         } |   37         } | 
|   40         assert(len(parsed) > 0) |  | 
|   41         formats := make(acceptFormatSlice, 0, len(parsed)) |   38         formats := make(acceptFormatSlice, 0, len(parsed)) | 
|   42         for _, at := range parsed { |   39         for _, at := range parsed { | 
|   43                 f, err := parseFormat(at.MediaType, at.MediaTypeParams) |   40                 f, err := parseFormat(at.MediaType, at.MediaTypeParams) | 
|   44                 if err != nil { |   41                 if err != nil { | 
|   45                         // Ignore invalid format. Check further. |   42                         // Ignore invalid format. Check further. | 
|   46                         continue |   43                         continue | 
|   47                 } |   44                 } | 
|   48                 switch f { |   45                 switch f { | 
|   49  |   46  | 
|   50                 case formatBinary, formatJSONPB, formatText: |   47                 case formatBinary, formatJSONPB, formatText: | 
|   51                         // fine |   48                         // fine | 
|   52  |   49  | 
|   53                 case formatUnspecified: |   50                 case formatUnspecified: | 
|   54                         f = formatBinary // prefer binary |   51                         f = formatBinary // prefer binary | 
|   55  |   52  | 
|   56 »       »       case formatUnrecognized: |   53 »       »       default: | 
|   57                         continue |   54                         continue | 
|   58  |  | 
|   59                 default: |  | 
|   60                         panicf("cannot happen") |  | 
|   61                 } |   55                 } | 
|   62  |   56  | 
|   63                 assert(f == formatBinary || f == formatJSONPB || f == formatText
     ) |  | 
|   64                 formats = append(formats, acceptFormat{f, at.QualityFactor}) |   57                 formats = append(formats, acceptFormat{f, at.QualityFactor}) | 
|   65         } |   58         } | 
|   66         if len(formats) == 0 { |   59         if len(formats) == 0 { | 
|   67                 return formatBinary, errorf( |   60                 return formatBinary, errorf( | 
|   68                         http.StatusNotAcceptable, |   61                         http.StatusNotAcceptable, | 
|   69                         "Accept header: specified media types are not not suppor
     ted. Supported types: %q, %q, %q, %q.", |   62                         "Accept header: specified media types are not not suppor
     ted. Supported types: %q, %q, %q, %q.", | 
|   70                         mtPRPCBinary, |   63                         mtPRPCBinary, | 
|   71                         mtPRPCJSNOPB, |   64                         mtPRPCJSNOPB, | 
|   72                         mtPRPCText, |   65                         mtPRPCText, | 
|   73                         mtJSON, |   66                         mtJSON, | 
|   74                 ) |   67                 ) | 
|   75         } |   68         } | 
|   76         sort.Sort(formats) // order by quality factor and format preference. |   69         sort.Sort(formats) // order by quality factor and format preference. | 
|   77         return formats[0].Format, nil |   70         return formats[0].Format, nil | 
|   78 } |   71 } | 
|   79  |   72  | 
|   80 // writeMessage writes a protobuf message to response in the specified format. |   73 // respondMessage encodes msg to a response in the specified format. | 
|   81 func writeMessage(w http.ResponseWriter, msg proto.Message, format format) error
      { |   74 func respondMessage(msg proto.Message, format format) *response { | 
|   82         if msg == nil { |   75         if msg == nil { | 
|   83 »       »       panic("msg is nil") |   76 »       »       return errResponse(codes.Internal, 0, "pRPC: responseMessage: ms
     g is nil") | 
|   84         } |   77         } | 
|   85 »       var ( |   78 »       res := response{header: http.Header{}} | 
|   86 »       »       contentType string |   79 »       var err error | 
|   87 »       »       res         []byte |  | 
|   88 »       »       err         error |  | 
|   89 »       ) |  | 
|   90         switch format { |   80         switch format { | 
|   91         case formatBinary: |   81         case formatBinary: | 
|   92 »       »       contentType = mtPRPCBinary |   82 »       »       res.header.Set(headerContentType, mtPRPCBinary) | 
|   93 »       »       res, err = proto.Marshal(msg) |   83 »       »       res.body, err = proto.Marshal(msg) | 
|   94  |   84  | 
|   95         case formatJSONPB: |   85         case formatJSONPB: | 
|   96 »       »       contentType = mtPRPCJSNOPB |   86 »       »       res.header.Set(headerContentType, mtPRPCJSNOPB) | 
|   97 »       »       m := jsonpb.Marshaler{Indent: "\t"} |  | 
|   98                 var buf bytes.Buffer |   87                 var buf bytes.Buffer | 
 |   88                 buf.WriteString(csrfPrefix) | 
 |   89                 m := jsonpb.Marshaler{} | 
|   99                 err = m.Marshal(&buf, msg) |   90                 err = m.Marshal(&buf, msg) | 
|  100 »       »       buf.WriteString("\n") |   91 »       »       res.body = buf.Bytes() | 
|  101 »       »       res = buf.Bytes() |   92 »       »       res.newLine = true | 
|  102  |   93  | 
|  103         case formatText: |   94         case formatText: | 
|  104 »       »       contentType = mtPRPCText |   95 »       »       res.header.Set(headerContentType, mtPRPCText) | 
|  105                 var buf bytes.Buffer |   96                 var buf bytes.Buffer | 
|  106                 err = proto.MarshalText(&buf, msg) |   97                 err = proto.MarshalText(&buf, msg) | 
|  107 »       »       res = buf.Bytes() |   98 »       »       res.body = buf.Bytes() | 
 |   99  | 
 |  100 »       default: | 
 |  101 »       »       return errResponse(codes.Internal, 0, "pRPC: responseMessage: in
     valid format %s", format) | 
 |  102  | 
|  108         } |  103         } | 
|  109         if err != nil { |  104         if err != nil { | 
|  110 »       »       return err |  105 »       »       return errResponse(codes.Internal, 0, err.Error()) | 
|  111         } |  106         } | 
|  112 »       w.Header().Set(headerContentType, contentType) |  107  | 
|  113 »       _, err = w.Write(res) |  108 »       return &res | 
|  114 »       return err |  109 } | 
 |  110  | 
 |  111 // respondProtocolError creates a response for a pRPC protocol error. | 
 |  112 func respondProtocolError(err *protocolError) *response { | 
 |  113 »       return errResponse(codes.InvalidArgument, err.status, err.err.Error()) | 
 |  114 } | 
 |  115  | 
 |  116 // errorCode returns a most appropriate gRPC code for an error | 
 |  117 func errorCode(err error) codes.Code { | 
 |  118 »       switch err { | 
 |  119 »       case context.DeadlineExceeded: | 
 |  120 »       »       return codes.DeadlineExceeded | 
 |  121  | 
 |  122 »       case context.Canceled: | 
 |  123 »       »       return codes.Canceled | 
 |  124  | 
 |  125 »       default: | 
 |  126 »       »       return grpc.Code(err) | 
 |  127 »       } | 
|  115 } |  128 } | 
|  116  |  129  | 
|  117 // codeToStatus maps gRPC codes to HTTP statuses. |  130 // codeToStatus maps gRPC codes to HTTP statuses. | 
|  118 // This map may need to be corrected when |  131 // This map may need to be corrected when | 
|  119 // https://github.com/grpc/grpc-common/issues/210 |  132 // https://github.com/grpc/grpc-common/issues/210 | 
|  120 // is closed. |  133 // is closed. | 
|  121 var codeToStatus = map[codes.Code]int{ |  134 var codeToStatus = map[codes.Code]int{ | 
|  122         codes.OK:                 http.StatusOK, |  135         codes.OK:                 http.StatusOK, | 
|  123         codes.Canceled:           http.StatusNoContent, |  136         codes.Canceled:           http.StatusNoContent, | 
|  124         codes.Unknown:            http.StatusInternalServerError, |  | 
|  125         codes.InvalidArgument:    http.StatusBadRequest, |  137         codes.InvalidArgument:    http.StatusBadRequest, | 
|  126         codes.DeadlineExceeded:   http.StatusServiceUnavailable, |  138         codes.DeadlineExceeded:   http.StatusServiceUnavailable, | 
|  127         codes.NotFound:           http.StatusNotFound, |  139         codes.NotFound:           http.StatusNotFound, | 
|  128         codes.AlreadyExists:      http.StatusConflict, |  140         codes.AlreadyExists:      http.StatusConflict, | 
|  129         codes.PermissionDenied:   http.StatusForbidden, |  141         codes.PermissionDenied:   http.StatusForbidden, | 
|  130         codes.Unauthenticated:    http.StatusUnauthorized, |  142         codes.Unauthenticated:    http.StatusUnauthorized, | 
|  131         codes.ResourceExhausted:  http.StatusServiceUnavailable, |  143         codes.ResourceExhausted:  http.StatusServiceUnavailable, | 
|  132         codes.FailedPrecondition: http.StatusPreconditionFailed, |  144         codes.FailedPrecondition: http.StatusPreconditionFailed, | 
|  133         codes.Aborted:            http.StatusInternalServerError, |  | 
|  134         codes.OutOfRange:         http.StatusBadRequest, |  145         codes.OutOfRange:         http.StatusBadRequest, | 
|  135         codes.Unimplemented:      http.StatusNotImplemented, |  146         codes.Unimplemented:      http.StatusNotImplemented, | 
|  136         codes.Internal:           http.StatusInternalServerError, |  | 
|  137         codes.Unavailable:        http.StatusServiceUnavailable, |  147         codes.Unavailable:        http.StatusServiceUnavailable, | 
|  138         codes.DataLoss:           http.StatusInternalServerError, |  | 
|  139 } |  148 } | 
|  140  |  149  | 
|  141 // ErrorStatus returns HTTP status for an error. |  150 // codeStatus maps gRPC codes to HTTP status codes. | 
|  142 // In particular, it maps gRPC codes to HTTP statuses. |  151 // Falls back to http.StatusInternalServerError. | 
|  143 // Status of nil is 200. |  152 func codeStatus(code codes.Code) int { | 
|  144 // |  153 »       if status, ok := codeToStatus[code]; ok { | 
|  145 // See also grpc.Code. |  154 »       »       return status | 
|  146 func ErrorStatus(err error) int { |  | 
|  147 »       if err, ok := err.(*httpError); ok { |  | 
|  148 »       »       return err.status |  | 
|  149         } |  155         } | 
|  150  |  156 »       return http.StatusInternalServerError | 
|  151 »       status, ok := codeToStatus[grpc.Code(err)] |  | 
|  152 »       if !ok { |  | 
|  153 »       »       status = http.StatusInternalServerError |  | 
|  154 »       } |  | 
|  155 »       return status |  | 
|  156 } |  157 } | 
|  157  |  | 
|  158 // ErrorDesc returns the error description of err if it was produced by pRPC or 
     gRPC. |  | 
|  159 // Otherwise, it returns err.Error() or empty string when err is nil. |  | 
|  160 // |  | 
|  161 // See also grpc.ErrorDesc. |  | 
|  162 func ErrorDesc(err error) string { |  | 
|  163         if err == nil { |  | 
|  164                 return "" |  | 
|  165         } |  | 
|  166         if e, ok := err.(*httpError); ok { |  | 
|  167                 err = e.err |  | 
|  168         } |  | 
|  169         return grpc.ErrorDesc(err) |  | 
|  170 } |  | 
|  171  |  | 
|  172 // writeError writes an error to an HTTP response. |  | 
|  173 // |  | 
|  174 // HTTP status is determined by ErrorStatus. |  | 
|  175 // If it is http.StatusInternalServerError, prints only "Internal server error", |  | 
|  176 // otherwise uses ErrorDesc. |  | 
|  177 // |  | 
|  178 // Logs all errors with status >= 500. |  | 
|  179 func writeError(c context.Context, w http.ResponseWriter, err error) { |  | 
|  180         if err == nil { |  | 
|  181                 panic("err is nil") |  | 
|  182         } |  | 
|  183  |  | 
|  184         status := ErrorStatus(err) |  | 
|  185         if status >= 500 { |  | 
|  186                 logging.Errorf(c, "HTTP %d: %s", status, ErrorDesc(err)) |  | 
|  187         } |  | 
|  188  |  | 
|  189         w.Header().Set(headerContentType, "text/plain") |  | 
|  190         w.WriteHeader(status) |  | 
|  191  |  | 
|  192         var body string |  | 
|  193         if status == http.StatusInternalServerError { |  | 
|  194                 body = "Internal server error" |  | 
|  195         } else { |  | 
|  196                 body = ErrorDesc(err) |  | 
|  197         } |  | 
|  198         if _, err := io.WriteString(w, body+"\n"); err != nil { |  | 
|  199                 logging.Errorf(c, "could not write error: %s", err) |  | 
|  200         } |  | 
|  201 } |  | 
|  202  |  | 
|  203 func assert(condition bool) { |  | 
|  204         if !condition { |  | 
|  205                 panicf("assertion failed") |  | 
|  206         } |  | 
|  207 } |  | 
| OLD | NEW |