Index: third_party/mojo/src/mojo/public/tools/bindings/generators/mojom_go_generator.py |
diff --git a/third_party/mojo/src/mojo/public/tools/bindings/generators/mojom_go_generator.py b/third_party/mojo/src/mojo/public/tools/bindings/generators/mojom_go_generator.py |
index 11962767e4b816a7119a1648e5fed38cd42b16f7..bdcb84f65a340e16a6a5c33b5602d105c90d14e1 100644 |
--- a/third_party/mojo/src/mojo/public/tools/bindings/generators/mojom_go_generator.py |
+++ b/third_party/mojo/src/mojo/public/tools/bindings/generators/mojom_go_generator.py |
@@ -57,6 +57,8 @@ _kind_infos = { |
mojom.NULLABLE_STRING: KindInfo('string', 'String', 'String', 64), |
} |
+_imports = {} |
+ |
def GetBitSize(kind): |
if isinstance(kind, (mojom.Array, mojom.Map, mojom.Struct)): |
return 64 |
@@ -66,14 +68,18 @@ def GetBitSize(kind): |
kind = mojom.INT32 |
return _kind_infos[kind].bit_size |
+# Returns go type corresponding to provided kind. If |nullable| is true |
+# and kind is nullable adds an '*' to type (example: ?string -> *string). |
def GetGoType(kind, nullable = True): |
if nullable and mojom.IsNullableKind(kind): |
return '*%s' % GetNonNullableGoType(kind) |
return GetNonNullableGoType(kind) |
+# Returns go type corresponding to provided kind. Ignores nullability of |
+# top-level kind. |
def GetNonNullableGoType(kind): |
if mojom.IsStructKind(kind): |
- return '%s' % FormatName(kind.name) |
+ return '%s' % GetFullName(kind) |
if mojom.IsArrayKind(kind): |
if kind.length: |
return '[%s]%s' % (kind.length, GetGoType(kind.kind)) |
@@ -86,6 +92,8 @@ def GetNonNullableGoType(kind): |
return GetNameForNestedElement(kind) |
return _kind_infos[kind].go_type |
+# Splits name to lower-cased parts used for camel-casing |
+# (example: HTTPEntry2FooBar -> ['http', 'entry2', 'foo', 'bar']). |
def NameToComponent(name): |
# insert '_' between anything and a Title name (e.g, HTTPEntry2FooBar -> |
# HTTP_Entry2_FooBar) |
@@ -98,24 +106,44 @@ def NameToComponent(name): |
def UpperCamelCase(name): |
return ''.join([x.capitalize() for x in NameToComponent(name)]) |
+# Formats a name. If |exported| is true makes name camel-cased with first |
+# letter capital, otherwise does no camel-casing and makes first letter |
+# lower-cased (which is used for making internal names more readable). |
def FormatName(name, exported=True): |
if exported: |
return UpperCamelCase(name) |
# Leave '_' symbols for unexported names. |
return name[0].lower() + name[1:] |
+# Returns full name of an imported element based on prebuilt dict |_imports|. |
+# If the |element| is not imported returns formatted name of it. |
+# |element| should have attr 'name'. |exported| argument is used to make |
+# |FormatName()| calls only. |
+def GetFullName(element, exported=True): |
+ if not hasattr(element, 'imported_from') or not element.imported_from: |
+ return FormatName(element.name, exported) |
+ path = 'gen/mojom' |
+ if element.imported_from['namespace']: |
+ path = '/'.join([path] + element.imported_from['namespace'].split('.')) |
+ if path in _imports: |
+ return '%s.%s' % (_imports[path], FormatName(element.name, exported)) |
+ return FormatName(element.name, exported) |
+ |
+# Returns a name for nested elements like enum field or constant. |
+# The returned name consists of camel-cased parts separated by '_'. |
def GetNameForNestedElement(element): |
if element.parent_kind: |
return "%s_%s" % (GetNameForElement(element.parent_kind), |
FormatName(element.name)) |
- return FormatName(element.name) |
+ return GetFullName(element) |
def GetNameForElement(element, exported=True): |
- if (mojom.IsInterfaceKind(element) or mojom.IsStructKind(element) or |
- isinstance(element, (mojom.EnumField, |
- mojom.Field, |
- mojom.Method, |
- mojom.Parameter))): |
+ if (mojom.IsInterfaceKind(element) or mojom.IsStructKind(element)): |
+ return GetFullName(element, exported) |
+ if isinstance(element, (mojom.EnumField, |
+ mojom.Field, |
+ mojom.Method, |
+ mojom.Parameter)): |
return FormatName(element.name, exported) |
if isinstance(element, (mojom.Enum, |
mojom.Constant, |
@@ -147,14 +175,14 @@ def EncodeSuffix(kind): |
return EncodeSuffix(mojom.MSGPIPE) |
return _kind_infos[kind].encode_suffix |
-def GetPackage(module): |
- if module.namespace: |
- return module.namespace.split('.')[-1] |
+def GetPackage(namespace): |
+ if namespace: |
+ return namespace.split('.')[-1] |
return 'mojom' |
-def GetPackagePath(module): |
+def GetPackagePath(namespace): |
path = 'mojom' |
- for i in module.namespace.split('.'): |
+ for i in namespace.split('.'): |
path = os.path.join(path, i) |
return path |
@@ -182,6 +210,74 @@ def GetResponseStructFromMethod(method): |
struct.versions = pack.GetVersionInfo(struct.packed) |
return struct |
+def GetAllConstants(module): |
+ data = [module] + module.structs + module.interfaces |
+ constants = [x.constants for x in data] |
+ return [i for i in chain.from_iterable(constants)] |
+ |
+def GetAllEnums(module): |
+ data = [module] + module.structs + module.interfaces |
+ enums = [x.enums for x in data] |
+ return [i for i in chain.from_iterable(enums)] |
+ |
+# Adds an import required to use the provided |element|. |
+# The required import is stored at '_imports'. |
+def AddImport(module, element): |
+ if not hasattr(element, 'imported_from') or not element.imported_from: |
+ return |
+ if isinstance(element, mojom.Kind) and mojom.IsAnyHandleKind(element): |
+ return |
+ imported = element.imported_from |
+ if imported['namespace'] == module.namespace: |
+ return |
+ path = 'gen/mojom' |
+ name = 'mojom' |
+ if imported['namespace']: |
+ path = '/'.join([path] + imported['namespace'].split('.')) |
+ name = '_'.join([name] + imported['namespace'].split('.')) |
+ while (name in _imports.values() and _imports[path] != path): |
+ name += '_' |
+ _imports[path] = name |
+ |
+# Scans |module| for elements that require imports and adds all found imports |
+# to '_imports' dict. Returns a list of imports that should include the |
+# generated go file. |
+def GetImports(module): |
+ # Imports can only be used in structs, constants, enums, interfaces. |
+ all_structs = list(module.structs) |
+ for i in module.interfaces: |
+ AddImport(module, i) |
+ for method in i.methods: |
+ all_structs.append(GetStructFromMethod(method)) |
+ if method.response_parameters: |
+ all_structs.append(GetResponseStructFromMethod(method)) |
+ |
+ if len(all_structs) > 0: |
+ _imports['mojo/public/go/bindings'] = 'bindings' |
+ for struct in all_structs: |
+ for field in struct.fields: |
+ AddImport(module, field.kind) |
+# TODO(rogulenko): add these after generating constants and struct defaults. |
+# if field.default: |
+# AddImport(module, field.default) |
+ |
+ for enum in GetAllEnums(module): |
+ for field in enum.fields: |
+ if field.value: |
+ AddImport(module, field.value) |
+ |
+# TODO(rogulenko): add these after generating constants and struct defaults. |
+# for constant in GetAllConstants(module): |
+# AddImport(module, constant.value) |
+ |
+ imports_list = [] |
+ for i in _imports: |
+ if i.split('/')[-1] == _imports[i]: |
+ imports_list.append('"%s"' % i) |
+ else: |
+ imports_list.append('%s "%s"' % (_imports[i], i)) |
+ return sorted(imports_list) |
+ |
class Generator(generator.Generator): |
go_filters = { |
'array': lambda kind: mojom.Array(kind), |
@@ -204,16 +300,12 @@ class Generator(generator.Generator): |
'tab_indent': lambda s, size = 1: ('\n' + '\t' * size).join(s.splitlines()) |
} |
- def GetAllEnums(self): |
- data = [self.module] + self.GetStructs() + self.module.interfaces |
- enums = [x.enums for x in data] |
- return [i for i in chain.from_iterable(enums)] |
- |
def GetParameters(self): |
return { |
- 'enums': self.GetAllEnums(), |
+ 'enums': GetAllEnums(self.module), |
+ 'imports': GetImports(self.module), |
'interfaces': self.module.interfaces, |
- 'package': GetPackage(self.module), |
+ 'package': GetPackage(self.module.namespace), |
'structs': self.GetStructs(), |
} |
@@ -223,7 +315,7 @@ class Generator(generator.Generator): |
def GenerateFiles(self, args): |
self.Write(self.GenerateSource(), os.path.join("go", "src", "gen", |
- GetPackagePath(self.module), '%s.go' % self.module.name)) |
+ GetPackagePath(self.module.namespace), '%s.go' % self.module.name)) |
def GetJinjaParameters(self): |
return { |