Browse Source

[utils] Support traversal helper functions `require`, `value`, `unpack`

Thx: yt-dlp/yt-dlp#10653
dirkf 1 month ago
parent
commit
68fe8c1781
3 changed files with 63 additions and 4 deletions
  1. 35 4
      test/test_traversal.py
  2. 3 0
      youtube_dl/traversal.py
  3. 25 0
      youtube_dl/utils.py

+ 35 - 4
test/test_traversal.py

@@ -15,8 +15,11 @@ import re
 from youtube_dl.traversal import (
     dict_get,
     get_first,
+    require,
     T,
     traverse_obj,
+    unpack,
+    value,
 )
 from youtube_dl.compat import (
     compat_chr as chr,
@@ -27,7 +30,9 @@ from youtube_dl.compat import (
     compat_zip as zip,
 )
 from youtube_dl.utils import (
+    ExtractorError,
     int_or_none,
+    join_nonempty,
     str_or_none,
 )
 
@@ -462,8 +467,8 @@ class TestTraversal(_TestCase):
         }),
         values = dict((str(k), v) for k, v in values.items())
 
-        for key, value in values.items():
-            self.assertEqual(traverse_obj(morsel, key), value,
+        for key, val in values.items():
+            self.assertEqual(traverse_obj(morsel, key), val,
                              msg='Morsel should provide access to all values')
         values = list(values.values())
         self.assertMaybeCountEqual(traverse_obj(morsel, Ellipsis), values,
@@ -481,8 +486,31 @@ class TestTraversal(_TestCase):
             [True, 1, 1.1, 'str', {0: 0}, [1]],
             '`filter` should filter falsy values')
 
-    def test_get_first(self):
-        self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
+
+class TestTraversalHelpers(_TestCase):
+    def test_traversal_require(self):
+        with self.assertRaises(ExtractorError, msg='Missing `value` should raise'):
+            traverse_obj(_TEST_DATA, ('None', T(require('value'))))
+        self.assertEqual(
+            traverse_obj(_TEST_DATA, ('str', T(require('value')))), 'str',
+            '`require` should pass through non-`None` values')
+
+    def test_unpack(self):
+        self.assertEqual(
+            unpack(lambda *x: ''.join(map(compat_str, x)))([1, 2, 3]), '123')
+        self.assertEqual(
+            unpack(join_nonempty)([1, 2, 3]), '1-2-3')
+        self.assertEqual(
+            unpack(join_nonempty, delim=' ')([1, 2, 3]), '1 2 3')
+        with self.assertRaises(TypeError):
+            unpack(join_nonempty)()
+        with self.assertRaises(TypeError):
+            unpack()
+
+    def test_value(self):
+        self.assertEqual(
+            traverse_obj(_TEST_DATA, ('str', T(value('other')))), 'other',
+            '`value` should substitute specified value')
 
 
 class TestDictGet(_TestCase):
@@ -508,6 +536,9 @@ class TestDictGet(_TestCase):
             self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
             self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
 
+    def test_get_first(self):
+        self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
+
 
 if __name__ == '__main__':
     unittest.main()

+ 3 - 0
youtube_dl/traversal.py

@@ -5,6 +5,9 @@
 from .utils import (
     dict_get,
     get_first,
+    require,
     T,
     traverse_obj,
+    unpack,
+    value,
 )

+ 25 - 0
youtube_dl/utils.py

@@ -6543,6 +6543,31 @@ def traverse_obj(obj, *paths, **kwargs):
     return None if default is NO_DEFAULT else default
 
 
+def value(value):
+    return lambda _: value
+
+
+class require(ExtractorError):
+    def __init__(self, name, expected=False):
+        super(require, self).__init__(
+            'Unable to extract {0}'.format(name), expected=expected)
+
+    def __call__(self, value):
+        if value is None:
+            raise self
+
+        return value
+
+
+def unpack(func, **kwargs):
+    """Make a function that applies `partial(func, **kwargs)` to its argument as *args"""
+    @functools.wraps(func)
+    def inner(items):
+        return func(*items, **kwargs)
+
+    return inner
+
+
 def T(*x):
     """ For use in yt-dl instead of {type, ...} or set((type, ...)) """
     return set(x)