| Index: tools/cmd/cproto/transform.go
|
| diff --git a/tools/cmd/cproto/transform.go b/tools/cmd/cproto/transform.go
|
| index 8555ecfedeccc13c145e882fe066bcfe9a59cafd..626eafcebfbba17dc57aa929b875a7363344f581 100644
|
| --- a/tools/cmd/cproto/transform.go
|
| +++ b/tools/cmd/cproto/transform.go
|
| @@ -8,45 +8,55 @@ package main
|
|
|
| import (
|
| "bytes"
|
| + "fmt"
|
| + "io/ioutil"
|
| + "strings"
|
| + "text/template"
|
| + "unicode/utf8"
|
| +
|
| "go/ast"
|
| + "go/format"
|
| "go/parser"
|
| "go/printer"
|
| "go/token"
|
| - "io/ioutil"
|
| - "strings"
|
| )
|
|
|
| const (
|
| - prpcPackagePath = `github.com/luci/luci-go/server/prpc`
|
| + serverPrpcPackagePath = `github.com/luci/luci-go/server/prpc`
|
| + commonPrpcPackagePath = `github.com/luci/luci-go/common/prpc`
|
| )
|
|
|
| var (
|
| - prpcPkg = ast.NewIdent("prpc")
|
| - registrarName = ast.NewIdent("Registrar")
|
| + serverPrpcPkg = ast.NewIdent("prpc")
|
| + commonPrpcPkg = ast.NewIdent("prpccommon")
|
| )
|
|
|
| type transformer struct {
|
| + fset *token.FileSet
|
| inPRPCPackage bool
|
| PackageName string
|
| }
|
|
|
| // transformGoFile rewrites a .go file to work with prpc.
|
| func (t *transformer) transformGoFile(filename string) error {
|
| - fset := token.NewFileSet()
|
| - file, err := parser.ParseFile(fset, filename, nil, parser.ParseComments)
|
| + t.fset = token.NewFileSet()
|
| + file, err := parser.ParseFile(t.fset, filename, nil, parser.ParseComments)
|
| if err != nil {
|
| return err
|
| }
|
|
|
| t.PackageName = file.Name.Name
|
| - t.inPRPCPackage, err = isInPackage(filename, prpcPackagePath)
|
| + t.inPRPCPackage, err = isInPackage(filename, serverPrpcPackagePath)
|
| if err != nil {
|
| return err
|
| }
|
| - t.transformFile(file)
|
| +
|
| + if err := t.transformFile(file); err != nil {
|
| + return err
|
| + }
|
|
|
| var buf bytes.Buffer
|
| - if err := printer.Fprint(&buf, fset, file); err != nil {
|
| + if err := printer.Fprint(&buf, t.fset, file); err != nil {
|
| return err
|
| }
|
| formatted, err := gofmt(buf.Bytes())
|
| @@ -57,19 +67,28 @@ func (t *transformer) transformGoFile(filename string) error {
|
| return ioutil.WriteFile(filename, formatted, 0666)
|
| }
|
|
|
| -func (t *transformer) transformFile(file *ast.File) {
|
| +func (t *transformer) transformFile(file *ast.File) error {
|
| if t.transformRegisterServerFuncs(file) && !t.inPRPCPackage {
|
| - t.insertPrpcImport(file)
|
| + t.insertImport(file, serverPrpcPkg, serverPrpcPackagePath)
|
| + }
|
| + changed, err := t.generateClients(file)
|
| + if err != nil {
|
| + return err
|
| }
|
| + if changed {
|
| + t.insertImport(file, commonPrpcPkg, commonPrpcPackagePath)
|
| + }
|
| + return nil
|
| }
|
|
|
| // transformRegisterServerFuncs finds RegisterXXXServer functions and
|
| // checks its first parameter type to prpc.Registrar.
|
| // Returns true if modified ast.
|
| func (t *transformer) transformRegisterServerFuncs(file *ast.File) bool {
|
| + registrarName := ast.NewIdent("Registrar")
|
| var registrarType ast.Expr = registrarName
|
| if !t.inPRPCPackage {
|
| - registrarType = &ast.SelectorExpr{prpcPkg, registrarName}
|
| + registrarType = &ast.SelectorExpr{serverPrpcPkg, registrarName}
|
| }
|
|
|
| changed := false
|
| @@ -93,12 +112,133 @@ func (t *transformer) transformRegisterServerFuncs(file *ast.File) bool {
|
| return changed
|
| }
|
|
|
| -func (t *transformer) insertPrpcImport(file *ast.File) {
|
| +// generateClients finds client interface declarations
|
| +// and inserts pRPC implementations after them.
|
| +func (t *transformer) generateClients(file *ast.File) (bool, error) {
|
| + changed := false
|
| + for i := len(file.Decls) - 1; i >= 0; i-- {
|
| + genDecl, ok := file.Decls[i].(*ast.GenDecl)
|
| + if !ok || genDecl.Tok != token.TYPE {
|
| + continue
|
| + }
|
| + for _, spec := range genDecl.Specs {
|
| + spec := spec.(*ast.TypeSpec)
|
| + const suffix = "Client"
|
| + if !strings.HasSuffix(spec.Name.Name, suffix) {
|
| + continue
|
| + }
|
| + serviceName := strings.TrimSuffix(spec.Name.Name, suffix)
|
| +
|
| + iface, ok := spec.Type.(*ast.InterfaceType)
|
| + if !ok {
|
| + continue
|
| + }
|
| +
|
| + newDecls, err := t.generateClient(file.Name.Name, serviceName, iface)
|
| + if err != nil {
|
| + return false, err
|
| + }
|
| + file.Decls = append(file.Decls[:i+1], append(newDecls, file.Decls[i+1:]...)...)
|
| + changed = true
|
| + }
|
| + }
|
| + return changed, nil
|
| +}
|
| +
|
| +var clientCodeTemplate = template.Must(template.New("").Parse(`
|
| +package template
|
| +
|
| +type {{$.StructName}} struct {
|
| + client *prpccommon.Client
|
| +}
|
| +
|
| +func New{{.Service}}PRPCClient(client *prpccommon.Client) {{.Service}}Client {
|
| + return &{{$.StructName}}{client}
|
| +}
|
| +
|
| +{{range .Methods}}
|
| +func (c *{{$.StructName}}) {{.Name}}(ctx context.Context, in *{{.InputMessage}}, opts ...grpc.CallOption) (*{{.OutputMessage}}, error) {
|
| + out := new({{.OutputMessage}})
|
| + err := c.client.Call(ctx, "{{$.Pkg}}.{{$.Service}}", "{{.Name}}", in, out, opts...)
|
| + if err != nil {
|
| + return nil, err
|
| + }
|
| + return out, nil
|
| +}
|
| +{{end}}
|
| +`))
|
| +
|
| +// generateClient generates pRPC implementation of a client interface.
|
| +func (t *transformer) generateClient(packageName, serviceName string, iface *ast.InterfaceType) ([]ast.Decl, error) {
|
| + // This function used to construct an AST. It was a lot of code.
|
| + // Now it generates code via a template and parses back to AST.
|
| + // Slower, but saner and easier to make changes.
|
| +
|
| + type Method struct {
|
| + Name string
|
| + InputMessage string
|
| + OutputMessage string
|
| + }
|
| + methods := make([]Method, 0, len(iface.Methods.List))
|
| +
|
| + var buf bytes.Buffer
|
| + toGoCode := func(n ast.Node) (string, error) {
|
| + defer buf.Reset()
|
| + err := format.Node(&buf, t.fset, n)
|
| + if err != nil {
|
| + return "", err
|
| + }
|
| + return buf.String(), nil
|
| + }
|
| +
|
| + for _, m := range iface.Methods.List {
|
| + signature, ok := m.Type.(*ast.FuncType)
|
| + if !ok {
|
| + return nil, fmt.Errorf("unexpected embedded interface in %sClient", serviceName)
|
| + }
|
| +
|
| + inStructPtr := signature.Params.List[1].Type.(*ast.StarExpr)
|
| + inStruct, err := toGoCode(inStructPtr.X)
|
| + if err != nil {
|
| + return nil, err
|
| + }
|
| +
|
| + outStructPtr := signature.Results.List[0].Type.(*ast.StarExpr)
|
| + outStruct, err := toGoCode(outStructPtr.X)
|
| + if err != nil {
|
| + return nil, err
|
| + }
|
| +
|
| + methods = append(methods, Method{
|
| + Name: m.Names[0].Name,
|
| + InputMessage: inStruct,
|
| + OutputMessage: outStruct,
|
| + })
|
| + }
|
| +
|
| + err := clientCodeTemplate.Execute(&buf, map[string]interface{}{
|
| + "Pkg": packageName,
|
| + "Service": serviceName,
|
| + "StructName": firstLower(serviceName) + "PRPCClient",
|
| + "Methods": methods,
|
| + })
|
| + if err != nil {
|
| + return nil, fmt.Errorf("client template execution: %s", err)
|
| + }
|
| +
|
| + f, err := parser.ParseFile(t.fset, "", buf.String(), 0)
|
| + if err != nil {
|
| + return nil, fmt.Errorf("client template result parsing: %s. Code: %#v", err, buf.String())
|
| + }
|
| + return f.Decls, nil
|
| +}
|
| +
|
| +func (t *transformer) insertImport(file *ast.File, name *ast.Ident, path string) {
|
| spec := &ast.ImportSpec{
|
| - Name: prpcPkg,
|
| + Name: name,
|
| Path: &ast.BasicLit{
|
| Kind: token.STRING,
|
| - Value: `"` + prpcPackagePath + `"`,
|
| + Value: `"` + path + `"`,
|
| },
|
| }
|
| importDecl := &ast.GenDecl{
|
| @@ -107,3 +247,8 @@ func (t *transformer) insertPrpcImport(file *ast.File) {
|
| }
|
| file.Decls = append([]ast.Decl{importDecl}, file.Decls...)
|
| }
|
| +
|
| +func firstLower(s string) string {
|
| + _, w := utf8.DecodeRuneInString(s)
|
| + return strings.ToLower(s[:w]) + s[w:]
|
| +}
|
|
|