浏览代码

Introduce get_elements_by_class and get_elements_by_attribute utility functions

Thomas Christlieb 8 年之前
父节点
当前提交
2af12ad9d2
共有 2 个文件被更改,包括 51 次插入10 次删除
  1. 29 0
      test/test_utils.py
  2. 22 10
      youtube_dl/utils.py

+ 29 - 0
test/test_utils.py

@@ -34,6 +34,9 @@ from youtube_dl.utils import (
     find_xpath_attr,
     find_xpath_attr,
     fix_xml_ampersands,
     fix_xml_ampersands,
     get_element_by_class,
     get_element_by_class,
+    get_element_by_attribute,
+    get_elements_by_class,
+    get_elements_by_attribute,
     InAdvancePagedList,
     InAdvancePagedList,
     intlist_to_bytes,
     intlist_to_bytes,
     is_html,
     is_html,
@@ -1124,6 +1127,32 @@ The first line
         self.assertEqual(get_element_by_class('foo', html), 'nice')
         self.assertEqual(get_element_by_class('foo', html), 'nice')
         self.assertEqual(get_element_by_class('no-such-class', html), None)
         self.assertEqual(get_element_by_class('no-such-class', html), None)
 
 
+    def test_get_element_by_attribute(self):
+        html = '''
+            <span class="foo bar">nice</span>
+        '''
+
+        self.assertEqual(get_element_by_attribute('class', 'foo bar', html), 'nice')
+        self.assertEqual(get_element_by_attribute('class', 'foo', html), None)
+        self.assertEqual(get_element_by_attribute('class', 'no-such-foo', html), None)
+
+    def test_get_elements_by_class(self):
+        html = '''
+            <span class="foo bar">nice</span><span class="foo bar">also nice</span>
+        '''
+
+        self.assertEqual(get_elements_by_class('foo', html), ['nice', 'also nice'])
+        self.assertEqual(get_elements_by_class('no-such-class', html), [])
+
+    def test_get_elements_by_attribute(self):
+        html = '''
+            <span class="foo bar">nice</span><span class="foo bar">also nice</span>
+        '''
+
+        self.assertEqual(get_elements_by_attribute('class', 'foo bar', html), ['nice', 'also nice'])
+        self.assertEqual(get_elements_by_attribute('class', 'foo', html), [])
+        self.assertEqual(get_elements_by_attribute('class', 'no-such-foo', html), [])
+
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     unittest.main()
     unittest.main()

+ 22 - 10
youtube_dl/utils.py

@@ -337,17 +337,30 @@ def get_element_by_id(id, html):
 
 
 
 
 def get_element_by_class(class_name, html):
 def get_element_by_class(class_name, html):
-    return get_element_by_attribute(
+    """Return the content of the first tag with the specified class in the passed HTML document"""
+    retval = get_elements_by_class(class_name, html)
+    return retval[0] if retval else None
+
+
+def get_element_by_attribute(attribute, value, html, escape_value=True):
+    retval = get_elements_by_attribute(attribute, value, html, escape_value)
+    return retval[0] if retval else None
+
+
+def get_elements_by_class(class_name, html):
+    """Return the content of all tags with the specified class in the passed HTML document as a list"""
+    return get_elements_by_attribute(
         'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name),
         'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name),
         html, escape_value=False)
         html, escape_value=False)
 
 
 
 
-def get_element_by_attribute(attribute, value, html, escape_value=True):
+def get_elements_by_attribute(attribute, value, html, escape_value=True):
     """Return the content of the tag with the specified attribute in the passed HTML document"""
     """Return the content of the tag with the specified attribute in the passed HTML document"""
 
 
     value = re.escape(value) if escape_value else value
     value = re.escape(value) if escape_value else value
 
 
-    m = re.search(r'''(?xs)
+    retlist = []
+    for m in re.finditer(r'''(?xs)
         <([a-zA-Z0-9:._-]+)
         <([a-zA-Z0-9:._-]+)
          (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
          (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
          \s+%s=['"]?%s['"]?
          \s+%s=['"]?%s['"]?
@@ -355,16 +368,15 @@ def get_element_by_attribute(attribute, value, html, escape_value=True):
         \s*>
         \s*>
         (?P<content>.*?)
         (?P<content>.*?)
         </\1>
         </\1>
-    ''' % (re.escape(attribute), value), html)
+    ''' % (re.escape(attribute), value), html):
+        res = m.group('content')
 
 
-    if not m:
-        return None
-    res = m.group('content')
+        if res.startswith('"') or res.startswith("'"):
+            res = res[1:-1]
 
 
-    if res.startswith('"') or res.startswith("'"):
-        res = res[1:-1]
+        retlist.append(unescapeHTML(res))
 
 
-    return unescapeHTML(res)
+    return retlist
 
 
 
 
 class HTMLAttributeParser(compat_HTMLParser):
 class HTMLAttributeParser(compat_HTMLParser):