Skip to content

Instantly share code, notes, and snippets.

@cellularmitosis
Last active April 4, 2024 21:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cellularmitosis/b07802f19c37f64196ecdf08d0342e80 to your computer and use it in GitHub Desktop.
Save cellularmitosis/b07802f19c37f64196ecdf08d0342e80 to your computer and use it in GitHub Desktop.
import lxml
from lxml import html
import urllib.request
import os

class HTMLNode:
    """A wrapper around lxml nodes with a fluent interface which supports optional chaining."""
    def __init__(self, arg=None):
        if arg is None:
            # arg is None, this is a NilNode.
            self.lxml_node = None
        elif isinstance(arg, lxml.html.HtmlElement):
            # arg is already an lxml node.
            self.lxml_node = arg
        elif isinstance(arg, bytes):
            # arg is raw bytes, try utf-8 with latin-1 as a fallback.
            try:
                decoded = arg.decode('utf-8')
            except:
                decoded = arg.decode('latin-1')
            self.__init__(decoded)
        elif isinstance(arg, str):
            # arg is a decoded string.
            if arg.startswith('http://') or arg.startswith('https://'):
                # arg is a URL.
                self.__init__(urllib.request.urlopen(arg).read())
            elif os.path.exists(arg):
                # arg is a file path.
                with open(arg, 'rb') as fd:
                    self.__init__(fd.read())
            else:
                # assume arg is html, hand it to lxml.
                self.lxml_node = html.fromstring(arg)
        else:
            raise Exception("Don't know how to handle %s" % arg)

    def __getattr__(self, attr):
        def without_s(x):
            if x.endswith('s'):
                return x[:-1]
            else:
                return x
        html_tags = [
            'html','head','meta','link','script',
            'body','div','span','p','a','img','br',
            'table','tr','td','ul','li','li','dl','dt','dd',
            'h1','h2','h3','h4','h5','h6','b','i','u','strong','em',
        ]
        if without_s(attr) in html_tags:
            # 'attr' is an HTML tag.
            # This is a 'branch' node of the tree, so we return another HTMLNode.
            # Call _node() or _nodes().
            if attr.endswith('s'):
                return lambda id=None, class_=None: self._nodes(tag=attr[:-1], id=id, class_=class_)
            else:
                return lambda id=None, class_=None: self._node(tag=attr, id=id, class_=class_)
        else:
            # This is a 'leaf' node of the tree, so we return a value or None.
            if self.lxml_node is None:
                # NilNodes have no values.
                return None
            # If this is an lxml node property, e.g. 'text', defer to lxml.
            if hasattr(self.lxml_node, attr):
                return getattr(self.lxml_node, attr)
            # Try to interpret attr as an HTML attribute, e.g. 'id', 'class', 'href', etc.
            # Note: 'class' is a python keyword, so we use 'class_' instead.
            if attr == 'class_':
                attr = 'class'
            # Note: python identifiers can't use dashes, so use underscores for custom data-* attributes. 
            if attr.startswith('data_'):
                attr.replace('_','-')
            for key, value in self.lxml_node.items():
                if key == attr:
                    return value
            # Lastly, assume this is an attribute which doesn't exist on this HTML node.
            return None

    def _nodes(self, tag, id=None, class_=None):
        results = []
        if self.lxml_node is None:
            return results
        for ch in self.lxml_node.getchildren():
            if ch.tag == tag:
                if id is not None:
                    if self._lxml_node_has_id(ch, id):
                        results.append(HTMLNode(ch))
                    else:
                        continue
                elif class_ is not None:
                    if self._lxml_node_has_class(ch, class_):
                        results.append(HTMLNode(ch))
                    else:
                        continue
                else:
                    results.append(HTMLNode(ch))
        return results

    def _node(self, tag, id=None, class_=None):
        results = self._nodes(tag, id, class_)
        if len(results) == 0:
            return NilNode()
        else:
            return results[0]

    def _lxml_node_has_class(self, node, class_):
        for key, value in node.items():
            if key == "class":
                classes = value.split()
                if class_ in classes:
                    return True
        return False

    def _lxml_node_has_id(self, node, id):
        for key, value in node.items():
            if key == "id" and value == id:
                return True
        return False

    def __str__(self):
        if self.lxml_node is None:
            return "nil"
        s = "<"
        s += self.lxml_node.tag
        if self.id is not None:
            s += ' id="%s"' % self.id
        if self.class_ is not None:
            s += ' class="%s"' % self.class_
        s += ">"
        if self.text is not None and len(self.text) > 0:
            s += self.text
            s += "</" + self.lxml_node.tag + ">"
        return s

    def __repr__(self):
        return str(self)

class NilNode(HTMLNode):
    pass

Init from str:

root = HTMLNode('<html></html>')

Init from bytes:

root = HTMLNode(b'<html></html>')

Init from a URL:

import urllib.request
root = HTMLNode('http://leopard.sh')
root = HTMLNode(b'http://leopard.sh')

Init from a file path:

root = HTMLNode('/tmp/foo.html')
root = HTMLNode(b'/tmp/foo.html')

Init from lxml HTML node:

from lxml import html
lnode = html.fromstring('<html></html>')
root = HTMLNode(lnode)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment