| OLD | NEW |
| 1 // Copyright 2016 The Chromium Authors. All rights reserved. | 1 // Copyright 2016 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 package prpc | 5 package prpc |
| 6 | 6 |
| 7 import ( | 7 import ( |
| 8 "fmt" |
| 8 "net/http" | 9 "net/http" |
| 9 "sort" | 10 "sort" |
| 10 "sync" | 11 "sync" |
| 11 | 12 |
| 12 "github.com/julienschmidt/httprouter" | 13 "github.com/julienschmidt/httprouter" |
| 13 "golang.org/x/net/context" | 14 "golang.org/x/net/context" |
| 14 "google.golang.org/grpc" | 15 "google.golang.org/grpc" |
| 16 "google.golang.org/grpc/codes" |
| 15 | 17 |
| 18 "github.com/luci/luci-go/common/logging" |
| 16 "github.com/luci/luci-go/server/auth" | 19 "github.com/luci/luci-go/server/auth" |
| 17 "github.com/luci/luci-go/server/middleware" | 20 "github.com/luci/luci-go/server/middleware" |
| 18 ) | 21 ) |
| 19 | 22 |
| 20 // Server is a pRPC server to serve RPC requests. | 23 // Server is a pRPC server to serve RPC requests. |
| 21 // Zero value is valid. | 24 // Zero value is valid. |
| 22 type Server struct { | 25 type Server struct { |
| 23 // CustomAuthenticator, if true, disables the forced authentication set
by | 26 // CustomAuthenticator, if true, disables the forced authentication set
by |
| 24 // RegisterDefaultAuth. | 27 // RegisterDefaultAuth. |
| 25 CustomAuthenticator bool | 28 CustomAuthenticator bool |
| (...skipping 22 matching lines...) Expand all Loading... |
| 48 desc: grpcDesc, | 51 desc: grpcDesc, |
| 49 } | 52 } |
| 50 } | 53 } |
| 51 | 54 |
| 52 s.mu.Lock() | 55 s.mu.Lock() |
| 53 defer s.mu.Unlock() | 56 defer s.mu.Unlock() |
| 54 | 57 |
| 55 if s.services == nil { | 58 if s.services == nil { |
| 56 s.services = map[string]*service{} | 59 s.services = map[string]*service{} |
| 57 } else if _, ok := s.services[desc.ServiceName]; ok { | 60 } else if _, ok := s.services[desc.ServiceName]; ok { |
| 58 » » panicf("service %q is already registered", desc.ServiceName) | 61 » » panic(fmt.Errorf("service %q is already registered", desc.Servic
eName)) |
| 59 } | 62 } |
| 60 | 63 |
| 61 s.services[desc.ServiceName] = serv | 64 s.services[desc.ServiceName] = serv |
| 62 } | 65 } |
| 63 | 66 |
| 64 // authenticate forces authentication set by RegisterDefaultAuth. | 67 // authenticate forces authentication set by RegisterDefaultAuth. |
| 65 func (s *Server) authenticate(base middleware.Base) middleware.Base { | 68 func (s *Server) authenticate(base middleware.Base) middleware.Base { |
| 66 a := GetDefaultAuth() | 69 a := GetDefaultAuth() |
| 67 if a == nil { | 70 if a == nil { |
| 68 » » panicf("prpc: CustomAuthenticator is false, but default authenti
cator was not registered. " + | 71 » » panic("prpc: CustomAuthenticator is false, but default authentic
ator was not registered. " + |
| 69 "Forgot to import appengine/gaeauth/server package?") | 72 "Forgot to import appengine/gaeauth/server package?") |
| 70 } | 73 } |
| 71 | 74 |
| 72 return func(h middleware.Handler) httprouter.Handle { | 75 return func(h middleware.Handler) httprouter.Handle { |
| 73 return base(func(c context.Context, w http.ResponseWriter, r *ht
tp.Request, p httprouter.Params) { | 76 return base(func(c context.Context, w http.ResponseWriter, r *ht
tp.Request, p httprouter.Params) { |
| 74 c = auth.SetAuthenticator(c, a) | 77 c = auth.SetAuthenticator(c, a) |
| 75 c, err := a.Authenticate(c, r) | 78 c, err := a.Authenticate(c, r) |
| 76 if err != nil { | 79 if err != nil { |
| 77 » » » » writeError(c, w, withStatus(err, http.StatusUnau
thorized)) | 80 » » » » res := errResponse(codes.Unauthenticated, http.S
tatusUnauthorized, err.Error()) |
| 81 » » » » res.write(c, w) |
| 78 return | 82 return |
| 79 } | 83 } |
| 80 h(c, w, r, p) | 84 h(c, w, r, p) |
| 81 }) | 85 }) |
| 82 } | 86 } |
| 83 } | 87 } |
| 84 | 88 |
| 85 // InstallHandlers installs HTTP POST handlers at | 89 // InstallHandlers installs HTTP handlers at /prpc/:service/:method. |
| 86 // /prpc/{service_name}/{method_name} for all registered services. | 90 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol |
| 91 // for pRPC protocol. |
| 87 func (s *Server) InstallHandlers(r *httprouter.Router, base middleware.Base) { | 92 func (s *Server) InstallHandlers(r *httprouter.Router, base middleware.Base) { |
| 88 s.mu.Lock() | 93 s.mu.Lock() |
| 89 defer s.mu.Unlock() | 94 defer s.mu.Unlock() |
| 90 | 95 |
| 91 if !s.CustomAuthenticator { | 96 if !s.CustomAuthenticator { |
| 92 base = s.authenticate(base) | 97 base = s.authenticate(base) |
| 93 } | 98 } |
| 94 | 99 |
| 95 » for _, service := range s.services { | 100 » const path = "/prpc/:service/:method" |
| 96 » » for _, m := range service.methods { | 101 » handle := base(s.handle) |
| 97 » » » m.InstallHandlers(r, base) | 102 » r.POST(path, handle) |
| 98 » » } | 103 » r.GET(path, handle) |
| 104 » r.PUT(path, handle) |
| 105 » r.DELETE(path, handle) |
| 106 » r.PATCH(path, handle) |
| 107 } |
| 108 |
| 109 // handle handles RPCs. |
| 110 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol |
| 111 // for pRPC protocol. |
| 112 func (s *Server) handle(c context.Context, w http.ResponseWriter, r *http.Reques
t, p httprouter.Params) { |
| 113 » serviceName := p.ByName("service") |
| 114 » methodName := p.ByName("method") |
| 115 » res := s.respond(c, w, r, serviceName, methodName) |
| 116 |
| 117 » c = logging.SetFields(c, logging.Fields{ |
| 118 » » "service": serviceName, |
| 119 » » "method": methodName, |
| 120 » }) |
| 121 » res.write(c, w) |
| 122 } |
| 123 |
| 124 func (s *Server) respond(c context.Context, w http.ResponseWriter, r *http.Reque
st, serviceName, methodName string) *response { |
| 125 » if r.Method != "POST" { |
| 126 » » res := errResponse(codes.Unimplemented, http.StatusMethodNotAllo
wed, "HTTP method must be POST") |
| 127 » » res.header.Set("Allow", "POST") |
| 128 » » return res |
| 99 } | 129 } |
| 130 |
| 131 service := s.services[serviceName] |
| 132 if service == nil { |
| 133 return errResponse( |
| 134 codes.Unimplemented, |
| 135 http.StatusNotImplemented, |
| 136 fmt.Sprintf("service %q is not implemented", serviceName
)) |
| 137 } |
| 138 |
| 139 method := service.methods[methodName] |
| 140 if method == nil { |
| 141 return errResponse( |
| 142 codes.Unimplemented, |
| 143 http.StatusNotImplemented, |
| 144 fmt.Sprintf("method %q in service %q is not implemented"
, methodName, serviceName)) |
| 145 } |
| 146 |
| 147 return method.handle(c, w, r) |
| 100 } | 148 } |
| 101 | 149 |
| 102 // ServiceNames returns a sorted list of full names of all registered services. | 150 // ServiceNames returns a sorted list of full names of all registered services. |
| 103 func (s *Server) ServiceNames() []string { | 151 func (s *Server) ServiceNames() []string { |
| 104 s.mu.Lock() | 152 s.mu.Lock() |
| 105 defer s.mu.Unlock() | 153 defer s.mu.Unlock() |
| 106 | 154 |
| 107 names := make([]string, 0, len(s.services)) | 155 names := make([]string, 0, len(s.services)) |
| 108 for name := range s.services { | 156 for name := range s.services { |
| 109 names = append(names, name) | 157 names = append(names, name) |
| 110 } | 158 } |
| 111 sort.Strings(names) | 159 sort.Strings(names) |
| 112 return names | 160 return names |
| 113 } | 161 } |
| OLD | NEW |