Browse Source

[utils] Add `partial_application` decorator function

Thx: yt-dlp/yt-dlp#10653
dirkf 1 month ago
parent
commit
23a848c314
2 changed files with 49 additions and 0 deletions
  1. 16 0
      test/test_utils.py
  2. 33 0
      youtube_dl/utils.py

+ 16 - 0
test/test_utils.py

@@ -69,6 +69,7 @@ from youtube_dl.utils import (
     parse_iso8601,
     parse_resolution,
     parse_qs,
+    partial_application,
     pkcs1pad,
     prepend_extension,
     read_batch_urls,
@@ -1723,6 +1724,21 @@ Line 1
             'a', 'b', 'c', 'd',
             from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d')
 
+    def test_partial_application(self):
+        test_fn = partial_application(lambda x, kwarg=None: '{0}, kwarg={1!r}'.format(x, kwarg))
+        self.assertTrue(
+            callable(test_fn(kwarg=10)),
+            'missing positional parameter should apply partially')
+        self.assertEqual(
+            test_fn(10, kwarg=0.1), '10, kwarg=0.1',
+            'positionally passed argument should call function')
+        self.assertEqual(
+            test_fn(x=10), '10, kwarg=None',
+            'keyword passed positional should call function')
+        self.assertEqual(
+            test_fn(kwarg=0.1)(10), '10, kwarg=0.1',
+            'call after partial application should call the function')
+
 
 if __name__ == '__main__':
     unittest.main()

+ 33 - 0
youtube_dl/utils.py

@@ -1861,6 +1861,39 @@ def write_json_file(obj, fn):
         raise
 
 
+class partial_application(object):
+    """Allow a function to use pre-set argument values"""
+
+    # see _try_bind_args()
+    try:
+        inspect.signature
+
+        @staticmethod
+        def required_args(fn):
+            return [
+                param.name for param in inspect.signature(fn).parameters.values()
+                if (param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+                    and param.default is inspect.Parameter.empty)]
+
+    except AttributeError:
+
+        # Py < 3.3
+        @staticmethod
+        def required_args(fn):
+            fn_args = inspect.getargspec(fn)
+            n_defaults = len(fn_args.defaults or [])
+            return (fn_args.args or [])[:-n_defaults if n_defaults > 0 else None]
+
+    def __new__(cls, func):
+        @functools.wraps(func)
+        def wrapped(*args, **kwargs):
+            if set(cls.required_args(func)[len(args):]).difference(kwargs):
+                return functools.partial(func, *args, **kwargs)
+            return func(*args, **kwargs)
+
+        return wrapped
+
+
 if sys.version_info >= (2, 7):
     def find_xpath_attr(node, xpath, key, val=None):
         """ Find the xpath xpath[@key=val] """