Browse Source

[utils] Process bytestrings in urljoin (closes #12369)

Sergey M․ 8 years ago
parent
commit
4b5de77bdb
2 changed files with 9 additions and 1 deletions
  1. 3 0
      test/test_utils.py
  2. 6 1
      youtube_dl/utils.py

+ 3 - 0
test/test_utils.py

@@ -455,6 +455,9 @@ class TestUtil(unittest.TestCase):
 
 
     def test_urljoin(self):
     def test_urljoin(self):
         self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
         self.assertEqual(urljoin('http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
+        self.assertEqual(urljoin(b'http://foo.de/', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
+        self.assertEqual(urljoin('http://foo.de/', b'/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
+        self.assertEqual(urljoin(b'http://foo.de/', b'/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
         self.assertEqual(urljoin('//foo.de/', '/a/b/c.txt'), '//foo.de/a/b/c.txt')
         self.assertEqual(urljoin('//foo.de/', '/a/b/c.txt'), '//foo.de/a/b/c.txt')
         self.assertEqual(urljoin('http://foo.de/', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt')
         self.assertEqual(urljoin('http://foo.de/', 'a/b/c.txt'), 'http://foo.de/a/b/c.txt')
         self.assertEqual(urljoin('http://foo.de', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')
         self.assertEqual(urljoin('http://foo.de', '/a/b/c.txt'), 'http://foo.de/a/b/c.txt')

+ 6 - 1
youtube_dl/utils.py

@@ -1748,11 +1748,16 @@ def base_url(url):
 
 
 
 
 def urljoin(base, path):
 def urljoin(base, path):
+    if isinstance(path, bytes):
+        path = path.decode('utf-8')
     if not isinstance(path, compat_str) or not path:
     if not isinstance(path, compat_str) or not path:
         return None
         return None
     if re.match(r'^(?:https?:)?//', path):
     if re.match(r'^(?:https?:)?//', path):
         return path
         return path
-    if not isinstance(base, compat_str) or not re.match(r'^(?:https?:)?//', base):
+    if isinstance(base, bytes):
+        base = base.decode('utf-8')
+    if not isinstance(base, compat_str) or not re.match(
+            r'^(?:https?:)?//', base):
         return None
         return None
     return compat_urlparse.urljoin(base, path)
     return compat_urlparse.urljoin(base, path)