Browse Source

Merge pull request #7296 from jaimeMF/xml_attrib_unicode

Use a wrapper around xml.etree.ElementTree.fromstring in python 2.x (…
Sergey M 9 years ago
parent
commit
30eecc6a04

+ 17 - 0
test/test_compat.py

@@ -13,8 +13,10 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 from youtube_dl.utils import get_filesystem_encoding
 from youtube_dl.utils import get_filesystem_encoding
 from youtube_dl.compat import (
 from youtube_dl.compat import (
     compat_getenv,
     compat_getenv,
+    compat_etree_fromstring,
     compat_expanduser,
     compat_expanduser,
     compat_shlex_split,
     compat_shlex_split,
+    compat_str,
     compat_urllib_parse_unquote,
     compat_urllib_parse_unquote,
     compat_urllib_parse_unquote_plus,
     compat_urllib_parse_unquote_plus,
 )
 )
@@ -71,5 +73,20 @@ class TestCompat(unittest.TestCase):
     def test_compat_shlex_split(self):
     def test_compat_shlex_split(self):
         self.assertEqual(compat_shlex_split('-option "one two"'), ['-option', 'one two'])
         self.assertEqual(compat_shlex_split('-option "one two"'), ['-option', 'one two'])
 
 
+    def test_compat_etree_fromstring(self):
+        xml = '''
+            <root foo="bar" spam="中文">
+                <normal>foo</normal>
+                <chinese>中文</chinese>
+                <foo><bar>spam</bar></foo>
+            </root>
+        '''
+        doc = compat_etree_fromstring(xml.encode('utf-8'))
+        self.assertTrue(isinstance(doc.attrib['foo'], compat_str))
+        self.assertTrue(isinstance(doc.attrib['spam'], compat_str))
+        self.assertTrue(isinstance(doc.find('normal').text, compat_str))
+        self.assertTrue(isinstance(doc.find('chinese').text, compat_str))
+        self.assertTrue(isinstance(doc.find('foo/bar').text, compat_str))
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()

+ 7 - 4
test/test_utils.py

@@ -68,6 +68,9 @@ from youtube_dl.utils import (
     cli_valueless_option,
     cli_valueless_option,
     cli_bool_option,
     cli_bool_option,
 )
 )
+from youtube_dl.compat import (
+    compat_etree_fromstring,
+)
 
 
 
 
 class TestUtil(unittest.TestCase):
 class TestUtil(unittest.TestCase):
@@ -242,7 +245,7 @@ class TestUtil(unittest.TestCase):
             <node x="b" y="d" />
             <node x="b" y="d" />
             <node x="" />
             <node x="" />
         </root>'''
         </root>'''
-        doc = xml.etree.ElementTree.fromstring(testxml)
+        doc = compat_etree_fromstring(testxml)
 
 
         self.assertEqual(find_xpath_attr(doc, './/fourohfour', 'n'), None)
         self.assertEqual(find_xpath_attr(doc, './/fourohfour', 'n'), None)
         self.assertEqual(find_xpath_attr(doc, './/fourohfour', 'n', 'v'), None)
         self.assertEqual(find_xpath_attr(doc, './/fourohfour', 'n', 'v'), None)
@@ -263,7 +266,7 @@ class TestUtil(unittest.TestCase):
                 <url>http://server.com/download.mp3</url>
                 <url>http://server.com/download.mp3</url>
             </media:song>
             </media:song>
         </root>'''
         </root>'''
-        doc = xml.etree.ElementTree.fromstring(testxml)
+        doc = compat_etree_fromstring(testxml)
         find = lambda p: doc.find(xpath_with_ns(p, {'media': 'http://example.com/'}))
         find = lambda p: doc.find(xpath_with_ns(p, {'media': 'http://example.com/'}))
         self.assertTrue(find('media:song') is not None)
         self.assertTrue(find('media:song') is not None)
         self.assertEqual(find('media:song/media:author').text, 'The Author')
         self.assertEqual(find('media:song/media:author').text, 'The Author')
@@ -292,7 +295,7 @@ class TestUtil(unittest.TestCase):
                 <p>Foo</p>
                 <p>Foo</p>
             </div>
             </div>
         </root>'''
         </root>'''
-        doc = xml.etree.ElementTree.fromstring(testxml)
+        doc = compat_etree_fromstring(testxml)
         self.assertEqual(xpath_text(doc, 'div/p'), 'Foo')
         self.assertEqual(xpath_text(doc, 'div/p'), 'Foo')
         self.assertEqual(xpath_text(doc, 'div/bar', default='default'), 'default')
         self.assertEqual(xpath_text(doc, 'div/bar', default='default'), 'default')
         self.assertTrue(xpath_text(doc, 'div/bar') is None)
         self.assertTrue(xpath_text(doc, 'div/bar') is None)
@@ -304,7 +307,7 @@ class TestUtil(unittest.TestCase):
                 <p x="a">Foo</p>
                 <p x="a">Foo</p>
             </div>
             </div>
         </root>'''
         </root>'''
-        doc = xml.etree.ElementTree.fromstring(testxml)
+        doc = compat_etree_fromstring(testxml)
         self.assertEqual(xpath_attr(doc, 'div/p', 'x'), 'a')
         self.assertEqual(xpath_attr(doc, 'div/p', 'x'), 'a')
         self.assertEqual(xpath_attr(doc, 'div/bar', 'x'), None)
         self.assertEqual(xpath_attr(doc, 'div/bar', 'x'), None)
         self.assertEqual(xpath_attr(doc, 'div/p', 'y'), None)
         self.assertEqual(xpath_attr(doc, 'div/p', 'y'), None)

+ 39 - 0
youtube_dl/compat.py

@@ -14,6 +14,7 @@ import socket
 import subprocess
 import subprocess
 import sys
 import sys
 import itertools
 import itertools
+import xml.etree.ElementTree
 
 
 
 
 try:
 try:
@@ -212,6 +213,43 @@ try:
 except ImportError:  # Python 2.6
 except ImportError:  # Python 2.6
     from xml.parsers.expat import ExpatError as compat_xml_parse_error
     from xml.parsers.expat import ExpatError as compat_xml_parse_error
 
 
+if sys.version_info[0] >= 3:
+    compat_etree_fromstring = xml.etree.ElementTree.fromstring
+else:
+    # python 2.x tries to encode unicode strings with ascii (see the
+    # XMLParser._fixtext method)
+    etree = xml.etree.ElementTree
+
+    try:
+        _etree_iter = etree.Element.iter
+    except AttributeError:  # Python <=2.6
+        def _etree_iter(root):
+            for el in root.findall('*'):
+                yield el
+                for sub in _etree_iter(el):
+                    yield sub
+
+    # on 2.6 XML doesn't have a parser argument, function copied from CPython
+    # 2.7 source
+    def _XML(text, parser=None):
+        if not parser:
+            parser = etree.XMLParser(target=etree.TreeBuilder())
+        parser.feed(text)
+        return parser.close()
+
+    def _element_factory(*args, **kwargs):
+        el = etree.Element(*args, **kwargs)
+        for k, v in el.items():
+            if isinstance(v, bytes):
+                el.set(k, v.decode('utf-8'))
+        return el
+
+    def compat_etree_fromstring(text):
+        doc = _XML(text, parser=etree.XMLParser(target=etree.TreeBuilder(element_factory=_element_factory)))
+        for el in _etree_iter(doc):
+            if el.text is not None and isinstance(el.text, bytes):
+                el.text = el.text.decode('utf-8')
+        return doc
 
 
 try:
 try:
     from urllib.parse import parse_qs as compat_parse_qs
     from urllib.parse import parse_qs as compat_parse_qs
@@ -507,6 +545,7 @@ __all__ = [
     'compat_chr',
     'compat_chr',
     'compat_cookiejar',
     'compat_cookiejar',
     'compat_cookies',
     'compat_cookies',
+    'compat_etree_fromstring',
     'compat_expanduser',
     'compat_expanduser',
     'compat_get_terminal_size',
     'compat_get_terminal_size',
     'compat_getenv',
     'compat_getenv',

+ 2 - 2
youtube_dl/downloader/f4m.py

@@ -5,10 +5,10 @@ import io
 import itertools
 import itertools
 import os
 import os
 import time
 import time
-import xml.etree.ElementTree as etree
 
 
 from .fragment import FragmentFD
 from .fragment import FragmentFD
 from ..compat import (
 from ..compat import (
+    compat_etree_fromstring,
     compat_urlparse,
     compat_urlparse,
     compat_urllib_error,
     compat_urllib_error,
     compat_urllib_parse_urlparse,
     compat_urllib_parse_urlparse,
@@ -290,7 +290,7 @@ class F4mFD(FragmentFD):
         man_url = urlh.geturl()
         man_url = urlh.geturl()
         manifest = urlh.read()
         manifest = urlh.read()
 
 
-        doc = etree.fromstring(manifest)
+        doc = compat_etree_fromstring(manifest)
         formats = [(int(f.attrib.get('bitrate', -1)), f)
         formats = [(int(f.attrib.get('bitrate', -1)), f)
                    for f in self._get_unencrypted_media(doc)]
                    for f in self._get_unencrypted_media(doc)]
         if requested_bitrate is None:
         if requested_bitrate is None:

+ 2 - 2
youtube_dl/extractor/ard.py

@@ -14,8 +14,8 @@ from ..utils import (
     parse_duration,
     parse_duration,
     unified_strdate,
     unified_strdate,
     xpath_text,
     xpath_text,
-    parse_xml,
 )
 )
+from ..compat import compat_etree_fromstring
 
 
 
 
 class ARDMediathekIE(InfoExtractor):
 class ARDMediathekIE(InfoExtractor):
@@ -161,7 +161,7 @@ class ARDMediathekIE(InfoExtractor):
             raise ExtractorError('This program is only suitable for those aged 12 and older. Video %s is therefore only available between 20 pm and 6 am.' % video_id, expected=True)
             raise ExtractorError('This program is only suitable for those aged 12 and older. Video %s is therefore only available between 20 pm and 6 am.' % video_id, expected=True)
 
 
         if re.search(r'[\?&]rss($|[=&])', url):
         if re.search(r'[\?&]rss($|[=&])', url):
-            doc = parse_xml(webpage)
+            doc = compat_etree_fromstring(webpage.encode('utf-8'))
             if doc.tag == 'rss':
             if doc.tag == 'rss':
                 return GenericIE()._extract_rss(url, video_id, doc)
                 return GenericIE()._extract_rss(url, video_id, doc)
 
 

+ 5 - 3
youtube_dl/extractor/bbc.py

@@ -2,7 +2,6 @@
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 
 
 import re
 import re
-import xml.etree.ElementTree
 
 
 from .common import InfoExtractor
 from .common import InfoExtractor
 from ..utils import (
 from ..utils import (
@@ -14,7 +13,10 @@ from ..utils import (
     remove_end,
     remove_end,
     unescapeHTML,
     unescapeHTML,
 )
 )
-from ..compat import compat_HTTPError
+from ..compat import (
+    compat_etree_fromstring,
+    compat_HTTPError,
+)
 
 
 
 
 class BBCCoUkIE(InfoExtractor):
 class BBCCoUkIE(InfoExtractor):
@@ -344,7 +346,7 @@ class BBCCoUkIE(InfoExtractor):
                 url, programme_id, 'Downloading media selection XML')
                 url, programme_id, 'Downloading media selection XML')
         except ExtractorError as ee:
         except ExtractorError as ee:
             if isinstance(ee.cause, compat_HTTPError) and ee.cause.code == 403:
             if isinstance(ee.cause, compat_HTTPError) and ee.cause.code == 403:
-                media_selection = xml.etree.ElementTree.fromstring(ee.cause.read().decode('utf-8'))
+                media_selection = compat_etree_fromstring(ee.cause.read().decode('utf-8'))
             else:
             else:
                 raise
                 raise
         return self._process_media_selector(media_selection, programme_id)
         return self._process_media_selector(media_selection, programme_id)

+ 4 - 2
youtube_dl/extractor/bilibili.py

@@ -4,9 +4,11 @@ from __future__ import unicode_literals
 import re
 import re
 import itertools
 import itertools
 import json
 import json
-import xml.etree.ElementTree as ET
 
 
 from .common import InfoExtractor
 from .common import InfoExtractor
+from ..compat import (
+    compat_etree_fromstring,
+)
 from ..utils import (
 from ..utils import (
     int_or_none,
     int_or_none,
     unified_strdate,
     unified_strdate,
@@ -88,7 +90,7 @@ class BiliBiliIE(InfoExtractor):
         except ValueError:
         except ValueError:
             pass
             pass
 
 
-        lq_doc = ET.fromstring(lq_page)
+        lq_doc = compat_etree_fromstring(lq_page)
         lq_durls = lq_doc.findall('./durl')
         lq_durls = lq_doc.findall('./durl')
 
 
         hq_doc = self._download_xml(
         hq_doc = self._download_xml(

+ 2 - 2
youtube_dl/extractor/brightcove.py

@@ -3,10 +3,10 @@ from __future__ import unicode_literals
 
 
 import re
 import re
 import json
 import json
-import xml.etree.ElementTree
 
 
 from .common import InfoExtractor
 from .common import InfoExtractor
 from ..compat import (
 from ..compat import (
+    compat_etree_fromstring,
     compat_parse_qs,
     compat_parse_qs,
     compat_str,
     compat_str,
     compat_urllib_parse,
     compat_urllib_parse,
@@ -119,7 +119,7 @@ class BrightcoveIE(InfoExtractor):
         object_str = fix_xml_ampersands(object_str)
         object_str = fix_xml_ampersands(object_str)
 
 
         try:
         try:
-            object_doc = xml.etree.ElementTree.fromstring(object_str.encode('utf-8'))
+            object_doc = compat_etree_fromstring(object_str.encode('utf-8'))
         except compat_xml_parse_error:
         except compat_xml_parse_error:
             return
             return
 
 

+ 2 - 2
youtube_dl/extractor/common.py

@@ -10,7 +10,6 @@ import re
 import socket
 import socket
 import sys
 import sys
 import time
 import time
-import xml.etree.ElementTree
 
 
 from ..compat import (
 from ..compat import (
     compat_cookiejar,
     compat_cookiejar,
@@ -23,6 +22,7 @@ from ..compat import (
     compat_urllib_request,
     compat_urllib_request,
     compat_urlparse,
     compat_urlparse,
     compat_str,
     compat_str,
+    compat_etree_fromstring,
 )
 )
 from ..utils import (
 from ..utils import (
     NO_DEFAULT,
     NO_DEFAULT,
@@ -461,7 +461,7 @@ class InfoExtractor(object):
             return xml_string
             return xml_string
         if transform_source:
         if transform_source:
             xml_string = transform_source(xml_string)
             xml_string = transform_source(xml_string)
-        return xml.etree.ElementTree.fromstring(xml_string.encode('utf-8'))
+        return compat_etree_fromstring(xml_string.encode('utf-8'))
 
 
     def _download_json(self, url_or_request, video_id,
     def _download_json(self, url_or_request, video_id,
                        note='Downloading JSON metadata',
                        note='Downloading JSON metadata',

+ 2 - 2
youtube_dl/extractor/crunchyroll.py

@@ -5,12 +5,12 @@ import re
 import json
 import json
 import base64
 import base64
 import zlib
 import zlib
-import xml.etree.ElementTree
 
 
 from hashlib import sha1
 from hashlib import sha1
 from math import pow, sqrt, floor
 from math import pow, sqrt, floor
 from .common import InfoExtractor
 from .common import InfoExtractor
 from ..compat import (
 from ..compat import (
+    compat_etree_fromstring,
     compat_urllib_parse,
     compat_urllib_parse,
     compat_urllib_parse_unquote,
     compat_urllib_parse_unquote,
     compat_urllib_request,
     compat_urllib_request,
@@ -234,7 +234,7 @@ Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
         return output
         return output
 
 
     def _extract_subtitles(self, subtitle):
     def _extract_subtitles(self, subtitle):
-        sub_root = xml.etree.ElementTree.fromstring(subtitle)
+        sub_root = compat_etree_fromstring(subtitle)
         return [{
         return [{
             'ext': 'srt',
             'ext': 'srt',
             'data': self._convert_subtitles_to_srt(sub_root),
             'data': self._convert_subtitles_to_srt(sub_root),

+ 2 - 2
youtube_dl/extractor/generic.py

@@ -9,6 +9,7 @@ import sys
 from .common import InfoExtractor
 from .common import InfoExtractor
 from .youtube import YoutubeIE
 from .youtube import YoutubeIE
 from ..compat import (
 from ..compat import (
+    compat_etree_fromstring,
     compat_urllib_parse_unquote,
     compat_urllib_parse_unquote,
     compat_urllib_request,
     compat_urllib_request,
     compat_urlparse,
     compat_urlparse,
@@ -21,7 +22,6 @@ from ..utils import (
     HEADRequest,
     HEADRequest,
     is_html,
     is_html,
     orderedSet,
     orderedSet,
-    parse_xml,
     smuggle_url,
     smuggle_url,
     unescapeHTML,
     unescapeHTML,
     unified_strdate,
     unified_strdate,
@@ -1238,7 +1238,7 @@ class GenericIE(InfoExtractor):
 
 
         # Is it an RSS feed, a SMIL file or a XSPF playlist?
         # Is it an RSS feed, a SMIL file or a XSPF playlist?
         try:
         try:
-            doc = parse_xml(webpage)
+            doc = compat_etree_fromstring(webpage.encode('utf-8'))
             if doc.tag == 'rss':
             if doc.tag == 'rss':
                 return self._extract_rss(url, video_id, doc)
                 return self._extract_rss(url, video_id, doc)
             elif re.match(r'^(?:{[^}]+})?smil$', doc.tag):
             elif re.match(r'^(?:{[^}]+})?smil$', doc.tag):

+ 3 - 3
youtube_dl/extractor/vevo.py

@@ -1,10 +1,10 @@
 from __future__ import unicode_literals
 from __future__ import unicode_literals
 
 
 import re
 import re
-import xml.etree.ElementTree
 
 
 from .common import InfoExtractor
 from .common import InfoExtractor
 from ..compat import (
 from ..compat import (
+    compat_etree_fromstring,
     compat_urllib_request,
     compat_urllib_request,
 )
 )
 from ..utils import (
 from ..utils import (
@@ -97,7 +97,7 @@ class VevoIE(InfoExtractor):
         if last_version['version'] == -1:
         if last_version['version'] == -1:
             raise ExtractorError('Unable to extract last version of the video')
             raise ExtractorError('Unable to extract last version of the video')
 
 
-        renditions = xml.etree.ElementTree.fromstring(last_version['data'])
+        renditions = compat_etree_fromstring(last_version['data'])
         formats = []
         formats = []
         # Already sorted from worst to best quality
         # Already sorted from worst to best quality
         for rend in renditions.findall('rendition'):
         for rend in renditions.findall('rendition'):
@@ -114,7 +114,7 @@ class VevoIE(InfoExtractor):
 
 
     def _formats_from_smil(self, smil_xml):
     def _formats_from_smil(self, smil_xml):
         formats = []
         formats = []
-        smil_doc = xml.etree.ElementTree.fromstring(smil_xml.encode('utf-8'))
+        smil_doc = compat_etree_fromstring(smil_xml.encode('utf-8'))
         els = smil_doc.findall('.//{http://www.w3.org/2001/SMIL20/Language}video')
         els = smil_doc.findall('.//{http://www.w3.org/2001/SMIL20/Language}video')
         for el in els:
         for el in els:
             src = el.attrib['src']
             src = el.attrib['src']

+ 2 - 24
youtube_dl/utils.py

@@ -36,6 +36,7 @@ import zlib
 from .compat import (
 from .compat import (
     compat_basestring,
     compat_basestring,
     compat_chr,
     compat_chr,
+    compat_etree_fromstring,
     compat_html_entities,
     compat_html_entities,
     compat_http_client,
     compat_http_client,
     compat_kwargs,
     compat_kwargs,
@@ -1665,29 +1666,6 @@ def encode_dict(d, encoding='utf-8'):
     return dict((k.encode(encoding), v.encode(encoding)) for k, v in d.items())
     return dict((k.encode(encoding), v.encode(encoding)) for k, v in d.items())
 
 
 
 
-try:
-    etree_iter = xml.etree.ElementTree.Element.iter
-except AttributeError:  # Python <=2.6
-    etree_iter = lambda n: n.findall('.//*')
-
-
-def parse_xml(s):
-    class TreeBuilder(xml.etree.ElementTree.TreeBuilder):
-        def doctype(self, name, pubid, system):
-            pass  # Ignore doctypes
-
-    parser = xml.etree.ElementTree.XMLParser(target=TreeBuilder())
-    kwargs = {'parser': parser} if sys.version_info >= (2, 7) else {}
-    tree = xml.etree.ElementTree.XML(s.encode('utf-8'), **kwargs)
-    # Fix up XML parser in Python 2.x
-    if sys.version_info < (3, 0):
-        for n in etree_iter(tree):
-            if n.text is not None:
-                if not isinstance(n.text, compat_str):
-                    n.text = n.text.decode('utf-8')
-    return tree
-
-
 US_RATINGS = {
 US_RATINGS = {
     'G': 0,
     'G': 0,
     'PG': 10,
     'PG': 10,
@@ -1988,7 +1966,7 @@ def dfxp2srt(dfxp_data):
 
 
         return out
         return out
 
 
-    dfxp = xml.etree.ElementTree.fromstring(dfxp_data.encode('utf-8'))
+    dfxp = compat_etree_fromstring(dfxp_data.encode('utf-8'))
     out = []
     out = []
     paras = dfxp.findall(_x('.//ttml:p')) or dfxp.findall(_x('.//ttaf1:p')) or dfxp.findall('.//p')
     paras = dfxp.findall(_x('.//ttml:p')) or dfxp.findall(_x('.//ttaf1:p')) or dfxp.findall('.//p')