Sfoglia il codice sorgente

[utils] Add replace_extension

Sergey M․ 10 anni fa
parent
commit
b3ed15b760
2 ha cambiato i file con 16 aggiunte e 0 eliminazioni
  1. 9 0
      test/test_utils.py
  2. 7 0
      youtube_dl/utils.py

+ 9 - 0
test/test_utils.py

@@ -42,6 +42,7 @@ from youtube_dl.utils import (
     sanitize_path,
     sanitize_path,
     sanitize_url_path_consecutive_slashes,
     sanitize_url_path_consecutive_slashes,
     prepend_extension,
     prepend_extension,
+    replace_extension,
     shell_quote,
     shell_quote,
     smuggle_url,
     smuggle_url,
     str_to_int,
     str_to_int,
@@ -202,6 +203,14 @@ class TestUtil(unittest.TestCase):
         self.assertEqual(prepend_extension('.abc', 'temp'), '.abc.temp')
         self.assertEqual(prepend_extension('.abc', 'temp'), '.abc.temp')
         self.assertEqual(prepend_extension('.abc.ext', 'temp'), '.abc.temp.ext')
         self.assertEqual(prepend_extension('.abc.ext', 'temp'), '.abc.temp.ext')
 
 
+    def test_replace_extension(self):
+        self.assertEqual(replace_extension('abc.ext', 'temp'), 'abc.temp')
+        self.assertEqual(replace_extension('abc.ext', 'temp', 'ext'), 'abc.temp')
+        self.assertEqual(replace_extension('abc.unexpected_ext', 'temp', 'ext'), 'abc.unexpected_ext.temp')
+        self.assertEqual(replace_extension('abc', 'temp'), 'abc.temp')
+        self.assertEqual(replace_extension('.abc', 'temp'), '.abc.temp')
+        self.assertEqual(replace_extension('.abc.ext', 'temp'), '.abc.temp')
+
     def test_ordered_set(self):
     def test_ordered_set(self):
         self.assertEqual(orderedSet([1, 1, 2, 3, 4, 4, 5, 6, 7, 3, 5]), [1, 2, 3, 4, 5, 6, 7])
         self.assertEqual(orderedSet([1, 1, 2, 3, 4, 4, 5, 6, 7, 3, 5]), [1, 2, 3, 4, 5, 6, 7])
         self.assertEqual(orderedSet([]), [])
         self.assertEqual(orderedSet([]), [])

+ 7 - 0
youtube_dl/utils.py

@@ -1357,6 +1357,13 @@ def prepend_extension(filename, ext, expected_real_ext=None):
         else '{0}.{1}'.format(filename, ext))
         else '{0}.{1}'.format(filename, ext))
 
 
 
 
+def replace_extension(filename, ext, expected_real_ext=None):
+    name, real_ext = os.path.splitext(filename)
+    return '{0}.{1}'.format(
+        name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename,
+        ext)
+
+
 def check_executable(exe, args=[]):
 def check_executable(exe, args=[]):
     """ Checks if the given binary is installed somewhere in PATH, and returns its name.
     """ Checks if the given binary is installed somewhere in PATH, and returns its name.
     args can be a list of arguments for a short output (like -version) """
     args can be a list of arguments for a short output (like -version) """