| OLD | NEW |
| 1 // Copyright 2016 The LUCI Authors. All rights reserved. | 1 // Copyright 2016 The LUCI Authors. All rights reserved. |
| 2 // Use of this source code is governed under the Apache License, Version 2.0 | 2 // Use of this source code is governed under the Apache License, Version 2.0 |
| 3 // that can be found in the LICENSE file. | 3 // that can be found in the LICENSE file. |
| 4 | 4 |
| 5 package prpc | 5 package prpc |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "fmt" | 8 "fmt" |
| 9 "net/http" | 9 "net/http" |
| 10 "sort" | 10 "sort" |
| 11 "strings" | 11 "strings" |
| 12 "sync" | 12 "sync" |
| 13 | 13 |
| 14 "github.com/julienschmidt/httprouter" | |
| 15 "golang.org/x/net/context" | 14 "golang.org/x/net/context" |
| 16 "google.golang.org/grpc" | 15 "google.golang.org/grpc" |
| 17 "google.golang.org/grpc/codes" | 16 "google.golang.org/grpc/codes" |
| 18 | 17 |
| 19 "github.com/luci/luci-go/common/errors" | 18 "github.com/luci/luci-go/common/errors" |
| 20 "github.com/luci/luci-go/common/logging" | 19 "github.com/luci/luci-go/common/logging" |
| 21 "github.com/luci/luci-go/server/auth" | 20 "github.com/luci/luci-go/server/auth" |
| 22 » "github.com/luci/luci-go/server/middleware" | 21 » "github.com/luci/luci-go/server/router" |
| 23 | 22 |
| 24 prpccommon "github.com/luci/luci-go/common/prpc" | 23 prpccommon "github.com/luci/luci-go/common/prpc" |
| 25 ) | 24 ) |
| 26 | 25 |
| 27 var ( | 26 var ( |
| 28 // Describe the permitted Access Control requests. | 27 // Describe the permitted Access Control requests. |
| 29 allowHeaders = strings.Join([]string{"Origin", "Content-Type", "Accept"}
, ", ") | 28 allowHeaders = strings.Join([]string{"Origin", "Content-Type", "Accept"}
, ", ") |
| 30 allowMethods = strings.Join([]string{"OPTIONS", "POST"}, ", ") | 29 allowMethods = strings.Join([]string{"OPTIONS", "POST"}, ", ") |
| 31 | 30 |
| 32 // allowPreflightCacheAgeSecs is the amount of time to enable the browse
r to | 31 // allowPreflightCacheAgeSecs is the amount of time to enable the browse
r to |
| (...skipping 67 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 100 if s.services == nil { | 99 if s.services == nil { |
| 101 s.services = map[string]*service{} | 100 s.services = map[string]*service{} |
| 102 } else if _, ok := s.services[desc.ServiceName]; ok { | 101 } else if _, ok := s.services[desc.ServiceName]; ok { |
| 103 panic(fmt.Errorf("service %q is already registered", desc.Servic
eName)) | 102 panic(fmt.Errorf("service %q is already registered", desc.Servic
eName)) |
| 104 } | 103 } |
| 105 | 104 |
| 106 s.services[desc.ServiceName] = serv | 105 s.services[desc.ServiceName] = serv |
| 107 } | 106 } |
| 108 | 107 |
| 109 // authenticate forces authentication set by RegisterDefaultAuth. | 108 // authenticate forces authentication set by RegisterDefaultAuth. |
| 110 func (s *Server) authenticate(base middleware.Base) middleware.Base { | 109 func (s *Server) authenticate() router.Middleware { |
| 111 a := s.Authenticator | 110 a := s.Authenticator |
| 112 if a == nil { | 111 if a == nil { |
| 113 a = GetDefaultAuth() | 112 a = GetDefaultAuth() |
| 114 if a == nil { | 113 if a == nil { |
| 115 panic("prpc: no custom Authenticator was provided and de
fault authenticator was not registered. " + | 114 panic("prpc: no custom Authenticator was provided and de
fault authenticator was not registered. " + |
| 116 "Forgot to import appengine/gaeauth/server packa
ge?") | 115 "Forgot to import appengine/gaeauth/server packa
ge?") |
| 117 } | 116 } |
| 118 } | 117 } |
| 119 | 118 |
| 120 if len(a) == 0 { | 119 if len(a) == 0 { |
| 121 » » return base | 120 » » return nil |
| 122 } | 121 } |
| 123 | 122 |
| 124 » return func(h middleware.Handler) httprouter.Handle { | 123 » return func(c *router.Context, next router.Handler) { |
| 125 » » return base(func(c context.Context, w http.ResponseWriter, r *ht
tp.Request, p httprouter.Params) { | 124 » » c.Context = auth.SetAuthenticator(c.Context, a) |
| 126 » » » c = auth.SetAuthenticator(c, a) | 125 » » var err error |
| 127 » » » switch c, err := a.Authenticate(c, r); { | 126 » » switch c.Context, err = a.Authenticate(c.Context, c.Request); { |
| 128 » » » case errors.IsTransient(err): | 127 » » case errors.IsTransient(err): |
| 129 » » » » res := errResponse(codes.Internal, http.StatusIn
ternalServerError, escapeFmt(err.Error())) | 128 » » » res := errResponse(codes.Internal, http.StatusInternalSe
rverError, escapeFmt(err.Error())) |
| 130 » » » » res.write(c, w) | 129 » » » res.write(c.Context, c.Writer) |
| 131 » » » case err != nil: | 130 » » case err != nil: |
| 132 » » » » res := errResponse(codes.Unauthenticated, http.S
tatusUnauthorized, escapeFmt(err.Error())) | 131 » » » res := errResponse(codes.Unauthenticated, http.StatusUna
uthorized, escapeFmt(err.Error())) |
| 133 » » » » res.write(c, w) | 132 » » » res.write(c.Context, c.Writer) |
| 134 » » » default: | 133 » » default: |
| 135 » » » » h(c, w, r, p) | 134 » » » next(c) |
| 136 » » » } | 135 » » } |
| 137 » » }) | |
| 138 } | 136 } |
| 139 } | 137 } |
| 140 | 138 |
| 141 // InstallHandlers installs HTTP handlers at /prpc/:service/:method. | 139 // InstallHandlers installs HTTP handlers at /prpc/:service/:method. |
| 142 // | 140 // |
| 143 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol | 141 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol |
| 144 // for pRPC protocol. | 142 // for pRPC protocol. |
| 145 // | 143 // |
| 146 // The authenticator in 'base' is always replaced by pRPC specific one. For more | 144 // The authenticator in 'base' is always replaced by pRPC specific one. For more |
| 147 // details about the authentication see Server.Authenticator doc. | 145 // details about the authentication see Server.Authenticator doc. |
| 148 func (s *Server) InstallHandlers(r *httprouter.Router, base middleware.Base) { | 146 func (s *Server) InstallHandlers(r *router.Router, base router.MiddlewareChain)
{ |
| 149 s.mu.Lock() | 147 s.mu.Lock() |
| 150 defer s.mu.Unlock() | 148 defer s.mu.Unlock() |
| 151 | 149 |
| 152 » base = s.authenticate(base) | 150 » rr := r.Subrouter("/prpc/:service/:method") |
| 151 » rr.Use(append(base, s.authenticate())) |
| 153 | 152 |
| 154 » r.POST("/prpc/:service/:method", base(s.handlePOST)) | 153 » rr.POST("", nil, s.handlePOST) |
| 155 » r.OPTIONS("/prpc/:service/:method", base(s.handleOPTIONS)) | 154 » rr.OPTIONS("", nil, s.handleOPTIONS) |
| 156 } | 155 } |
| 157 | 156 |
| 158 // handle handles RPCs. | 157 // handle handles RPCs. |
| 159 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol | 158 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol |
| 160 // for pRPC protocol. | 159 // for pRPC protocol. |
| 161 func (s *Server) handlePOST(c context.Context, w http.ResponseWriter, r *http.Re
quest, p httprouter.Params) { | 160 func (s *Server) handlePOST(c *router.Context) { |
| 162 » serviceName := p.ByName("service") | 161 » serviceName := c.Params.ByName("service") |
| 163 » methodName := p.ByName("method") | 162 » methodName := c.Params.ByName("method") |
| 164 » res := s.respond(c, w, r, serviceName, methodName) | 163 » res := s.respond(c.Context, c.Writer, c.Request, serviceName, methodName
) |
| 165 | 164 |
| 166 » c = logging.SetFields(c, logging.Fields{ | 165 » c.Context = logging.SetFields(c.Context, logging.Fields{ |
| 167 "service": serviceName, | 166 "service": serviceName, |
| 168 "method": methodName, | 167 "method": methodName, |
| 169 }) | 168 }) |
| 170 » s.setAccessControlHeaders(c, r, w, false) | 169 » s.setAccessControlHeaders(c.Context, c.Request, c.Writer, false) |
| 171 » res.write(c, w) | 170 » res.write(c.Context, c.Writer) |
| 172 } | 171 } |
| 173 | 172 |
| 174 func (s *Server) handleOPTIONS(c context.Context, w http.ResponseWriter, r *http
.Request, p httprouter.Params) { | 173 func (s *Server) handleOPTIONS(c *router.Context) { |
| 175 » s.setAccessControlHeaders(c, r, w, true) | 174 » s.setAccessControlHeaders(c.Context, c.Request, c.Writer, true) |
| 176 » w.WriteHeader(http.StatusOK) | 175 » c.Writer.WriteHeader(http.StatusOK) |
| 177 } | 176 } |
| 178 | 177 |
| 179 func (s *Server) respond(c context.Context, w http.ResponseWriter, r *http.Reque
st, serviceName, methodName string) *response { | 178 func (s *Server) respond(c context.Context, w http.ResponseWriter, r *http.Reque
st, serviceName, methodName string) *response { |
| 180 service := s.services[serviceName] | 179 service := s.services[serviceName] |
| 181 if service == nil { | 180 if service == nil { |
| 182 return errResponse( | 181 return errResponse( |
| 183 codes.Unimplemented, | 182 codes.Unimplemented, |
| 184 http.StatusNotImplemented, | 183 http.StatusNotImplemented, |
| 185 "service %q is not implemented", | 184 "service %q is not implemented", |
| 186 serviceName) | 185 serviceName) |
| (...skipping 38 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 225 s.mu.Lock() | 224 s.mu.Lock() |
| 226 defer s.mu.Unlock() | 225 defer s.mu.Unlock() |
| 227 | 226 |
| 228 names := make([]string, 0, len(s.services)) | 227 names := make([]string, 0, len(s.services)) |
| 229 for name := range s.services { | 228 for name := range s.services { |
| 230 names = append(names, name) | 229 names = append(names, name) |
| 231 } | 230 } |
| 232 sort.Strings(names) | 231 sort.Strings(names) |
| 233 return names | 232 return names |
| 234 } | 233 } |
| OLD | NEW |