Browse Source

[utils] Add {expected_type} and Iterable support to traverse_obj()

dirkf 2 years ago
parent
commit
f47fdb9564
2 changed files with 266 additions and 100 deletions
  1. 124 29
      test/test_utils.py
  2. 142 71
      youtube_dl/utils.py

+ 124 - 29
test/test_utils.py

@@ -79,10 +79,12 @@ from youtube_dl.utils import (
     rot47,
     rot47,
     shell_quote,
     shell_quote,
     smuggle_url,
     smuggle_url,
+    str_or_none,
     str_to_int,
     str_to_int,
     strip_jsonp,
     strip_jsonp,
     strip_or_none,
     strip_or_none,
     subtitles_filename,
     subtitles_filename,
+    T,
     timeconvert,
     timeconvert,
     traverse_obj,
     traverse_obj,
     try_call,
     try_call,
@@ -1566,6 +1568,7 @@ Line 1
         self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam')
         self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam')
 
 
     def test_traverse_obj(self):
     def test_traverse_obj(self):
+        str = compat_str
         _TEST_DATA = {
         _TEST_DATA = {
             100: 100,
             100: 100,
             1.2: 1.2,
             1.2: 1.2,
@@ -1598,8 +1601,8 @@ Line 1
 
 
         # Test Ellipsis behavior
         # Test Ellipsis behavior
         self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
         self.assertCountEqual(traverse_obj(_TEST_DATA, Ellipsis),
-                              (item for item in _TEST_DATA.values() if item is not None),
-                              msg='`...` should give all values except `None`')
+                              (item for item in _TEST_DATA.values() if item not in (None, {})),
+                              msg='`...` should give all non discarded values')
         self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
         self.assertCountEqual(traverse_obj(_TEST_DATA, ('urls', 0, Ellipsis)), _TEST_DATA['urls'][0].values(),
                               msg='`...` selection for dicts should select all values')
                               msg='`...` selection for dicts should select all values')
         self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
         self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'url')),
@@ -1607,13 +1610,51 @@ Line 1
                          msg='nested `...` queries should work')
                          msg='nested `...` queries should work')
         self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
         self.assertCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, Ellipsis, 'index')), range(4),
                               msg='`...` query result should be flattened')
                               msg='`...` query result should be flattened')
+        self.assertEqual(traverse_obj(iter(range(4)), Ellipsis), list(range(4)),
+                         msg='`...` should accept iterables')
 
 
         # Test function as key
         # Test function as key
         self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
         self.assertEqual(traverse_obj(_TEST_DATA, lambda x, y: x == 'urls' and isinstance(y, list)),
                          [_TEST_DATA['urls']],
                          [_TEST_DATA['urls']],
                          msg='function as query key should perform a filter based on (key, value)')
                          msg='function as query key should perform a filter based on (key, value)')
-        self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], compat_str)), ('str',),
-                              msg='exceptions in the query function should be caught')
+        self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
+                              msg='exceptions in the query function should be catched')
+        self.assertEqual(traverse_obj(iter(range(4)), lambda _, x: x % 2 == 0), [0, 2],
+                         msg='function key should accept iterables')
+        if __debug__:
+            with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
+                traverse_obj(_TEST_DATA, lambda a: Ellipsis)
+            with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
+                traverse_obj(_TEST_DATA, lambda a, b, c: Ellipsis)
+
+        # Test set as key (transformation/type, like `expected_type`)
+        self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'],
+                         msg='Function in set should be a transformation')
+        self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'],
+                         msg='Type in set should be a type filter')
+        self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA,
+                         msg='A single set should be wrapped into a path')
+        self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'],
+                         msg='Transformation function should not raise')
+        self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))),
+                         [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
+                         msg='Function in set should be a transformation')
+        if __debug__:
+            with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
+                traverse_obj(_TEST_DATA, set())
+            with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
+                traverse_obj(_TEST_DATA, {str.upper, str})
+
+        # Test `slice` as a key
+        _SLICE_DATA = [0, 1, 2, 3, 4]
+        self.assertEqual(traverse_obj(_TEST_DATA, ('dict', slice(1))), None,
+                         msg='slice on a dictionary should not throw')
+        self.assertEqual(traverse_obj(_SLICE_DATA, slice(1)), _SLICE_DATA[:1],
+                         msg='slice key should apply slice to sequence')
+        self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 2)), _SLICE_DATA[1:2],
+                         msg='slice key should apply slice to sequence')
+        self.assertEqual(traverse_obj(_SLICE_DATA, slice(1, 4, 2)), _SLICE_DATA[1:4:2],
+                         msg='slice key should apply slice to sequence')
 
 
         # Test alternative paths
         # Test alternative paths
         self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
         self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
@@ -1659,15 +1700,23 @@ Line 1
                          {0: ['https://www.example.com/1', 'https://www.example.com/0']},
                          {0: ['https://www.example.com/1', 'https://www.example.com/0']},
                          msg='triple nesting in dict path should be treated as branches')
                          msg='triple nesting in dict path should be treated as branches')
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}), {},
-                         msg='remove `None` values when dict key')
+                         msg='remove `None` values when top level dict key fails')
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
         self.assertEqual(traverse_obj(_TEST_DATA, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
-                         msg='do not remove `None` values if `default`')
-        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {0: {}},
-                         msg='do not remove empty values when dict key')
-        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: {}},
-                         msg='do not remove empty values when dict key and a default')
-        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {0: []},
-                         msg='if branch in dict key not successful, return `[]`')
+                         msg='use `default` if key fails and `default`')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}), {},
+                         msg='remove empty values when dict key')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 'dict'}, default=Ellipsis), {0: Ellipsis},
+                         msg='use `default` when dict key and `default`')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}), {},
+                         msg='remove empty values when nested dict key fails')
+        self.assertEqual(traverse_obj(None, {0: 'fail'}), {},
+                         msg='default to dict if pruned')
+        self.assertEqual(traverse_obj(None, {0: 'fail'}, default=Ellipsis), {0: Ellipsis},
+                         msg='default to dict if pruned and default is given')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 'fail'}}, default=Ellipsis), {0: {0: Ellipsis}},
+                         msg='use nested `default` when nested dict key fails and `default`')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: ('dict', Ellipsis)}), {},
+                         msg='remove key if branch in dict key not successful')
 
 
         # Testing default parameter behavior
         # Testing default parameter behavior
         _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
         _DEFAULT_DATA = {'None': None, 'int': 0, 'list': []}
@@ -1691,20 +1740,55 @@ Line 1
                          msg='if branched but not successful return `[]`, not `default`')
                          msg='if branched but not successful return `[]`, not `default`')
         self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [],
         self.assertEqual(traverse_obj(_DEFAULT_DATA, ('list', Ellipsis)), [],
                          msg='if branched but object is empty return `[]`, not `default`')
                          msg='if branched but object is empty return `[]`, not `default`')
+        self.assertEqual(traverse_obj(None, Ellipsis), [],
+                         msg='if branched but object is `None` return `[]`, not `default`')
+        self.assertEqual(traverse_obj({0: None}, (0, Ellipsis)), [],
+                         msg='if branched but state is `None` return `[]`, not `default`')
+
+        branching_paths = [
+            ('fail', Ellipsis),
+            (Ellipsis, 'fail'),
+            100 * ('fail',) + (Ellipsis,),
+            (Ellipsis,) + 100 * ('fail',),
+        ]
+        for branching_path in branching_paths:
+            self.assertEqual(traverse_obj({}, branching_path), [],
+                             msg='if branched but state is `None`, return `[]` (not `default`)')
+            self.assertEqual(traverse_obj({}, 'fail', branching_path), [],
+                             msg='if branching in last alternative and previous did not match, return `[]` (not `default`)')
+            self.assertEqual(traverse_obj({0: 'x'}, 0, branching_path), 'x',
+                             msg='if branching in last alternative and previous did match, return single value')
+            self.assertEqual(traverse_obj({0: 'x'}, branching_path, 0), 'x',
+                             msg='if branching in first alternative and non-branching path does match, return single value')
+            self.assertEqual(traverse_obj({}, branching_path, 'fail'), None,
+                             msg='if branching in first alternative and non-branching path does not match, return `default`')
 
 
         # Testing expected_type behavior
         # Testing expected_type behavior
         _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
         _EXPECTED_TYPE_DATA = {'str': 'str', 'int': 0}
-        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=compat_str), 'str',
-                         msg='accept matching `expected_type` type')
-        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int), None,
-                         msg='reject non matching `expected_type` type')
-        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: compat_str(x)), '0',
-                         msg='transform type using type function')
-        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str',
-                                      expected_type=lambda _: 1 / 0), None,
-                         msg='wrap expected_type function in try_call')
-        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=compat_str), ['str'],
-                         msg='eliminate items that expected_type fails on')
+        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=str),
+                         'str', msg='accept matching `expected_type` type')
+        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=int),
+                         None, msg='reject non matching `expected_type` type')
+        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'int', expected_type=lambda x: str(x)),
+                         '0', msg='transform type using type function')
+        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, 'str', expected_type=lambda _: 1 / 0),
+                         None, msg='wrap expected_type function in try_call')
+        self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, Ellipsis, expected_type=str),
+                         ['str'], msg='eliminate items that expected_type fails on')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int),
+                         {0: 100}, msg='type as expected_type should filter dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none),
+                         {0: '100', 1: '1.2'}, msg='function as expected_type should transform dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int),
+                         1, msg='expected_type should not filter non final dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int),
+                         {0: {0: 100}}, msg='expected_type should transform deep dict values')
+        self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(Ellipsis)),
+                         [{0: Ellipsis}, {0: Ellipsis}], msg='expected_type should transform branched dict values')
+        self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int),
+                         [4], msg='expected_type regression for type matching in tuple branching')
+        self.assertEqual(traverse_obj(_TEST_DATA, ['data', Ellipsis], expected_type=int),
+                         [], msg='expected_type regression for type matching in dict result')
 
 
         # Test get_all behavior
         # Test get_all behavior
         _GET_ALL_DATA = {'key': [0, 1, 2]}
         _GET_ALL_DATA = {'key': [0, 1, 2]}
@@ -1749,14 +1833,23 @@ Line 1
                                       _traverse_string=True), '.',
                                       _traverse_string=True), '.',
                          msg='traverse into converted data if `traverse_string`')
                          msg='traverse into converted data if `traverse_string`')
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis),
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', Ellipsis),
-                                      _traverse_string=True), list('str'),
-                         msg='`...` branching into string should result in list')
+                                      _traverse_string=True), 'str',
+                         msg='`...` should result in string (same value) if `traverse_string`')
+        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', slice(0, None, 2)),
+                                      _traverse_string=True), 'sr',
+                         msg='`slice` should result in string if `traverse_string`')
+        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda i, v: i or v == "s"),
+                                      _traverse_string=True), 'str',
+                         msg='function should result in string if `traverse_string`')
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
         self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', (0, 2)),
                                       _traverse_string=True), ['s', 'r'],
                                       _traverse_string=True), ['s', 'r'],
-                         msg='branching into string should result in list')
-        self.assertEqual(traverse_obj(_TRAVERSE_STRING_DATA, ('str', lambda _, x: x),
-                                      _traverse_string=True), list('str'),
-                         msg='function branching into string should result in list')
+                         msg='branching should result in list if `traverse_string`')
+        self.assertEqual(traverse_obj({}, (0, Ellipsis), _traverse_string=True), [],
+                         msg='branching should result in list if `traverse_string`')
+        self.assertEqual(traverse_obj({}, (0, lambda x, y: True), _traverse_string=True), [],
+                         msg='branching should result in list if `traverse_string`')
+        self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [],
+                         msg='branching should result in list if `traverse_string`')
 
 
         # Test is_user_input behavior
         # Test is_user_input behavior
         _IS_USER_INPUT_DATA = {'range8': list(range(8))}
         _IS_USER_INPUT_DATA = {'range8': list(range(8))}
@@ -1793,6 +1886,8 @@ Line 1
                          msg='failing str key on a `re.Match` should return `default`')
                          msg='failing str key on a `re.Match` should return `default`')
         self.assertEqual(traverse_obj(mobj, 8), None,
         self.assertEqual(traverse_obj(mobj, 8), None,
                          msg='failing int key on a `re.Match` should return `default`')
                          msg='failing int key on a `re.Match` should return `default`')
+        self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
+                         msg='function on a `re.Match` should give group name as well')
 
 
     def test_get_first(self):
     def test_get_first(self):
         self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')
         self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam')

+ 142 - 71
youtube_dl/utils.py

@@ -16,6 +16,7 @@ import email.header
 import errno
 import errno
 import functools
 import functools
 import gzip
 import gzip
+import inspect
 import io
 import io
 import itertools
 import itertools
 import json
 import json
@@ -3881,7 +3882,7 @@ def detect_exe_version(output, version_re=None, unrecognized='present'):
         return unrecognized
         return unrecognized
 
 
 
 
-class LazyList(compat_collections_abc.Sequence):
+class LazyList(compat_collections_abc.Iterable):
     """Lazy immutable list from an iterable
     """Lazy immutable list from an iterable
     Note that slices of a LazyList are lists and not LazyList"""
     Note that slices of a LazyList are lists and not LazyList"""
 
 
@@ -4223,10 +4224,16 @@ def multipart_encode(data, boundary=None):
     return out, content_type
     return out, content_type
 
 
 
 
-def variadic(x, allowed_types=(compat_str, bytes, dict)):
-    if not isinstance(allowed_types, tuple) and isinstance(allowed_types, compat_collections_abc.Iterable):
+def is_iterable_like(x, allowed_types=compat_collections_abc.Iterable, blocked_types=NO_DEFAULT):
+    if blocked_types is NO_DEFAULT:
+        blocked_types = (compat_str, bytes, compat_collections_abc.Mapping)
+    return isinstance(x, allowed_types) and not isinstance(x, blocked_types)
+
+
+def variadic(x, allowed_types=NO_DEFAULT):
+    if isinstance(allowed_types, compat_collections_abc.Iterable):
         allowed_types = tuple(allowed_types)
         allowed_types = tuple(allowed_types)
-    return x if isinstance(x, compat_collections_abc.Iterable) and not isinstance(x, allowed_types) else (x,)
+    return x if is_iterable_like(x, blocked_types=allowed_types) else (x,)
 
 
 
 
 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
 def dict_get(d, key_or_keys, default=None, skip_false_values=True):
@@ -5993,7 +6000,7 @@ def clean_podcast_url(url):
 
 
 def traverse_obj(obj, *paths, **kwargs):
 def traverse_obj(obj, *paths, **kwargs):
     """
     """
-    Safely traverse nested `dict`s and `Sequence`s
+    Safely traverse nested `dict`s and `Iterable`s
 
 
     >>> obj = [{}, {"key": "value"}]
     >>> obj = [{}, {"key": "value"}]
     >>> traverse_obj(obj, (1, "key"))
     >>> traverse_obj(obj, (1, "key"))
@@ -6001,14 +6008,17 @@ def traverse_obj(obj, *paths, **kwargs):
 
 
     Each of the provided `paths` is tested and the first producing a valid result will be returned.
     Each of the provided `paths` is tested and the first producing a valid result will be returned.
     The next path will also be tested if the path branched but no results could be found.
     The next path will also be tested if the path branched but no results could be found.
-    Supported values for traversal are `Mapping`, `Sequence` and `re.Match`.
-    A value of None is treated as the absence of a value.
+    Supported values for traversal are `Mapping`, `Iterable` and `re.Match`.
+    Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded.
 
 
     The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
     The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`.
 
 
     The keys in the path can be one of:
     The keys in the path can be one of:
         - `None`:           Return the current object.
         - `None`:           Return the current object.
-        - `str`/`int`:      Return `obj[key]`. For `re.Match, return `obj.group(key)`.
+        - `set`:            Requires the only item in the set to be a type or function,
+                            like `{type}`/`{func}`. If a `type`, returns only values
+                            of this type. If a function, returns `func(obj)`.
+        - `str`/`int`:      Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
         - `slice`:          Branch out and return all values in `obj[key]`.
         - `slice`:          Branch out and return all values in `obj[key]`.
         - `Ellipsis`:       Branch out and return a list of all values.
         - `Ellipsis`:       Branch out and return a list of all values.
         - `tuple`/`list`:   Branch out and return a list of all matching values.
         - `tuple`/`list`:   Branch out and return a list of all matching values.
@@ -6016,6 +6026,9 @@ def traverse_obj(obj, *paths, **kwargs):
         - `function`:       Branch out and return values filtered by the function.
         - `function`:       Branch out and return values filtered by the function.
                             Read as: `[value for key, value in obj if function(key, value)]`.
                             Read as: `[value for key, value in obj if function(key, value)]`.
                             For `Sequence`s, `key` is the index of the value.
                             For `Sequence`s, `key` is the index of the value.
+                            For `Iterable`s, `key` is the enumeration count of the value.
+                            For `re.Match`es, `key` is the group number (0 = full match)
+                            as well as additionally any group names, if given.
         - `dict`            Transform the current object and return a matching dict.
         - `dict`            Transform the current object and return a matching dict.
                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
                             Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
 
 
@@ -6024,8 +6037,12 @@ def traverse_obj(obj, *paths, **kwargs):
     @params paths           Paths which to traverse by.
     @params paths           Paths which to traverse by.
     Keyword arguments:
     Keyword arguments:
     @param default          Value to return if the paths do not match.
     @param default          Value to return if the paths do not match.
+                            If the last key in the path is a `dict`, it will apply to each value inside
+                            the dict instead, depth first. Try to avoid if using nested `dict` keys.
     @param expected_type    If a `type`, only accept final values of this type.
     @param expected_type    If a `type`, only accept final values of this type.
                             If any other callable, try to call the function on each result.
                             If any other callable, try to call the function on each result.
+                            If the last key in the path is a `dict`, it will apply to each value inside
+                            the dict instead, recursively. This does respect branching paths.
     @param get_all          If `False`, return the first matching result, otherwise all matching ones.
     @param get_all          If `False`, return the first matching result, otherwise all matching ones.
     @param casesense        If `False`, consider string dictionary keys as case insensitive.
     @param casesense        If `False`, consider string dictionary keys as case insensitive.
 
 
@@ -6036,12 +6053,15 @@ def traverse_obj(obj, *paths, **kwargs):
     @param _traverse_string  Whether to traverse into objects as strings.
     @param _traverse_string  Whether to traverse into objects as strings.
                             If `True`, any non-compatible object will first be
                             If `True`, any non-compatible object will first be
                             converted into a string and then traversed into.
                             converted into a string and then traversed into.
+                            The return value of that path will be a string instead,
+                            not respecting any further branching.
 
 
 
 
     @returns                The result of the object traversal.
     @returns                The result of the object traversal.
                             If successful, `get_all=True`, and the path branches at least once,
                             If successful, `get_all=True`, and the path branches at least once,
                             then a list of results is returned instead.
                             then a list of results is returned instead.
                             A list is always returned if the last path branches and no `default` is given.
                             A list is always returned if the last path branches and no `default` is given.
+                            If a path ends on a `dict` that result will always be a `dict`.
     """
     """
 
 
     # parameter defaults
     # parameter defaults
@@ -6055,7 +6075,6 @@ def traverse_obj(obj, *paths, **kwargs):
     # instant compat
     # instant compat
     str = compat_str
     str = compat_str
 
 
-    is_sequence = lambda x: isinstance(x, compat_collections_abc.Sequence) and not isinstance(x, (str, bytes))
     casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k
     casefold = lambda k: compat_casefold(k) if isinstance(k, str) else k
 
 
     if isinstance(expected_type, type):
     if isinstance(expected_type, type):
@@ -6063,128 +6082,180 @@ def traverse_obj(obj, *paths, **kwargs):
     else:
     else:
         type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
         type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
 
 
+    def lookup_or_none(v, k, getter=None):
+        try:
+            return getter(v, k) if getter else v[k]
+        except IndexError:
+            return None
+
     def from_iterable(iterables):
     def from_iterable(iterables):
         # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
         # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
         for it in iterables:
         for it in iterables:
             for item in it:
             for item in it:
                 yield item
                 yield item
 
 
-    def apply_key(key, obj):
-        if obj is None:
-            return
+    def apply_key(key, obj, is_last):
+        branching = False
+
+        if obj is None and _traverse_string:
+            if key is Ellipsis or callable(key) or isinstance(key, slice):
+                branching = True
+                result = ()
+            else:
+                result = None
 
 
         elif key is None:
         elif key is None:
-            yield obj
+            result = obj
+
+        elif isinstance(key, set):
+            assert len(key) == 1, 'Set should only be used to wrap a single item'
+            item = next(iter(key))
+            if isinstance(item, type):
+                result = obj if isinstance(obj, item) else None
+            else:
+                result = try_call(item, args=(obj,))
 
 
         elif isinstance(key, (list, tuple)):
         elif isinstance(key, (list, tuple)):
-            for branch in key:
-                _, result = apply_path(obj, branch)
-                for item in result:
-                    yield item
+            branching = True
+            result = from_iterable(
+                apply_path(obj, branch, is_last)[0] for branch in key)
 
 
         elif key is Ellipsis:
         elif key is Ellipsis:
-            result = []
+            branching = True
             if isinstance(obj, compat_collections_abc.Mapping):
             if isinstance(obj, compat_collections_abc.Mapping):
                 result = obj.values()
                 result = obj.values()
-            elif is_sequence(obj):
+            elif is_iterable_like(obj):
                 result = obj
                 result = obj
             elif isinstance(obj, compat_re_Match):
             elif isinstance(obj, compat_re_Match):
                 result = obj.groups()
                 result = obj.groups()
             elif _traverse_string:
             elif _traverse_string:
+                branching = False
                 result = str(obj)
                 result = str(obj)
-            for item in result:
-                yield item
+            else:
+                result = ()
 
 
         elif callable(key):
         elif callable(key):
-            if is_sequence(obj):
-                iter_obj = enumerate(obj)
-            elif isinstance(obj, compat_collections_abc.Mapping):
+            branching = True
+            if isinstance(obj, compat_collections_abc.Mapping):
                 iter_obj = obj.items()
                 iter_obj = obj.items()
+            elif is_iterable_like(obj):
+                iter_obj = enumerate(obj)
             elif isinstance(obj, compat_re_Match):
             elif isinstance(obj, compat_re_Match):
-                iter_obj = enumerate(itertools.chain([obj.group()], obj.groups()))
+                iter_obj = itertools.chain(
+                    enumerate(itertools.chain((obj.group(),), obj.groups())),
+                    obj.groupdict().items())
             elif _traverse_string:
             elif _traverse_string:
+                branching = False
                 iter_obj = enumerate(str(obj))
                 iter_obj = enumerate(str(obj))
             else:
             else:
-                return
-            for item in (v for k, v in iter_obj if try_call(key, args=(k, v))):
-                yield item
+                iter_obj = ()
+
+            result = (v for k, v in iter_obj if try_call(key, args=(k, v)))
+            if not branching:  # string traversal
+                result = ''.join(result)
 
 
         elif isinstance(key, dict):
         elif isinstance(key, dict):
-            iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items())
-            yield dict((k, v if v is not None else default) for k, v in iter_obj
-                       if v is not None or default is not NO_DEFAULT)
+            iter_obj = ((k, _traverse_obj(obj, v, False, is_last)) for k, v in key.items())
+            result = dict((k, v if v is not None else default) for k, v in iter_obj
+                          if v is not None or default is not NO_DEFAULT) or None
 
 
         elif isinstance(obj, compat_collections_abc.Mapping):
         elif isinstance(obj, compat_collections_abc.Mapping):
-            yield (obj.get(key) if casesense or (key in obj)
-                   else next((v for k, v in obj.items() if casefold(k) == key), None))
+            result = (try_call(obj.get, args=(key,))
+                      if casesense or try_call(obj.__contains__, args=(key,))
+                      else next((v for k, v in obj.items() if casefold(k) == key), None))
 
 
         elif isinstance(obj, compat_re_Match):
         elif isinstance(obj, compat_re_Match):
+            result = None
             if isinstance(key, int) or casesense:
             if isinstance(key, int) or casesense:
-                try:
-                    yield obj.group(key)
-                    return
-                except IndexError:
-                    pass
-            if not isinstance(key, str):
-                return
+                result = lookup_or_none(obj, key, getter=compat_re_Match.group)
 
 
-            yield next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
+            elif isinstance(key, str):
+                result = next((v for k, v in obj.groupdict().items()
+                              if casefold(k) == key), None)
 
 
         else:
         else:
-            if _is_user_input:
-                key = (int_or_none(key) if ':' not in key
-                       else slice(*map(int_or_none, key.split(':'))))
+            result = None
+            if isinstance(key, (int, slice)):
+                if is_iterable_like(obj, compat_collections_abc.Sequence):
+                    branching = isinstance(key, slice)
+                    result = lookup_or_none(obj, key)
+                elif _traverse_string:
+                    result = lookup_or_none(str(obj), key)
+
+        return branching, result if branching else (result,)
+
+    def lazy_last(iterable):
+        iterator = iter(iterable)
+        prev = next(iterator, NO_DEFAULT)
+        if prev is NO_DEFAULT:
+            return
 
 
-            if not isinstance(key, (int, slice)):
-                return
+        for item in iterator:
+            yield False, prev
+            prev = item
 
 
-            if not is_sequence(obj):
-                if not _traverse_string:
-                    return
-                obj = str(obj)
+        yield True, prev
 
 
-            try:
-                yield obj[key]
-            except IndexError:
-                pass
-
-    def apply_path(start_obj, path):
+    def apply_path(start_obj, path, test_type):
         objs = (start_obj,)
         objs = (start_obj,)
         has_branched = False
         has_branched = False
 
 
-        for key in variadic(path):
-            if _is_user_input and key == ':':
-                key = Ellipsis
+        key = None
+        for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
+            if _is_user_input and isinstance(key, str):
+                if key == ':':
+                    key = Ellipsis
+                elif ':' in key:
+                    key = slice(*map(int_or_none, key.split(':')))
+                elif int_or_none(key) is not None:
+                    key = int(key)
 
 
             if not casesense and isinstance(key, str):
             if not casesense and isinstance(key, str):
                 key = compat_casefold(key)
                 key = compat_casefold(key)
 
 
-            if key is Ellipsis or isinstance(key, (list, tuple)) or callable(key):
-                has_branched = True
+            if __debug__ and callable(key):
+                # Verify function signature
+                inspect.getcallargs(key, None, None)
+
+            new_objs = []
+            for obj in objs:
+                branching, results = apply_key(key, obj, last)
+                has_branched |= branching
+                new_objs.append(results)
+
+            objs = from_iterable(new_objs)
 
 
-            key_func = functools.partial(apply_key, key)
-            objs = from_iterable(map(key_func, objs))
+        if test_type and not isinstance(key, (dict, list, tuple)):
+            objs = map(type_test, objs)
 
 
-        return has_branched, objs
+        return objs, has_branched, isinstance(key, dict)
 
 
-    def _traverse_obj(obj, path, use_list=True):
-        has_branched, results = apply_path(obj, path)
-        results = LazyList(x for x in map(type_test, results) if x is not None)
+    def _traverse_obj(obj, path, allow_empty, test_type):
+        results, has_branched, is_dict = apply_path(obj, path, test_type)
+        results = LazyList(x for x in results if x not in (None, {}))
 
 
         if get_all and has_branched:
         if get_all and has_branched:
-            return results.exhaust() if results or use_list else None
+            if results:
+                return results.exhaust()
+            if allow_empty:
+                return [] if default is NO_DEFAULT else default
+            return None
 
 
-        return results[0] if results else None
+        return results[0] if results else {} if allow_empty and is_dict else None
 
 
     for index, path in enumerate(paths, 1):
     for index, path in enumerate(paths, 1):
-        use_list = default is NO_DEFAULT and index == len(paths)
-        result = _traverse_obj(obj, path, use_list)
+        result = _traverse_obj(obj, path, index == len(paths), True)
         if result is not None:
         if result is not None:
             return result
             return result
 
 
     return None if default is NO_DEFAULT else default
     return None if default is NO_DEFAULT else default
 
 
 
 
+def T(x):
+    """ For use in yt-dl instead of {type} or set((type,)) """
+    return set((x,))
+
+
 def get_first(obj, keys, **kwargs):
 def get_first(obj, keys, **kwargs):
     return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)
     return traverse_obj(obj, (Ellipsis,) + tuple(variadic(keys)), get_all=False, **kwargs)