Browse Source

[JSInterp] Add `_separate_at_op()`

dirkf 5 months ago
parent
commit
16b7e97afa
1 changed files with 75 additions and 51 deletions
  1. 75 51
      youtube_dl/jsinterp.py

+ 75 - 51
youtube_dl/jsinterp.py

@@ -704,6 +704,68 @@ class JSInterpreter(object):
                 _SC_OPERATORS, _LOG_OPERATORS, _COMP_OPERATORS, _OPERATORS, _UNARY_OPERATORS_X))
                 _SC_OPERATORS, _LOG_OPERATORS, _COMP_OPERATORS, _OPERATORS, _UNARY_OPERATORS_X))
         return _cached
         return _cached
 
 
+    def _separate_at_op(self, expr, max_split=None):
+
+        for op, _ in self._all_operators():
+            # hackety: </> have higher priority than <</>>, but don't confuse them
+            skip_delim = (op + op) if op in '<>*?' else None
+            if op == '?':
+                skip_delim = (skip_delim, '?.')
+            separated = list(self._separate(expr, op, skip_delims=skip_delim))
+            if len(separated) < 2:
+                continue
+
+            right_expr = separated.pop()
+            # handle operators that are both unary and binary, minimal BODMAS
+            if op in ('+', '-'):
+                # simplify/adjust consecutive instances of these operators
+                undone = 0
+                separated = [s.strip() for s in separated]
+                while len(separated) > 1 and not separated[-1]:
+                    undone += 1
+                    separated.pop()
+                if op == '-' and undone % 2 != 0:
+                    right_expr = op + right_expr
+                elif op == '+':
+                    while len(separated) > 1 and set(separated[-1]) <= self.OP_CHARS:
+                        right_expr = separated.pop() + right_expr
+                    if separated[-1][-1:] in self.OP_CHARS:
+                        right_expr = separated.pop() + right_expr
+                # hanging op at end of left => unary + (strip) or - (push right)
+                separated.append(right_expr)
+                dm_ops = ('*', '%', '/', '**')
+                dm_chars = set(''.join(dm_ops))
+
+                def yield_terms(s):
+                    skip = False
+                    for i, term in enumerate(s[:-1]):
+                        if skip:
+                            skip = False
+                            continue
+                        if not (dm_chars & set(term)):
+                            yield term
+                            continue
+                        for dm_op in dm_ops:
+                            bodmas = list(self._separate(term, dm_op, skip_delims=skip_delim))
+                            if len(bodmas) > 1 and not bodmas[-1].strip():
+                                bodmas[-1] = (op if op == '-' else '') + s[i + 1]
+                                yield dm_op.join(bodmas)
+                                skip = True
+                                break
+                        else:
+                            if term:
+                                yield term
+
+                    if not skip and s[-1]:
+                        yield s[-1]
+
+                separated = list(yield_terms(separated))
+                right_expr = separated.pop() if len(separated) > 1 else None
+                expr = op.join(separated)
+            if right_expr is None:
+                continue
+            return op, separated, right_expr
+
     def _operator(self, op, left_val, right_expr, expr, local_vars, allow_recursion):
     def _operator(self, op, left_val, right_expr, expr, local_vars, allow_recursion):
         if op in ('||', '&&'):
         if op in ('||', '&&'):
             if (op == '&&') ^ _js_ternary(left_val):
             if (op == '&&') ^ _js_ternary(left_val):
@@ -759,51 +821,9 @@ class JSInterpreter(object):
     _FINALLY_RE = re.compile(r'finally\s*\{')
     _FINALLY_RE = re.compile(r'finally\s*\{')
     _SWITCH_RE = re.compile(r'switch\s*\(')
     _SWITCH_RE = re.compile(r'switch\s*\(')
 
 
-    def handle_operators(self, expr, local_vars, allow_recursion):
-
-        for op, _ in self._all_operators():
-            # hackety: </> have higher priority than <</>>, but don't confuse them
-            skip_delim = (op + op) if op in '<>*?' else None
-            if op == '?':
-                skip_delim = (skip_delim, '?.')
-            separated = list(self._separate(expr, op, skip_delims=skip_delim))
-            if len(separated) < 2:
-                continue
-
-            right_expr = separated.pop()
-            # handle operators that are both unary and binary, minimal BODMAS
-            if op in ('+', '-'):
-                # simplify/adjust consecutive instances of these operators
-                undone = 0
-                separated = [s.strip() for s in separated]
-                while len(separated) > 1 and not separated[-1]:
-                    undone += 1
-                    separated.pop()
-                if op == '-' and undone % 2 != 0:
-                    right_expr = op + right_expr
-                elif op == '+':
-                    while len(separated) > 1 and set(separated[-1]) <= self.OP_CHARS:
-                        right_expr = separated.pop() + right_expr
-                    if separated[-1][-1:] in self.OP_CHARS:
-                        right_expr = separated.pop() + right_expr
-                # hanging op at end of left => unary + (strip) or - (push right)
-                left_val = separated[-1] if separated else ''
-                for dm_op in ('*', '%', '/', '**'):
-                    bodmas = tuple(self._separate(left_val, dm_op, skip_delims=skip_delim))
-                    if len(bodmas) > 1 and not bodmas[-1].strip():
-                        expr = op.join(separated) + op + right_expr
-                        if len(separated) > 1:
-                            separated.pop()
-                            right_expr = op.join((left_val, right_expr))
-                        else:
-                            separated = [op.join((left_val, right_expr))]
-                            right_expr = None
-                        break
-                if right_expr is None:
-                    continue
-
-            left_val = self.interpret_expression(op.join(separated), local_vars, allow_recursion)
-            return self._operator(op, left_val, right_expr, expr, local_vars, allow_recursion), True
+    def _eval_operator(self, op, left_expr, right_expr, expr, local_vars, allow_recursion):
+        left_val = self.interpret_expression(left_expr, local_vars, allow_recursion)
+        return self._operator(op, left_val, right_expr, expr, local_vars, allow_recursion)
 
 
     @Debugger.wrap_interpreter
     @Debugger.wrap_interpreter
     def interpret_statement(self, stmt, local_vars, allow_recursion=100):
     def interpret_statement(self, stmt, local_vars, allow_recursion=100):
@@ -865,9 +885,12 @@ class JSInterpreter(object):
             operand = expr[len(op):]
             operand = expr[len(op):]
             if not operand or operand[0] != ' ':
             if not operand or operand[0] != ' ':
                 continue
                 continue
-            op_result = self.handle_operators(expr, local_vars, allow_recursion)
-            if op_result:
-                return op_result[0], should_return
+            separated = self._separate_at_op(operand, max_split=1)
+            if separated:
+                next_op, separated, right_expr = separated
+                separated.append(right_expr)
+                operand = next_op.join(separated)
+            return self._eval_operator(op, operand, '', expr, local_vars, allow_recursion), should_return
 
 
         if expr.startswith('{'):
         if expr.startswith('{'):
             inner, outer = self._separate_at_paren(expr)
             inner, outer = self._separate_at_paren(expr)
@@ -1138,9 +1161,10 @@ class JSInterpreter(object):
                 val = self._index(val, idx)
                 val = self._index(val, idx)
             return val, should_return
             return val, should_return
 
 
-        op_result = self.handle_operators(expr, local_vars, allow_recursion)
-        if op_result:
-            return op_result[0], should_return
+        separated = self._separate_at_op(expr)
+        if separated:
+            op, separated, right_expr = separated
+            return self._eval_operator(op, op.join(separated), right_expr, expr, local_vars, allow_recursion), should_return
 
 
         if md.get('attribute'):
         if md.get('attribute'):
             variable, member, nullish = m.group('var', 'member', 'nullish')
             variable, member, nullish = m.group('var', 'member', 'nullish')