Browse Source

Merge pull request #2783 from bentiss/openapi_gen

Openapi gen fixes
Lauri Ojansivu 5 years ago
parent
commit
e868f9dcee
1 changed files with 141 additions and 43 deletions
  1. 141 43
      openapi/generate_openapi.py

+ 141 - 43
openapi/generate_openapi.py

@@ -3,9 +3,15 @@
 import argparse
 import argparse
 import esprima
 import esprima
 import json
 import json
+import logging
 import os
 import os
 import re
 import re
 import sys
 import sys
+import traceback
+
+
+logger = logging.getLogger(__name__)
+err_context = 3
 
 
 
 
 def get_req_body_elems(obj, elems):
 def get_req_body_elems(obj, elems):
@@ -156,16 +162,25 @@ class EntryPoint(object):
     def compute_path(self):
     def compute_path(self):
         return self._path.value.rstrip('/')
         return self._path.value.rstrip('/')
 
 
-    def error(self, message):
+    def log(self, message, level):
         if self._raw_doc is None:
         if self._raw_doc is None:
-            sys.stderr.write('in {},\n'.format(self.schema.name))
-            sys.stderr.write('{}\n'.format(message))
+            logger.log(level, 'in {},'.format(self.schema.name))
+            logger.log(level, message)
             return
             return
-        sys.stderr.write('in {}, lines {}-{}\n'.format(self.schema.name,
-                                                       self._raw_doc.loc.start.line,
-                                                       self._raw_doc.loc.end.line))
-        sys.stderr.write('{}\n'.format(self._raw_doc.value))
-        sys.stderr.write('{}\n'.format(message))
+        logger.log(level, 'in {}, lines {}-{}'.format(self.schema.name,
+                                                      self._raw_doc.loc.start.line,
+                                                      self._raw_doc.loc.end.line))
+        logger.log(level, self._raw_doc.value)
+        logger.log(level, message)
+
+    def error(self, message):
+        return self.log(message, logging.ERROR)
+
+    def warn(self, message):
+        return self.log(message, logging.WARNING)
+
+    def info(self, message):
+        return self.log(message, logging.INFO)
 
 
     @property
     @property
     def doc(self):
     def doc(self):
@@ -235,7 +250,7 @@ class EntryPoint(object):
                 if name.startswith('{'):
                 if name.startswith('{'):
                     param_type = name.strip('{}')
                     param_type = name.strip('{}')
                     if param_type not in ['string', 'number', 'boolean', 'integer', 'array', 'file']:
                     if param_type not in ['string', 'number', 'boolean', 'integer', 'array', 'file']:
-                        self.error('Warning, unknown type {}\n allowed values: string, number, boolean, integer, array, file'.format(param_type))
+                        self.warn('unknown type {}\n allowed values: string, number, boolean, integer, array, file'.format(param_type))
                     try:
                     try:
                         name, desc = desc.split(maxsplit=1)
                         name, desc = desc.split(maxsplit=1)
                     except ValueError:
                     except ValueError:
@@ -248,7 +263,7 @@ class EntryPoint(object):
 
 
                 # we should not have 2 identical parameter names
                 # we should not have 2 identical parameter names
                 if tag in params:
                 if tag in params:
-                    self.error('Warning, overwriting parameter {}'.format(name))
+                    self.warn('overwriting parameter {}'.format(name))
 
 
                 params[name] = (param_type, optional, desc)
                 params[name] = (param_type, optional, desc)
 
 
@@ -278,7 +293,7 @@ class EntryPoint(object):
 
 
             # we should not have 2 identical tags but @param or @tag
             # we should not have 2 identical tags but @param or @tag
             if tag in self._doc:
             if tag in self._doc:
-                self.error('Warning, overwriting tag {}'.format(tag))
+                self.warn('overwriting tag {}'.format(tag))
 
 
             self._doc[tag] = data
             self._doc[tag] = data
 
 
@@ -301,7 +316,7 @@ class EntryPoint(object):
                     current_data = ''
                     current_data = ''
                     line = data
                     line = data
                 else:
                 else:
-                    self.error('Unknown tag {}, ignoring'.format(tag))
+                    self.info('Unknown tag {}, ignoring'.format(tag))
 
 
             current_data += line + '\n'
             current_data += line + '\n'
 
 
@@ -441,7 +456,7 @@ class EntryPoint(object):
 
 
 
 
 class SchemaProperty(object):
 class SchemaProperty(object):
-    def __init__(self, statement, schema):
+    def __init__(self, statement, schema, context):
         self.schema = schema
         self.schema = schema
         self.statement = statement
         self.statement = statement
         self.name = statement.key.name or statement.key.value
         self.name = statement.key.name or statement.key.value
@@ -449,22 +464,75 @@ class SchemaProperty(object):
         self.blackbox = False
         self.blackbox = False
         self.required = True
         self.required = True
         for p in statement.value.properties:
         for p in statement.value.properties:
-            if p.key.name == 'type':
-                if p.value.type == 'Identifier':
-                    self.type = p.value.name.lower()
-                elif p.value.type == 'ArrayExpression':
-                    self.type = 'array'
-                    self.elements = [e.name.lower() for e in p.value.elements]
-
-            elif p.key.name == 'allowedValues':
-                self.type = 'enum'
-                self.enum = [e.value.lower() for e in p.value.elements]
-
-            elif p.key.name == 'blackbox':
-                self.blackbox = True
-
-            elif p.key.name == 'optional' and p.value.value:
-                self.required = False
+            try:
+                if p.key.name == 'type':
+                    if p.value.type == 'Identifier':
+                        self.type = p.value.name.lower()
+                    elif p.value.type == 'ArrayExpression':
+                        self.type = 'array'
+                        self.elements = [e.name.lower() for e in p.value.elements]
+
+                elif p.key.name == 'allowedValues':
+                    self.type = 'enum'
+                    if p.value.type == 'ArrayExpression':
+                        self.enum = [e.value.lower() for e in p.value.elements]
+                    elif p.value.type == 'Identifier':
+                        # tree wide lookout for the identifier
+                        def find_variable(elem, match):
+                            if isinstance(elem, list):
+                                for value in elem:
+                                    ret = find_variable(value, match)
+                                    if ret is not None:
+                                        return ret
+
+                            try:
+                                items = elem.items()
+                            except AttributeError:
+                                return None
+                            except TypeError:
+                                return None
+
+                            if (elem.type == 'VariableDeclarator' and
+                               elem.id.name == match):
+                                return elem
+
+                            for type, value in items:
+                                ret = find_variable(value, match)
+                                if ret is not None:
+                                    return ret
+
+                            return None
+
+                        elem = find_variable(context.program.body, p.value.name)
+
+                        if elem.init.type != 'ArrayExpression':
+                            raise TypeError('can not find "{}"'.format(p.value.name))
+
+                        self.enum = [e.value.lower() for e in elem.init.elements]
+
+                elif p.key.name == 'blackbox':
+                    self.blackbox = True
+
+                elif p.key.name == 'optional' and p.value.value:
+                    self.required = False
+            except Exception:
+                input = ''
+                for line in range(p.loc.start.line - err_context, p.loc.end.line + 1 + err_context):
+                    if line < p.loc.start.line or line > p.loc.end.line:
+                        input += '. '
+                    else:
+                        input += '>>'
+                    input += context.text_at(line, line)
+                input = ''.join(input)
+                logger.error('{}:{}-{} can not parse {}:\n{}'.format(context.path,
+                                                                     p.loc.start.line,
+                                                                     p.loc.end.line,
+                                                                     p.type,
+                                                                     input))
+                logger.error('esprima tree:\n{}'.format(p))
+
+                logger.error(traceback.format_exc())
+                sys.exit(1)
 
 
         self._doc = None
         self._doc = None
         self._raw_doc = None
         self._raw_doc = None
@@ -574,7 +642,7 @@ class SchemaProperty(object):
 
 
 
 
 class Schemas(object):
 class Schemas(object):
-    def __init__(self, data=None, jsdocs=None, name=None):
+    def __init__(self, context, data=None, jsdocs=None, name=None):
         self.name = name
         self.name = name
         self._data = data
         self._data = data
         self.fields = None
         self.fields = None
@@ -585,7 +653,7 @@ class Schemas(object):
                 self.name = data.expression.callee.object.name
                 self.name = data.expression.callee.object.name
 
 
             content = data.expression.arguments[0].arguments[0]
             content = data.expression.arguments[0].arguments[0]
-            self.fields = [SchemaProperty(p, self) for p in content.properties]
+            self.fields = [SchemaProperty(p, self, context) for p in content.properties]
 
 
         self._doc = None
         self._doc = None
         self._raw_doc = None
         self._raw_doc = None
@@ -665,6 +733,27 @@ class Schemas(object):
                 print('      - {}'.format(f))
                 print('      - {}'.format(f))
 
 
 
 
+class Context(object):
+    def __init__(self, path):
+        self.path = path
+
+        with open(path) as f:
+            self._txt = f.readlines()
+
+        data = ''.join(self._txt)
+        self.program = esprima.parseModule(data,
+                                           options={
+                                               'comment': True,
+                                               'loc': True
+                                           })
+
+    def txt_for(self, statement):
+        return self.text_at(statement.loc.start.line, statement.loc.end.line)
+
+    def text_at(self, begin, end):
+        return ''.join(self._txt[begin - 1:end])
+
+
 def parse_schemas(schemas_dir):
 def parse_schemas(schemas_dir):
 
 
     schemas = {}
     schemas = {}
@@ -674,17 +763,19 @@ def parse_schemas(schemas_dir):
         files.sort()
         files.sort()
         for filename in files:
         for filename in files:
             path = os.path.join(root, filename)
             path = os.path.join(root, filename)
-            with open(path) as f:
-                data = ''.join(f.readlines())
-                try:
-                    # if the file failed, it's likely it doesn't contain a schema
-                    program = esprima.parseModule(data, options={'comment': True, 'loc': True})
-                except:
-                    continue
+            try:
+                # if the file failed, it's likely it doesn't contain a schema
+                context = Context(path)
+            except:
+                continue
+
+            program = context.program
 
 
-                current_schema = None
-                jsdocs = [c for c in program.comments
-                          if c.type == 'Block' and c.value.startswith('*\n')]
+            current_schema = None
+            jsdocs = [c for c in program.comments
+                      if c.type == 'Block' and c.value.startswith('*\n')]
+
+            try:
 
 
                 for statement in program.body:
                 for statement in program.body:
 
 
@@ -697,7 +788,7 @@ def parse_schemas(schemas_dir):
                        statement.expression.arguments[0].type == 'NewExpression' and
                        statement.expression.arguments[0].type == 'NewExpression' and
                        statement.expression.arguments[0].callee.name == 'SimpleSchema'):
                        statement.expression.arguments[0].callee.name == 'SimpleSchema'):
 
 
-                        schema = Schemas(statement, jsdocs)
+                        schema = Schemas(context, statement, jsdocs)
                         current_schema = schema.name
                         current_schema = schema.name
                         schemas[current_schema] = schema
                         schemas[current_schema] = schema
 
 
@@ -717,7 +808,7 @@ def parse_schemas(schemas_dir):
                             if len(data) > 0:
                             if len(data) > 0:
                                 if current_schema is None:
                                 if current_schema is None:
                                     current_schema = filename
                                     current_schema = filename
-                                    schemas[current_schema] = Schemas(name=current_schema)
+                                    schemas[current_schema] = Schemas(context, name=current_schema)
 
 
                                 schema_entry_points = [EntryPoint(schemas[current_schema], d)
                                 schema_entry_points = [EntryPoint(schemas[current_schema], d)
                                                        for d in data]
                                                        for d in data]
@@ -730,6 +821,13 @@ def parse_schemas(schemas_dir):
                                              if j.loc.end.line + 1 == operation.loc.start.line]
                                              if j.loc.end.line + 1 == operation.loc.start.line]
                                     if bool(jsdoc):
                                     if bool(jsdoc):
                                         entry_point.doc = jsdoc[0]
                                         entry_point.doc = jsdoc[0]
+            except TypeError:
+                logger.warning(context.txt_for(statement))
+                logger.error('{}:{}-{} can not parse {}'.format(path,
+                                                                statement.loc.start.line,
+                                                                statement.loc.end.line,
+                                                                statement.type))
+                raise
 
 
     return schemas, entry_points
     return schemas, entry_points