| OLD | NEW | 
|    1 // Copyright 2016 The Chromium Authors. All rights reserved. |    1 // Copyright 2016 The Chromium Authors. All rights reserved. | 
|    2 // Use of this source code is governed by a BSD-style license that can be |    2 // Use of this source code is governed by a BSD-style license that can be | 
|    3 // found in the LICENSE file. |    3 // found in the LICENSE file. | 
|    4  |    4  | 
|    5 package prpc |    5 package prpc | 
|    6  |    6  | 
|    7 import ( |    7 import ( | 
 |    8         "fmt" | 
|    8         "net/http" |    9         "net/http" | 
|    9         "sort" |   10         "sort" | 
|   10         "sync" |   11         "sync" | 
|   11  |   12  | 
|   12         "github.com/julienschmidt/httprouter" |   13         "github.com/julienschmidt/httprouter" | 
|   13         "golang.org/x/net/context" |   14         "golang.org/x/net/context" | 
|   14         "google.golang.org/grpc" |   15         "google.golang.org/grpc" | 
 |   16         "google.golang.org/grpc/codes" | 
|   15  |   17  | 
 |   18         "github.com/luci/luci-go/common/logging" | 
|   16         "github.com/luci/luci-go/server/auth" |   19         "github.com/luci/luci-go/server/auth" | 
|   17         "github.com/luci/luci-go/server/middleware" |   20         "github.com/luci/luci-go/server/middleware" | 
|   18 ) |   21 ) | 
|   19  |   22  | 
|   20 // Server is a pRPC server to serve RPC requests. |   23 // Server is a pRPC server to serve RPC requests. | 
|   21 // Zero value is valid. |   24 // Zero value is valid. | 
|   22 type Server struct { |   25 type Server struct { | 
|   23         // CustomAuthenticator, if true, disables the forced authentication set 
     by |   26         // CustomAuthenticator, if true, disables the forced authentication set 
     by | 
|   24         // RegisterDefaultAuth. |   27         // RegisterDefaultAuth. | 
|   25         CustomAuthenticator bool |   28         CustomAuthenticator bool | 
| (...skipping 22 matching lines...) Expand all  Loading... | 
|   48                         desc:    grpcDesc, |   51                         desc:    grpcDesc, | 
|   49                 } |   52                 } | 
|   50         } |   53         } | 
|   51  |   54  | 
|   52         s.mu.Lock() |   55         s.mu.Lock() | 
|   53         defer s.mu.Unlock() |   56         defer s.mu.Unlock() | 
|   54  |   57  | 
|   55         if s.services == nil { |   58         if s.services == nil { | 
|   56                 s.services = map[string]*service{} |   59                 s.services = map[string]*service{} | 
|   57         } else if _, ok := s.services[desc.ServiceName]; ok { |   60         } else if _, ok := s.services[desc.ServiceName]; ok { | 
|   58 »       »       panicf("service %q is already registered", desc.ServiceName) |   61 »       »       panic(fmt.Errorf("service %q is already registered", desc.Servic
     eName)) | 
|   59         } |   62         } | 
|   60  |   63  | 
|   61         s.services[desc.ServiceName] = serv |   64         s.services[desc.ServiceName] = serv | 
|   62 } |   65 } | 
|   63  |   66  | 
|   64 // authenticate forces authentication set by RegisterDefaultAuth. |   67 // authenticate forces authentication set by RegisterDefaultAuth. | 
|   65 func (s *Server) authenticate(base middleware.Base) middleware.Base { |   68 func (s *Server) authenticate(base middleware.Base) middleware.Base { | 
|   66         a := GetDefaultAuth() |   69         a := GetDefaultAuth() | 
|   67         if a == nil { |   70         if a == nil { | 
|   68 »       »       panicf("prpc: CustomAuthenticator is false, but default authenti
     cator was not registered. " + |   71 »       »       panic("prpc: CustomAuthenticator is false, but default authentic
     ator was not registered. " + | 
|   69                         "Forgot to import appengine/gaeauth/server package?") |   72                         "Forgot to import appengine/gaeauth/server package?") | 
|   70         } |   73         } | 
|   71  |   74  | 
|   72         return func(h middleware.Handler) httprouter.Handle { |   75         return func(h middleware.Handler) httprouter.Handle { | 
|   73                 return base(func(c context.Context, w http.ResponseWriter, r *ht
     tp.Request, p httprouter.Params) { |   76                 return base(func(c context.Context, w http.ResponseWriter, r *ht
     tp.Request, p httprouter.Params) { | 
|   74                         c = auth.SetAuthenticator(c, a) |   77                         c = auth.SetAuthenticator(c, a) | 
|   75                         c, err := a.Authenticate(c, r) |   78                         c, err := a.Authenticate(c, r) | 
|   76                         if err != nil { |   79                         if err != nil { | 
|   77 »       »       »       »       writeError(c, w, withStatus(err, http.StatusUnau
     thorized)) |   80 »       »       »       »       res := errResponse(codes.Unauthenticated, http.S
     tatusUnauthorized, err.Error()) | 
 |   81 »       »       »       »       res.write(c, w) | 
|   78                                 return |   82                                 return | 
|   79                         } |   83                         } | 
|   80                         h(c, w, r, p) |   84                         h(c, w, r, p) | 
|   81                 }) |   85                 }) | 
|   82         } |   86         } | 
|   83 } |   87 } | 
|   84  |   88  | 
|   85 // InstallHandlers installs HTTP POST handlers at |   89 // InstallHandlers installs HTTP handlers at /prpc/:service/:method. | 
|   86 // /prpc/{service_name}/{method_name} for all registered services. |   90 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol | 
 |   91 // for pRPC protocol. | 
|   87 func (s *Server) InstallHandlers(r *httprouter.Router, base middleware.Base) { |   92 func (s *Server) InstallHandlers(r *httprouter.Router, base middleware.Base) { | 
|   88         s.mu.Lock() |   93         s.mu.Lock() | 
|   89         defer s.mu.Unlock() |   94         defer s.mu.Unlock() | 
|   90  |   95  | 
|   91         if !s.CustomAuthenticator { |   96         if !s.CustomAuthenticator { | 
|   92                 base = s.authenticate(base) |   97                 base = s.authenticate(base) | 
|   93         } |   98         } | 
|   94  |   99  | 
|   95 »       for _, service := range s.services { |  100 »       r.POST("/prpc/:service/:method", base(s.handle)) | 
|   96 »       »       for _, m := range service.methods { |  101 } | 
|   97 »       »       »       m.InstallHandlers(r, base) |  102  | 
|   98 »       »       } |  103 // handle handles RPCs. | 
 |  104 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol | 
 |  105 // for pRPC protocol. | 
 |  106 func (s *Server) handle(c context.Context, w http.ResponseWriter, r *http.Reques
     t, p httprouter.Params) { | 
 |  107 »       serviceName := p.ByName("service") | 
 |  108 »       methodName := p.ByName("method") | 
 |  109 »       res := s.respond(c, w, r, serviceName, methodName) | 
 |  110  | 
 |  111 »       c = logging.SetFields(c, logging.Fields{ | 
 |  112 »       »       "service": serviceName, | 
 |  113 »       »       "method":  methodName, | 
 |  114 »       }) | 
 |  115 »       res.write(c, w) | 
 |  116 } | 
 |  117  | 
 |  118 func (s *Server) respond(c context.Context, w http.ResponseWriter, r *http.Reque
     st, serviceName, methodName string) *response { | 
 |  119 »       service := s.services[serviceName] | 
 |  120 »       if service == nil { | 
 |  121 »       »       return errResponse( | 
 |  122 »       »       »       codes.Unimplemented, | 
 |  123 »       »       »       http.StatusNotImplemented, | 
 |  124 »       »       »       fmt.Sprintf("service %q is not implemented", serviceName
     )) | 
|   99         } |  125         } | 
 |  126  | 
 |  127         method := service.methods[methodName] | 
 |  128         if method == nil { | 
 |  129                 return errResponse( | 
 |  130                         codes.Unimplemented, | 
 |  131                         http.StatusNotImplemented, | 
 |  132                         fmt.Sprintf("method %q in service %q is not implemented"
     , methodName, serviceName)) | 
 |  133         } | 
 |  134  | 
 |  135         return method.handle(c, w, r) | 
|  100 } |  136 } | 
|  101  |  137  | 
|  102 // ServiceNames returns a sorted list of full names of all registered services. |  138 // ServiceNames returns a sorted list of full names of all registered services. | 
|  103 func (s *Server) ServiceNames() []string { |  139 func (s *Server) ServiceNames() []string { | 
|  104         s.mu.Lock() |  140         s.mu.Lock() | 
|  105         defer s.mu.Unlock() |  141         defer s.mu.Unlock() | 
|  106  |  142  | 
|  107         names := make([]string, 0, len(s.services)) |  143         names := make([]string, 0, len(s.services)) | 
|  108         for name := range s.services { |  144         for name := range s.services { | 
|  109                 names = append(names, name) |  145                 names = append(names, name) | 
|  110         } |  146         } | 
|  111         sort.Strings(names) |  147         sort.Strings(names) | 
|  112         return names |  148         return names | 
|  113 } |  149 } | 
| OLD | NEW |