Allow multi-module processing.
authorJames Ward <j.ward@acfr.usyd.edu.au>
Tue, 8 Sep 2015 10:22:27 +0000 (20:22 +1000)
committerJames Ward <j.ward@acfr.usyd.edu.au>
Thu, 8 Oct 2015 05:02:57 +0000 (16:02 +1100)
When generating code, the list of modules is passed through Pyasn1Backend
in order to allow references across modules. Modules in this list
are also added to the output as import statements.

The main function in pyasn1gen.py has been modified to take a command
line switch to output multiple modules to separate files rather than
stdout.

asn1ate/pyasn1gen.py
asn1ate/sema.py
testdata/multi_module.asn

index 6cd0a0f..f5f5465 100644 (file)
@@ -26,6 +26,7 @@
 from __future__ import print_function  # Python 2 compatibility
 
 import sys
+import argparse
 import keyword
 from asn1ate import parser
 from asn1ate.support import pygen
@@ -80,8 +81,9 @@ class Pyasn1Backend(object):
     type assignment involves a constructed type, it is filled with inline
     definitions.
     """
-    def __init__(self, sema_module, out_stream):
+    def __init__(self, sema_module, out_stream, referenced_modules):
         self.sema_module = sema_module
+        self.referenced_modules = referenced_modules
         self.writer = pygen.PythonWriter(out_stream)
 
         self.decl_generators = {
@@ -121,6 +123,9 @@ class Pyasn1Backend(object):
 
     def generate_code(self):
         self.writer.write_line('from pyasn1.type import univ, char, namedtype, namedval, tag, constraint, useful')
+        for module in self.referenced_modules:
+            if not module is self.sema_module:
+                self.writer.write_line('import ' + _sanitize_module(module.name))
         self.writer.write_blanks(2)
 
         # Generate _OID if sema_module contains any object identifier values.
@@ -283,7 +288,10 @@ class Pyasn1Backend(object):
         return type_expr
 
     def inline_defined_type(self, t):
-        return _translate_type(t.type_name) + '()'
+        translated_type = _translate_type(t.type_name) + '()'
+        if t.module_name and t.module_name != self.sema_module.name:
+            translated_type = _sanitize_module(t.module_name) + '.' + translated_type
+        return translated_type 
 
     def inline_constructed_type(self, t):
         fragment = self.writer.get_fragment()
@@ -335,7 +343,7 @@ class Pyasn1Backend(object):
     def build_tag_expr(self, tag_def):
         context = _translate_tag_class(tag_def.class_name)
 
-        tagged_type_decl = self.sema_module.resolve_type_decl(tag_def.type_decl)
+        tagged_type_decl = self.sema_module.resolve_type_decl(tag_def.type_decl, self.referenced_modules)
         if isinstance(tagged_type_decl, ConstructedType):
             tag_format = 'tag.tagFormatConstructed'
         else:
@@ -389,14 +397,14 @@ class Pyasn1Backend(object):
             return self.build_object_identifier_value(value)
         else:
             value_type = _translate_type(type_decl.type_name)
-            root_type = self.sema_module.resolve_type_decl(type_decl)
+            root_type = self.sema_module.resolve_type_decl(type_decl, self.referenced_modules)
             return '%s(%s)' % (value_type, build_value_expr(root_type.type_name, value))
 
     def inline_component_type(self, t):
         if t.components_of_type:
             # COMPONENTS OF works like a literal include, so just
             # expand all components of the referenced type.
-            included_type_decl = self.sema_module.resolve_type_decl(t.components_of_type)
+            included_type_decl = self.sema_module.resolve_type_decl(t.components_of_type, self.referenced_modules)
             included_content = self.inline_component_types(included_type_decl.components)
 
             # Strip trailing newline from inline_component_types
@@ -474,8 +482,8 @@ class Pyasn1Backend(object):
         return str(fragment)
 
 
-def generate_pyasn1(sema_module, out_stream):
-    return Pyasn1Backend(sema_module, out_stream).generate_code()
+def generate_pyasn1(sema_module, out_stream, referenced_modules):
+    return Pyasn1Backend(sema_module, out_stream, referenced_modules).generate_code()
 
 
 # Translation tables from ASN.1 primitives to pyasn1 primitives
@@ -567,23 +575,40 @@ def _sanitize_identifier(name):
     return name
 
 
+def _sanitize_module(name):
+    """ Sanitize ASN.1 module identifiers so that they're PEP8 compliant identifiers.
+    """
+    return _sanitize_identifier(name).lower()
+
 # Simplistic command-line driver
-def main(args):
-    with open(args[0]) as f:
-        asn1def = f.read()
+def main():
+    arg_parser = argparse.ArgumentParser(description='Generate Python classes from an ASN.1 definition file. Output to stdout by default.')
+    arg_parser.add_argument('file', metavar='file', type=argparse.FileType('r'),
+                            help='the ASN.1 file to process')
+    arg_parser.add_argument('--split', action='store_true',
+                            help='output multiple modules to separate files')
+    args = arg_parser.parse_args()
+    asn1def = args.file.read()
 
     parse_tree = parser.parse_asn1(asn1def)
 
     modules = build_semantic_model(parse_tree)
-    if len(modules) > 1:
+    if len(modules) > 1 and not args.split:
         print('WARNING: More than one module generated to the same stream.', file=sys.stderr)
 
+    output_file = sys.stdout
     for module in modules:
-        print(pygen.auto_generated_header())
-        generate_pyasn1(module, sys.stdout)
+        try:
+            if args.split:
+                output_file = open(_sanitize_module(module.name) + '.py', 'w')
+            print(pygen.auto_generated_header(), file=output_file)
+            generate_pyasn1(module, output_file, modules)
+        finally:
+            if output_file != sys.stdout:
+                output_file.close()
 
     return 0
 
 
 if __name__ == '__main__':
-    sys.exit(main(sys.argv[1:]))
+    sys.exit(main())
index 9c7b118..6d8bd21 100644 (file)
@@ -271,14 +271,24 @@ class Module(SemaNode):
 
         return self._user_types
 
-    def resolve_type_decl(self, type_decl):
+    def resolve_type_decl(self, type_decl, referenced_modules):
         """ Recursively resolve user-defined types to their built-in
         declaration.
         """
-        user_types = self.user_types()
-
         if isinstance(type_decl, ReferencedType):
-            return self.resolve_type_decl(user_types[type_decl.type_name])
+            module = None
+            if not type_decl.module_name or type_decl.module_name == self.name:
+                module = self
+            else:
+                # Find the referenced module
+                for ref_mod in referenced_modules:
+                    if ref_mod.name == type_decl.module_name:
+                        module = ref_mod
+                        break
+            if not module:
+                raise Exception('Unrecognized referenced module %s in %s.' % (type_decl.module_name,
+                                                                              [module.name for module in referenced_modules]))
+            return module.resolve_type_decl(module.user_types()[type_decl.type_name], referenced_modules)
         else:
             return type_decl
 
@@ -496,9 +506,12 @@ class ReferencedType(SemaNode):
 
 class DefinedType(ReferencedType):
     def __init__(self, elements):
-        # TODO: Module references are not resolved at the moment,
-        # and I'm not sure how to handle them.
+        self.constraint = None
+        self.module_name = None
+
         module_ref, type_ref, size_constraint = elements
+        if module_ref:
+            self.module_name = module_ref.elements[0]
         self.type_name = type_ref
         if size_constraint:
             self.constraint = _create_sema_node(size_constraint)
@@ -535,8 +548,6 @@ class SelectionType(ReferencedType):
 
 class ReferencedValue(SemaNode):
     def __init__(self, elements):
-        # TODO: Module references are not resolved at the moment,
-        # and I'm not sure how to handle them.
         if len(elements) > 1 and elements[0].ty == 'ModuleReference':
             self.module_reference = elements[0].elements[0]
             self.name = elements[1]
@@ -548,7 +559,9 @@ class ReferencedValue(SemaNode):
         return self.name
 
     def __str__(self):
-        return self.name
+        if not self.module_reference:
+            return self.name
+        return '%s.%s' % (self.module_reference, self.name)
 
     __repr__ = __str__
 
index 82024a2..2b1d5f9 100644 (file)
@@ -6,6 +6,8 @@ OneSequence ::= SEQUENCE {
     second BOOLEAN
 }
 
+value INTEGER ::= 123
+
 END
 
 
@@ -17,4 +19,11 @@ AnotherSequence ::= SEQUENCE {
     second BOOLEAN
 }
 
+YetAnotherSequence ::= SEQUENCE {
+    first [0] Module1.OneSequence,
+    second [1] AnotherSequence
+}
+
+anothervalue INTEGER ::= Module1.value
+
 END