Created
October 19, 2011 08:05
-
-
Save jdennes/1297720 to your computer and use it in GitHub Desktop.
Prototype of a super-simple feedparser web service
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
.DS_Store | |
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""web.py: makes web apps (http://webpy.org)""" | |
from __future__ import generators | |
__version__ = "0.37" | |
__author__ = [ | |
"Aaron Swartz <me@aaronsw.com>", | |
"Anand Chitipothu <anandology@gmail.com>" | |
] | |
__license__ = "public domain" | |
__contributors__ = "see http://webpy.org/changes" | |
import utils, db, net, wsgi, http, webapi, httpserver, debugerror | |
import template, form | |
import session | |
from utils import * | |
from db import * | |
from net import * | |
from wsgi import * | |
from http import * | |
from webapi import * | |
from httpserver import * | |
from debugerror import * | |
from application import * | |
from browser import * | |
try: | |
import webopenid as openid | |
except ImportError: | |
pass # requires openid module | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A high-speed, production ready, thread pooled, generic HTTP server. | |
Simplest example on how to use this module directly | |
(without using CherryPy's application machinery):: | |
from cherrypy import wsgiserver | |
def my_crazy_app(environ, start_response): | |
status = '200 OK' | |
response_headers = [('Content-type','text/plain')] | |
start_response(status, response_headers) | |
return ['Hello world!'] | |
server = wsgiserver.CherryPyWSGIServer( | |
('0.0.0.0', 8070), my_crazy_app, | |
server_name='www.cherrypy.example') | |
server.start() | |
The CherryPy WSGI server can serve as many WSGI applications | |
as you want in one instance by using a WSGIPathInfoDispatcher:: | |
d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app}) | |
server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d) | |
Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance. | |
This won't call the CherryPy engine (application side) at all, only the | |
HTTP server, which is independent from the rest of CherryPy. Don't | |
let the name "CherryPyWSGIServer" throw you; the name merely reflects | |
its origin, not its coupling. | |
For those of you wanting to understand internals of this module, here's the | |
basic call flow. The server's listening thread runs a very tight loop, | |
sticking incoming connections onto a Queue:: | |
server = CherryPyWSGIServer(...) | |
server.start() | |
while True: | |
tick() | |
# This blocks until a request comes in: | |
child = socket.accept() | |
conn = HTTPConnection(child, ...) | |
server.requests.put(conn) | |
Worker threads are kept in a pool and poll the Queue, popping off and then | |
handling each connection in turn. Each connection can consist of an arbitrary | |
number of requests and their responses, so we run a nested loop:: | |
while True: | |
conn = server.requests.get() | |
conn.communicate() | |
-> while True: | |
req = HTTPRequest(...) | |
req.parse_request() | |
-> # Read the Request-Line, e.g. "GET /page HTTP/1.1" | |
req.rfile.readline() | |
read_headers(req.rfile, req.inheaders) | |
req.respond() | |
-> response = app(...) | |
try: | |
for chunk in response: | |
if chunk: | |
req.write(chunk) | |
finally: | |
if hasattr(response, "close"): | |
response.close() | |
if req.close_connection: | |
return | |
""" | |
CRLF = '\r\n' | |
import os | |
import Queue | |
import re | |
quoted_slash = re.compile("(?i)%2F") | |
import rfc822 | |
import socket | |
import sys | |
if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'): | |
socket.IPPROTO_IPV6 = 41 | |
try: | |
import cStringIO as StringIO | |
except ImportError: | |
import StringIO | |
DEFAULT_BUFFER_SIZE = -1 | |
_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring) | |
import threading | |
import time | |
import traceback | |
def format_exc(limit=None): | |
"""Like print_exc() but return a string. Backport for Python 2.3.""" | |
try: | |
etype, value, tb = sys.exc_info() | |
return ''.join(traceback.format_exception(etype, value, tb, limit)) | |
finally: | |
etype = value = tb = None | |
from urllib import unquote | |
from urlparse import urlparse | |
import warnings | |
import errno | |
def plat_specific_errors(*errnames): | |
"""Return error numbers for all errors in errnames on this platform. | |
The 'errno' module contains different global constants depending on | |
the specific platform (OS). This function will return the list of | |
numeric values for a given list of potential names. | |
""" | |
errno_names = dir(errno) | |
nums = [getattr(errno, k) for k in errnames if k in errno_names] | |
# de-dupe the list | |
return dict.fromkeys(nums).keys() | |
socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") | |
socket_errors_to_ignore = plat_specific_errors( | |
"EPIPE", | |
"EBADF", "WSAEBADF", | |
"ENOTSOCK", "WSAENOTSOCK", | |
"ETIMEDOUT", "WSAETIMEDOUT", | |
"ECONNREFUSED", "WSAECONNREFUSED", | |
"ECONNRESET", "WSAECONNRESET", | |
"ECONNABORTED", "WSAECONNABORTED", | |
"ENETRESET", "WSAENETRESET", | |
"EHOSTDOWN", "EHOSTUNREACH", | |
) | |
socket_errors_to_ignore.append("timed out") | |
socket_errors_to_ignore.append("The read operation timed out") | |
socket_errors_nonblocking = plat_specific_errors( | |
'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK') | |
comma_separated_headers = ['Accept', 'Accept-Charset', 'Accept-Encoding', | |
'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control', | |
'Connection', 'Content-Encoding', 'Content-Language', 'Expect', | |
'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE', | |
'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning', | |
'WWW-Authenticate'] | |
import logging | |
if not hasattr(logging, 'statistics'): logging.statistics = {} | |
def read_headers(rfile, hdict=None): | |
"""Read headers from the given stream into the given header dict. | |
If hdict is None, a new header dict is created. Returns the populated | |
header dict. | |
Headers which are repeated are folded together using a comma if their | |
specification so dictates. | |
This function raises ValueError when the read bytes violate the HTTP spec. | |
You should probably return "400 Bad Request" if this happens. | |
""" | |
if hdict is None: | |
hdict = {} | |
while True: | |
line = rfile.readline() | |
if not line: | |
# No more data--illegal end of headers | |
raise ValueError("Illegal end of headers.") | |
if line == CRLF: | |
# Normal end of headers | |
break | |
if not line.endswith(CRLF): | |
raise ValueError("HTTP requires CRLF terminators") | |
if line[0] in ' \t': | |
# It's a continuation line. | |
v = line.strip() | |
else: | |
try: | |
k, v = line.split(":", 1) | |
except ValueError: | |
raise ValueError("Illegal header line.") | |
# TODO: what about TE and WWW-Authenticate? | |
k = k.strip().title() | |
v = v.strip() | |
hname = k | |
if k in comma_separated_headers: | |
existing = hdict.get(hname) | |
if existing: | |
v = ", ".join((existing, v)) | |
hdict[hname] = v | |
return hdict | |
class MaxSizeExceeded(Exception): | |
pass | |
class SizeCheckWrapper(object): | |
"""Wraps a file-like object, raising MaxSizeExceeded if too large.""" | |
def __init__(self, rfile, maxlen): | |
self.rfile = rfile | |
self.maxlen = maxlen | |
self.bytes_read = 0 | |
def _check_length(self): | |
if self.maxlen and self.bytes_read > self.maxlen: | |
raise MaxSizeExceeded() | |
def read(self, size=None): | |
data = self.rfile.read(size) | |
self.bytes_read += len(data) | |
self._check_length() | |
return data | |
def readline(self, size=None): | |
if size is not None: | |
data = self.rfile.readline(size) | |
self.bytes_read += len(data) | |
self._check_length() | |
return data | |
# User didn't specify a size ... | |
# We read the line in chunks to make sure it's not a 100MB line ! | |
res = [] | |
while True: | |
data = self.rfile.readline(256) | |
self.bytes_read += len(data) | |
self._check_length() | |
res.append(data) | |
# See http://www.cherrypy.org/ticket/421 | |
if len(data) < 256 or data[-1:] == "\n": | |
return ''.join(res) | |
def readlines(self, sizehint=0): | |
# Shamelessly stolen from StringIO | |
total = 0 | |
lines = [] | |
line = self.readline() | |
while line: | |
lines.append(line) | |
total += len(line) | |
if 0 < sizehint <= total: | |
break | |
line = self.readline() | |
return lines | |
def close(self): | |
self.rfile.close() | |
def __iter__(self): | |
return self | |
def next(self): | |
data = self.rfile.next() | |
self.bytes_read += len(data) | |
self._check_length() | |
return data | |
class KnownLengthRFile(object): | |
"""Wraps a file-like object, returning an empty string when exhausted.""" | |
def __init__(self, rfile, content_length): | |
self.rfile = rfile | |
self.remaining = content_length | |
def read(self, size=None): | |
if self.remaining == 0: | |
return '' | |
if size is None: | |
size = self.remaining | |
else: | |
size = min(size, self.remaining) | |
data = self.rfile.read(size) | |
self.remaining -= len(data) | |
return data | |
def readline(self, size=None): | |
if self.remaining == 0: | |
return '' | |
if size is None: | |
size = self.remaining | |
else: | |
size = min(size, self.remaining) | |
data = self.rfile.readline(size) | |
self.remaining -= len(data) | |
return data | |
def readlines(self, sizehint=0): | |
# Shamelessly stolen from StringIO | |
total = 0 | |
lines = [] | |
line = self.readline(sizehint) | |
while line: | |
lines.append(line) | |
total += len(line) | |
if 0 < sizehint <= total: | |
break | |
line = self.readline(sizehint) | |
return lines | |
def close(self): | |
self.rfile.close() | |
def __iter__(self): | |
return self | |
def __next__(self): | |
data = next(self.rfile) | |
self.remaining -= len(data) | |
return data | |
class ChunkedRFile(object): | |
"""Wraps a file-like object, returning an empty string when exhausted. | |
This class is intended to provide a conforming wsgi.input value for | |
request entities that have been encoded with the 'chunked' transfer | |
encoding. | |
""" | |
def __init__(self, rfile, maxlen, bufsize=8192): | |
self.rfile = rfile | |
self.maxlen = maxlen | |
self.bytes_read = 0 | |
self.buffer = '' | |
self.bufsize = bufsize | |
self.closed = False | |
def _fetch(self): | |
if self.closed: | |
return | |
line = self.rfile.readline() | |
self.bytes_read += len(line) | |
if self.maxlen and self.bytes_read > self.maxlen: | |
raise MaxSizeExceeded("Request Entity Too Large", self.maxlen) | |
line = line.strip().split(";", 1) | |
try: | |
chunk_size = line.pop(0) | |
chunk_size = int(chunk_size, 16) | |
except ValueError: | |
raise ValueError("Bad chunked transfer size: " + repr(chunk_size)) | |
if chunk_size <= 0: | |
self.closed = True | |
return | |
## if line: chunk_extension = line[0] | |
if self.maxlen and self.bytes_read + chunk_size > self.maxlen: | |
raise IOError("Request Entity Too Large") | |
chunk = self.rfile.read(chunk_size) | |
self.bytes_read += len(chunk) | |
self.buffer += chunk | |
crlf = self.rfile.read(2) | |
if crlf != CRLF: | |
raise ValueError( | |
"Bad chunked transfer coding (expected '\\r\\n', " | |
"got " + repr(crlf) + ")") | |
def read(self, size=None): | |
data = '' | |
while True: | |
if size and len(data) >= size: | |
return data | |
if not self.buffer: | |
self._fetch() | |
if not self.buffer: | |
# EOF | |
return data | |
if size: | |
remaining = size - len(data) | |
data += self.buffer[:remaining] | |
self.buffer = self.buffer[remaining:] | |
else: | |
data += self.buffer | |
def readline(self, size=None): | |
data = '' | |
while True: | |
if size and len(data) >= size: | |
return data | |
if not self.buffer: | |
self._fetch() | |
if not self.buffer: | |
# EOF | |
return data | |
newline_pos = self.buffer.find('\n') | |
if size: | |
if newline_pos == -1: | |
remaining = size - len(data) | |
data += self.buffer[:remaining] | |
self.buffer = self.buffer[remaining:] | |
else: | |
remaining = min(size - len(data), newline_pos) | |
data += self.buffer[:remaining] | |
self.buffer = self.buffer[remaining:] | |
else: | |
if newline_pos == -1: | |
data += self.buffer | |
else: | |
data += self.buffer[:newline_pos] | |
self.buffer = self.buffer[newline_pos:] | |
def readlines(self, sizehint=0): | |
# Shamelessly stolen from StringIO | |
total = 0 | |
lines = [] | |
line = self.readline(sizehint) | |
while line: | |
lines.append(line) | |
total += len(line) | |
if 0 < sizehint <= total: | |
break | |
line = self.readline(sizehint) | |
return lines | |
def read_trailer_lines(self): | |
if not self.closed: | |
raise ValueError( | |
"Cannot read trailers until the request body has been read.") | |
while True: | |
line = self.rfile.readline() | |
if not line: | |
# No more data--illegal end of headers | |
raise ValueError("Illegal end of headers.") | |
self.bytes_read += len(line) | |
if self.maxlen and self.bytes_read > self.maxlen: | |
raise IOError("Request Entity Too Large") | |
if line == CRLF: | |
# Normal end of headers | |
break | |
if not line.endswith(CRLF): | |
raise ValueError("HTTP requires CRLF terminators") | |
yield line | |
def close(self): | |
self.rfile.close() | |
def __iter__(self): | |
# Shamelessly stolen from StringIO | |
total = 0 | |
line = self.readline(sizehint) | |
while line: | |
yield line | |
total += len(line) | |
if 0 < sizehint <= total: | |
break | |
line = self.readline(sizehint) | |
class HTTPRequest(object): | |
"""An HTTP Request (and response). | |
A single HTTP connection may consist of multiple request/response pairs. | |
""" | |
server = None | |
"""The HTTPServer object which is receiving this request.""" | |
conn = None | |
"""The HTTPConnection object on which this request connected.""" | |
inheaders = {} | |
"""A dict of request headers.""" | |
outheaders = [] | |
"""A list of header tuples to write in the response.""" | |
ready = False | |
"""When True, the request has been parsed and is ready to begin generating | |
the response. When False, signals the calling Connection that the response | |
should not be generated and the connection should close.""" | |
close_connection = False | |
"""Signals the calling Connection that the request should close. This does | |
not imply an error! The client and/or server may each request that the | |
connection be closed.""" | |
chunked_write = False | |
"""If True, output will be encoded with the "chunked" transfer-coding. | |
This value is set automatically inside send_headers.""" | |
def __init__(self, server, conn): | |
self.server= server | |
self.conn = conn | |
self.ready = False | |
self.started_request = False | |
self.scheme = "http" | |
if self.server.ssl_adapter is not None: | |
self.scheme = "https" | |
# Use the lowest-common protocol in case read_request_line errors. | |
self.response_protocol = 'HTTP/1.0' | |
self.inheaders = {} | |
self.status = "" | |
self.outheaders = [] | |
self.sent_headers = False | |
self.close_connection = self.__class__.close_connection | |
self.chunked_read = False | |
self.chunked_write = self.__class__.chunked_write | |
def parse_request(self): | |
"""Parse the next HTTP request start-line and message-headers.""" | |
self.rfile = SizeCheckWrapper(self.conn.rfile, | |
self.server.max_request_header_size) | |
try: | |
self.read_request_line() | |
except MaxSizeExceeded: | |
self.simple_response("414 Request-URI Too Long", | |
"The Request-URI sent with the request exceeds the maximum " | |
"allowed bytes.") | |
return | |
try: | |
success = self.read_request_headers() | |
except MaxSizeExceeded: | |
self.simple_response("413 Request Entity Too Large", | |
"The headers sent with the request exceed the maximum " | |
"allowed bytes.") | |
return | |
else: | |
if not success: | |
return | |
self.ready = True | |
def read_request_line(self): | |
# HTTP/1.1 connections are persistent by default. If a client | |
# requests a page, then idles (leaves the connection open), | |
# then rfile.readline() will raise socket.error("timed out"). | |
# Note that it does this based on the value given to settimeout(), | |
# and doesn't need the client to request or acknowledge the close | |
# (although your TCP stack might suffer for it: cf Apache's history | |
# with FIN_WAIT_2). | |
request_line = self.rfile.readline() | |
# Set started_request to True so communicate() knows to send 408 | |
# from here on out. | |
self.started_request = True | |
if not request_line: | |
# Force self.ready = False so the connection will close. | |
self.ready = False | |
return | |
if request_line == CRLF: | |
# RFC 2616 sec 4.1: "...if the server is reading the protocol | |
# stream at the beginning of a message and receives a CRLF | |
# first, it should ignore the CRLF." | |
# But only ignore one leading line! else we enable a DoS. | |
request_line = self.rfile.readline() | |
if not request_line: | |
self.ready = False | |
return | |
if not request_line.endswith(CRLF): | |
self.simple_response("400 Bad Request", "HTTP requires CRLF terminators") | |
return | |
try: | |
method, uri, req_protocol = request_line.strip().split(" ", 2) | |
rp = int(req_protocol[5]), int(req_protocol[7]) | |
except (ValueError, IndexError): | |
self.simple_response("400 Bad Request", "Malformed Request-Line") | |
return | |
self.uri = uri | |
self.method = method | |
# uri may be an abs_path (including "http://host.domain.tld"); | |
scheme, authority, path = self.parse_request_uri(uri) | |
if '#' in path: | |
self.simple_response("400 Bad Request", | |
"Illegal #fragment in Request-URI.") | |
return | |
if scheme: | |
self.scheme = scheme | |
qs = '' | |
if '?' in path: | |
path, qs = path.split('?', 1) | |
# Unquote the path+params (e.g. "/this%20path" -> "/this path"). | |
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2 | |
# | |
# But note that "...a URI must be separated into its components | |
# before the escaped characters within those components can be | |
# safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2 | |
# Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path". | |
try: | |
atoms = [unquote(x) for x in quoted_slash.split(path)] | |
except ValueError, ex: | |
self.simple_response("400 Bad Request", ex.args[0]) | |
return | |
path = "%2F".join(atoms) | |
self.path = path | |
# Note that, like wsgiref and most other HTTP servers, | |
# we "% HEX HEX"-unquote the path but not the query string. | |
self.qs = qs | |
# Compare request and server HTTP protocol versions, in case our | |
# server does not support the requested protocol. Limit our output | |
# to min(req, server). We want the following output: | |
# request server actual written supported response | |
# protocol protocol response protocol feature set | |
# a 1.0 1.0 1.0 1.0 | |
# b 1.0 1.1 1.1 1.0 | |
# c 1.1 1.0 1.0 1.0 | |
# d 1.1 1.1 1.1 1.1 | |
# Notice that, in (b), the response will be "HTTP/1.1" even though | |
# the client only understands 1.0. RFC 2616 10.5.6 says we should | |
# only return 505 if the _major_ version is different. | |
sp = int(self.server.protocol[5]), int(self.server.protocol[7]) | |
if sp[0] != rp[0]: | |
self.simple_response("505 HTTP Version Not Supported") | |
return | |
self.request_protocol = req_protocol | |
self.response_protocol = "HTTP/%s.%s" % min(rp, sp) | |
def read_request_headers(self): | |
"""Read self.rfile into self.inheaders. Return success.""" | |
# then all the http headers | |
try: | |
read_headers(self.rfile, self.inheaders) | |
except ValueError, ex: | |
self.simple_response("400 Bad Request", ex.args[0]) | |
return False | |
mrbs = self.server.max_request_body_size | |
if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs: | |
self.simple_response("413 Request Entity Too Large", | |
"The entity sent with the request exceeds the maximum " | |
"allowed bytes.") | |
return False | |
# Persistent connection support | |
if self.response_protocol == "HTTP/1.1": | |
# Both server and client are HTTP/1.1 | |
if self.inheaders.get("Connection", "") == "close": | |
self.close_connection = True | |
else: | |
# Either the server or client (or both) are HTTP/1.0 | |
if self.inheaders.get("Connection", "") != "Keep-Alive": | |
self.close_connection = True | |
# Transfer-Encoding support | |
te = None | |
if self.response_protocol == "HTTP/1.1": | |
te = self.inheaders.get("Transfer-Encoding") | |
if te: | |
te = [x.strip().lower() for x in te.split(",") if x.strip()] | |
self.chunked_read = False | |
if te: | |
for enc in te: | |
if enc == "chunked": | |
self.chunked_read = True | |
else: | |
# Note that, even if we see "chunked", we must reject | |
# if there is an extension we don't recognize. | |
self.simple_response("501 Unimplemented") | |
self.close_connection = True | |
return False | |
# From PEP 333: | |
# "Servers and gateways that implement HTTP 1.1 must provide | |
# transparent support for HTTP 1.1's "expect/continue" mechanism. | |
# This may be done in any of several ways: | |
# 1. Respond to requests containing an Expect: 100-continue request | |
# with an immediate "100 Continue" response, and proceed normally. | |
# 2. Proceed with the request normally, but provide the application | |
# with a wsgi.input stream that will send the "100 Continue" | |
# response if/when the application first attempts to read from | |
# the input stream. The read request must then remain blocked | |
# until the client responds. | |
# 3. Wait until the client decides that the server does not support | |
# expect/continue, and sends the request body on its own. | |
# (This is suboptimal, and is not recommended.) | |
# | |
# We used to do 3, but are now doing 1. Maybe we'll do 2 someday, | |
# but it seems like it would be a big slowdown for such a rare case. | |
if self.inheaders.get("Expect", "") == "100-continue": | |
# Don't use simple_response here, because it emits headers | |
# we don't want. See http://www.cherrypy.org/ticket/951 | |
msg = self.server.protocol + " 100 Continue\r\n\r\n" | |
try: | |
self.conn.wfile.sendall(msg) | |
except socket.error, x: | |
if x.args[0] not in socket_errors_to_ignore: | |
raise | |
return True | |
def parse_request_uri(self, uri): | |
"""Parse a Request-URI into (scheme, authority, path). | |
Note that Request-URI's must be one of:: | |
Request-URI = "*" | absoluteURI | abs_path | authority | |
Therefore, a Request-URI which starts with a double forward-slash | |
cannot be a "net_path":: | |
net_path = "//" authority [ abs_path ] | |
Instead, it must be interpreted as an "abs_path" with an empty first | |
path segment:: | |
abs_path = "/" path_segments | |
path_segments = segment *( "/" segment ) | |
segment = *pchar *( ";" param ) | |
param = *pchar | |
""" | |
if uri == "*": | |
return None, None, uri | |
i = uri.find('://') | |
if i > 0 and '?' not in uri[:i]: | |
# An absoluteURI. | |
# If there's a scheme (and it must be http or https), then: | |
# http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]] | |
scheme, remainder = uri[:i].lower(), uri[i + 3:] | |
authority, path = remainder.split("/", 1) | |
return scheme, authority, path | |
if uri.startswith('/'): | |
# An abs_path. | |
return None, None, uri | |
else: | |
# An authority. | |
return None, uri, None | |
def respond(self): | |
"""Call the gateway and write its iterable output.""" | |
mrbs = self.server.max_request_body_size | |
if self.chunked_read: | |
self.rfile = ChunkedRFile(self.conn.rfile, mrbs) | |
else: | |
cl = int(self.inheaders.get("Content-Length", 0)) | |
if mrbs and mrbs < cl: | |
if not self.sent_headers: | |
self.simple_response("413 Request Entity Too Large", | |
"The entity sent with the request exceeds the maximum " | |
"allowed bytes.") | |
return | |
self.rfile = KnownLengthRFile(self.conn.rfile, cl) | |
self.server.gateway(self).respond() | |
if (self.ready and not self.sent_headers): | |
self.sent_headers = True | |
self.send_headers() | |
if self.chunked_write: | |
self.conn.wfile.sendall("0\r\n\r\n") | |
def simple_response(self, status, msg=""): | |
"""Write a simple response back to the client.""" | |
status = str(status) | |
buf = [self.server.protocol + " " + | |
status + CRLF, | |
"Content-Length: %s\r\n" % len(msg), | |
"Content-Type: text/plain\r\n"] | |
if status[:3] in ("413", "414"): | |
# Request Entity Too Large / Request-URI Too Long | |
self.close_connection = True | |
if self.response_protocol == 'HTTP/1.1': | |
# This will not be true for 414, since read_request_line | |
# usually raises 414 before reading the whole line, and we | |
# therefore cannot know the proper response_protocol. | |
buf.append("Connection: close\r\n") | |
else: | |
# HTTP/1.0 had no 413/414 status nor Connection header. | |
# Emit 400 instead and trust the message body is enough. | |
status = "400 Bad Request" | |
buf.append(CRLF) | |
if msg: | |
if isinstance(msg, unicode): | |
msg = msg.encode("ISO-8859-1") | |
buf.append(msg) | |
try: | |
self.conn.wfile.sendall("".join(buf)) | |
except socket.error, x: | |
if x.args[0] not in socket_errors_to_ignore: | |
raise | |
def write(self, chunk): | |
"""Write unbuffered data to the client.""" | |
if self.chunked_write and chunk: | |
buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF] | |
self.conn.wfile.sendall("".join(buf)) | |
else: | |
self.conn.wfile.sendall(chunk) | |
def send_headers(self): | |
"""Assert, process, and send the HTTP response message-headers. | |
You must set self.status, and self.outheaders before calling this. | |
""" | |
hkeys = [key.lower() for key, value in self.outheaders] | |
status = int(self.status[:3]) | |
if status == 413: | |
# Request Entity Too Large. Close conn to avoid garbage. | |
self.close_connection = True | |
elif "content-length" not in hkeys: | |
# "All 1xx (informational), 204 (no content), | |
# and 304 (not modified) responses MUST NOT | |
# include a message-body." So no point chunking. | |
if status < 200 or status in (204, 205, 304): | |
pass | |
else: | |
if (self.response_protocol == 'HTTP/1.1' | |
and self.method != 'HEAD'): | |
# Use the chunked transfer-coding | |
self.chunked_write = True | |
self.outheaders.append(("Transfer-Encoding", "chunked")) | |
else: | |
# Closing the conn is the only way to determine len. | |
self.close_connection = True | |
if "connection" not in hkeys: | |
if self.response_protocol == 'HTTP/1.1': | |
# Both server and client are HTTP/1.1 or better | |
if self.close_connection: | |
self.outheaders.append(("Connection", "close")) | |
else: | |
# Server and/or client are HTTP/1.0 | |
if not self.close_connection: | |
self.outheaders.append(("Connection", "Keep-Alive")) | |
if (not self.close_connection) and (not self.chunked_read): | |
# Read any remaining request body data on the socket. | |
# "If an origin server receives a request that does not include an | |
# Expect request-header field with the "100-continue" expectation, | |
# the request includes a request body, and the server responds | |
# with a final status code before reading the entire request body | |
# from the transport connection, then the server SHOULD NOT close | |
# the transport connection until it has read the entire request, | |
# or until the client closes the connection. Otherwise, the client | |
# might not reliably receive the response message. However, this | |
# requirement is not be construed as preventing a server from | |
# defending itself against denial-of-service attacks, or from | |
# badly broken client implementations." | |
remaining = getattr(self.rfile, 'remaining', 0) | |
if remaining > 0: | |
self.rfile.read(remaining) | |
if "date" not in hkeys: | |
self.outheaders.append(("Date", rfc822.formatdate())) | |
if "server" not in hkeys: | |
self.outheaders.append(("Server", self.server.server_name)) | |
buf = [self.server.protocol + " " + self.status + CRLF] | |
for k, v in self.outheaders: | |
buf.append(k + ": " + v + CRLF) | |
buf.append(CRLF) | |
self.conn.wfile.sendall("".join(buf)) | |
class NoSSLError(Exception): | |
"""Exception raised when a client speaks HTTP to an HTTPS socket.""" | |
pass | |
class FatalSSLAlert(Exception): | |
"""Exception raised when the SSL implementation signals a fatal alert.""" | |
pass | |
class CP_fileobject(socket._fileobject): | |
"""Faux file object attached to a socket object.""" | |
def __init__(self, *args, **kwargs): | |
self.bytes_read = 0 | |
self.bytes_written = 0 | |
socket._fileobject.__init__(self, *args, **kwargs) | |
def sendall(self, data): | |
"""Sendall for non-blocking sockets.""" | |
while data: | |
try: | |
bytes_sent = self.send(data) | |
data = data[bytes_sent:] | |
except socket.error, e: | |
if e.args[0] not in socket_errors_nonblocking: | |
raise | |
def send(self, data): | |
bytes_sent = self._sock.send(data) | |
self.bytes_written += bytes_sent | |
return bytes_sent | |
def flush(self): | |
if self._wbuf: | |
buffer = "".join(self._wbuf) | |
self._wbuf = [] | |
self.sendall(buffer) | |
def recv(self, size): | |
while True: | |
try: | |
data = self._sock.recv(size) | |
self.bytes_read += len(data) | |
return data | |
except socket.error, e: | |
if (e.args[0] not in socket_errors_nonblocking | |
and e.args[0] not in socket_error_eintr): | |
raise | |
if not _fileobject_uses_str_type: | |
def read(self, size=-1): | |
# Use max, disallow tiny reads in a loop as they are very inefficient. | |
# We never leave read() with any leftover data from a new recv() call | |
# in our internal buffer. | |
rbufsize = max(self._rbufsize, self.default_bufsize) | |
# Our use of StringIO rather than lists of string objects returned by | |
# recv() minimizes memory usage and fragmentation that occurs when | |
# rbufsize is large compared to the typical return value of recv(). | |
buf = self._rbuf | |
buf.seek(0, 2) # seek end | |
if size < 0: | |
# Read until EOF | |
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. | |
while True: | |
data = self.recv(rbufsize) | |
if not data: | |
break | |
buf.write(data) | |
return buf.getvalue() | |
else: | |
# Read until size bytes or EOF seen, whichever comes first | |
buf_len = buf.tell() | |
if buf_len >= size: | |
# Already have size bytes in our buffer? Extract and return. | |
buf.seek(0) | |
rv = buf.read(size) | |
self._rbuf = StringIO.StringIO() | |
self._rbuf.write(buf.read()) | |
return rv | |
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. | |
while True: | |
left = size - buf_len | |
# recv() will malloc the amount of memory given as its | |
# parameter even though it often returns much less data | |
# than that. The returned data string is short lived | |
# as we copy it into a StringIO and free it. This avoids | |
# fragmentation issues on many platforms. | |
data = self.recv(left) | |
if not data: | |
break | |
n = len(data) | |
if n == size and not buf_len: | |
# Shortcut. Avoid buffer data copies when: | |
# - We have no data in our buffer. | |
# AND | |
# - Our call to recv returned exactly the | |
# number of bytes we were asked to read. | |
return data | |
if n == left: | |
buf.write(data) | |
del data # explicit free | |
break | |
assert n <= left, "recv(%d) returned %d bytes" % (left, n) | |
buf.write(data) | |
buf_len += n | |
del data # explicit free | |
#assert buf_len == buf.tell() | |
return buf.getvalue() | |
def readline(self, size=-1): | |
buf = self._rbuf | |
buf.seek(0, 2) # seek end | |
if buf.tell() > 0: | |
# check if we already have it in our buffer | |
buf.seek(0) | |
bline = buf.readline(size) | |
if bline.endswith('\n') or len(bline) == size: | |
self._rbuf = StringIO.StringIO() | |
self._rbuf.write(buf.read()) | |
return bline | |
del bline | |
if size < 0: | |
# Read until \n or EOF, whichever comes first | |
if self._rbufsize <= 1: | |
# Speed up unbuffered case | |
buf.seek(0) | |
buffers = [buf.read()] | |
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. | |
data = None | |
recv = self.recv | |
while data != "\n": | |
data = recv(1) | |
if not data: | |
break | |
buffers.append(data) | |
return "".join(buffers) | |
buf.seek(0, 2) # seek end | |
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. | |
while True: | |
data = self.recv(self._rbufsize) | |
if not data: | |
break | |
nl = data.find('\n') | |
if nl >= 0: | |
nl += 1 | |
buf.write(data[:nl]) | |
self._rbuf.write(data[nl:]) | |
del data | |
break | |
buf.write(data) | |
return buf.getvalue() | |
else: | |
# Read until size bytes or \n or EOF seen, whichever comes first | |
buf.seek(0, 2) # seek end | |
buf_len = buf.tell() | |
if buf_len >= size: | |
buf.seek(0) | |
rv = buf.read(size) | |
self._rbuf = StringIO.StringIO() | |
self._rbuf.write(buf.read()) | |
return rv | |
self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf. | |
while True: | |
data = self.recv(self._rbufsize) | |
if not data: | |
break | |
left = size - buf_len | |
# did we just receive a newline? | |
nl = data.find('\n', 0, left) | |
if nl >= 0: | |
nl += 1 | |
# save the excess data to _rbuf | |
self._rbuf.write(data[nl:]) | |
if buf_len: | |
buf.write(data[:nl]) | |
break | |
else: | |
# Shortcut. Avoid data copy through buf when returning | |
# a substring of our first recv(). | |
return data[:nl] | |
n = len(data) | |
if n == size and not buf_len: | |
# Shortcut. Avoid data copy through buf when | |
# returning exactly all of our first recv(). | |
return data | |
if n >= left: | |
buf.write(data[:left]) | |
self._rbuf.write(data[left:]) | |
break | |
buf.write(data) | |
buf_len += n | |
#assert buf_len == buf.tell() | |
return buf.getvalue() | |
else: | |
def read(self, size=-1): | |
if size < 0: | |
# Read until EOF | |
buffers = [self._rbuf] | |
self._rbuf = "" | |
if self._rbufsize <= 1: | |
recv_size = self.default_bufsize | |
else: | |
recv_size = self._rbufsize | |
while True: | |
data = self.recv(recv_size) | |
if not data: | |
break | |
buffers.append(data) | |
return "".join(buffers) | |
else: | |
# Read until size bytes or EOF seen, whichever comes first | |
data = self._rbuf | |
buf_len = len(data) | |
if buf_len >= size: | |
self._rbuf = data[size:] | |
return data[:size] | |
buffers = [] | |
if data: | |
buffers.append(data) | |
self._rbuf = "" | |
while True: | |
left = size - buf_len | |
recv_size = max(self._rbufsize, left) | |
data = self.recv(recv_size) | |
if not data: | |
break | |
buffers.append(data) | |
n = len(data) | |
if n >= left: | |
self._rbuf = data[left:] | |
buffers[-1] = data[:left] | |
break | |
buf_len += n | |
return "".join(buffers) | |
def readline(self, size=-1): | |
data = self._rbuf | |
if size < 0: | |
# Read until \n or EOF, whichever comes first | |
if self._rbufsize <= 1: | |
# Speed up unbuffered case | |
assert data == "" | |
buffers = [] | |
while data != "\n": | |
data = self.recv(1) | |
if not data: | |
break | |
buffers.append(data) | |
return "".join(buffers) | |
nl = data.find('\n') | |
if nl >= 0: | |
nl += 1 | |
self._rbuf = data[nl:] | |
return data[:nl] | |
buffers = [] | |
if data: | |
buffers.append(data) | |
self._rbuf = "" | |
while True: | |
data = self.recv(self._rbufsize) | |
if not data: | |
break | |
buffers.append(data) | |
nl = data.find('\n') | |
if nl >= 0: | |
nl += 1 | |
self._rbuf = data[nl:] | |
buffers[-1] = data[:nl] | |
break | |
return "".join(buffers) | |
else: | |
# Read until size bytes or \n or EOF seen, whichever comes first | |
nl = data.find('\n', 0, size) | |
if nl >= 0: | |
nl += 1 | |
self._rbuf = data[nl:] | |
return data[:nl] | |
buf_len = len(data) | |
if buf_len >= size: | |
self._rbuf = data[size:] | |
return data[:size] | |
buffers = [] | |
if data: | |
buffers.append(data) | |
self._rbuf = "" | |
while True: | |
data = self.recv(self._rbufsize) | |
if not data: | |
break | |
buffers.append(data) | |
left = size - buf_len | |
nl = data.find('\n', 0, left) | |
if nl >= 0: | |
nl += 1 | |
self._rbuf = data[nl:] | |
buffers[-1] = data[:nl] | |
break | |
n = len(data) | |
if n >= left: | |
self._rbuf = data[left:] | |
buffers[-1] = data[:left] | |
break | |
buf_len += n | |
return "".join(buffers) | |
class HTTPConnection(object): | |
"""An HTTP connection (active socket). | |
server: the Server object which received this connection. | |
socket: the raw socket object (usually TCP) for this connection. | |
makefile: a fileobject class for reading from the socket. | |
""" | |
remote_addr = None | |
remote_port = None | |
ssl_env = None | |
rbufsize = DEFAULT_BUFFER_SIZE | |
wbufsize = DEFAULT_BUFFER_SIZE | |
RequestHandlerClass = HTTPRequest | |
def __init__(self, server, sock, makefile=CP_fileobject): | |
self.server = server | |
self.socket = sock | |
self.rfile = makefile(sock, "rb", self.rbufsize) | |
self.wfile = makefile(sock, "wb", self.wbufsize) | |
self.requests_seen = 0 | |
def communicate(self): | |
"""Read each request and respond appropriately.""" | |
request_seen = False | |
try: | |
while True: | |
# (re)set req to None so that if something goes wrong in | |
# the RequestHandlerClass constructor, the error doesn't | |
# get written to the previous request. | |
req = None | |
req = self.RequestHandlerClass(self.server, self) | |
# This order of operations should guarantee correct pipelining. | |
req.parse_request() | |
if self.server.stats['Enabled']: | |
self.requests_seen += 1 | |
if not req.ready: | |
# Something went wrong in the parsing (and the server has | |
# probably already made a simple_response). Return and | |
# let the conn close. | |
return | |
request_seen = True | |
req.respond() | |
if req.close_connection: | |
return | |
except socket.error, e: | |
errnum = e.args[0] | |
# sadly SSL sockets return a different (longer) time out string | |
if errnum == 'timed out' or errnum == 'The read operation timed out': | |
# Don't error if we're between requests; only error | |
# if 1) no request has been started at all, or 2) we're | |
# in the middle of a request. | |
# See http://www.cherrypy.org/ticket/853 | |
if (not request_seen) or (req and req.started_request): | |
# Don't bother writing the 408 if the response | |
# has already started being written. | |
if req and not req.sent_headers: | |
try: | |
req.simple_response("408 Request Timeout") | |
except FatalSSLAlert: | |
# Close the connection. | |
return | |
elif errnum not in socket_errors_to_ignore: | |
if req and not req.sent_headers: | |
try: | |
req.simple_response("500 Internal Server Error", | |
format_exc()) | |
except FatalSSLAlert: | |
# Close the connection. | |
return | |
return | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except FatalSSLAlert: | |
# Close the connection. | |
return | |
except NoSSLError: | |
if req and not req.sent_headers: | |
# Unwrap our wfile | |
self.wfile = CP_fileobject(self.socket._sock, "wb", self.wbufsize) | |
req.simple_response("400 Bad Request", | |
"The client sent a plain HTTP request, but " | |
"this server only speaks HTTPS on this port.") | |
self.linger = True | |
except Exception: | |
if req and not req.sent_headers: | |
try: | |
req.simple_response("500 Internal Server Error", format_exc()) | |
except FatalSSLAlert: | |
# Close the connection. | |
return | |
linger = False | |
def close(self): | |
"""Close the socket underlying this connection.""" | |
self.rfile.close() | |
if not self.linger: | |
# Python's socket module does NOT call close on the kernel socket | |
# when you call socket.close(). We do so manually here because we | |
# want this server to send a FIN TCP segment immediately. Note this | |
# must be called *before* calling socket.close(), because the latter | |
# drops its reference to the kernel socket. | |
if hasattr(self.socket, '_sock'): | |
self.socket._sock.close() | |
self.socket.close() | |
else: | |
# On the other hand, sometimes we want to hang around for a bit | |
# to make sure the client has a chance to read our entire | |
# response. Skipping the close() calls here delays the FIN | |
# packet until the socket object is garbage-collected later. | |
# Someday, perhaps, we'll do the full lingering_close that | |
# Apache does, but not today. | |
pass | |
_SHUTDOWNREQUEST = None | |
class WorkerThread(threading.Thread): | |
"""Thread which continuously polls a Queue for Connection objects. | |
Due to the timing issues of polling a Queue, a WorkerThread does not | |
check its own 'ready' flag after it has started. To stop the thread, | |
it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue | |
(one for each running WorkerThread). | |
""" | |
conn = None | |
"""The current connection pulled off the Queue, or None.""" | |
server = None | |
"""The HTTP Server which spawned this thread, and which owns the | |
Queue and is placing active connections into it.""" | |
ready = False | |
"""A simple flag for the calling server to know when this thread | |
has begun polling the Queue.""" | |
def __init__(self, server): | |
self.ready = False | |
self.server = server | |
self.requests_seen = 0 | |
self.bytes_read = 0 | |
self.bytes_written = 0 | |
self.start_time = None | |
self.work_time = 0 | |
self.stats = { | |
'Requests': lambda s: self.requests_seen + ((self.start_time is None) and 0 or self.conn.requests_seen), | |
'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and 0 or self.conn.rfile.bytes_read), | |
'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and 0 or self.conn.wfile.bytes_written), | |
'Work Time': lambda s: self.work_time + ((self.start_time is None) and 0 or time.time() - self.start_time), | |
'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6), | |
'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6), | |
} | |
threading.Thread.__init__(self) | |
def run(self): | |
self.server.stats['Worker Threads'][self.getName()] = self.stats | |
try: | |
self.ready = True | |
while True: | |
conn = self.server.requests.get() | |
if conn is _SHUTDOWNREQUEST: | |
return | |
self.conn = conn | |
if self.server.stats['Enabled']: | |
self.start_time = time.time() | |
try: | |
conn.communicate() | |
finally: | |
conn.close() | |
if self.server.stats['Enabled']: | |
self.requests_seen += self.conn.requests_seen | |
self.bytes_read += self.conn.rfile.bytes_read | |
self.bytes_written += self.conn.wfile.bytes_written | |
self.work_time += time.time() - self.start_time | |
self.start_time = None | |
self.conn = None | |
except (KeyboardInterrupt, SystemExit), exc: | |
self.server.interrupt = exc | |
class ThreadPool(object): | |
"""A Request Queue for the CherryPyWSGIServer which pools threads. | |
ThreadPool objects must provide min, get(), put(obj), start() | |
and stop(timeout) attributes. | |
""" | |
def __init__(self, server, min=10, max=-1): | |
self.server = server | |
self.min = min | |
self.max = max | |
self._threads = [] | |
self._queue = Queue.Queue() | |
self.get = self._queue.get | |
def start(self): | |
"""Start the pool of threads.""" | |
for i in range(self.min): | |
self._threads.append(WorkerThread(self.server)) | |
for worker in self._threads: | |
worker.setName("CP Server " + worker.getName()) | |
worker.start() | |
for worker in self._threads: | |
while not worker.ready: | |
time.sleep(.1) | |
def _get_idle(self): | |
"""Number of worker threads which are idle. Read-only.""" | |
return len([t for t in self._threads if t.conn is None]) | |
idle = property(_get_idle, doc=_get_idle.__doc__) | |
def put(self, obj): | |
self._queue.put(obj) | |
if obj is _SHUTDOWNREQUEST: | |
return | |
def grow(self, amount): | |
"""Spawn new worker threads (not above self.max).""" | |
for i in range(amount): | |
if self.max > 0 and len(self._threads) >= self.max: | |
break | |
worker = WorkerThread(self.server) | |
worker.setName("CP Server " + worker.getName()) | |
self._threads.append(worker) | |
worker.start() | |
def shrink(self, amount): | |
"""Kill off worker threads (not below self.min).""" | |
# Grow/shrink the pool if necessary. | |
# Remove any dead threads from our list | |
for t in self._threads: | |
if not t.isAlive(): | |
self._threads.remove(t) | |
amount -= 1 | |
if amount > 0: | |
for i in range(min(amount, len(self._threads) - self.min)): | |
# Put a number of shutdown requests on the queue equal | |
# to 'amount'. Once each of those is processed by a worker, | |
# that worker will terminate and be culled from our list | |
# in self.put. | |
self._queue.put(_SHUTDOWNREQUEST) | |
def stop(self, timeout=5): | |
# Must shut down threads here so the code that calls | |
# this method can know when all threads are stopped. | |
for worker in self._threads: | |
self._queue.put(_SHUTDOWNREQUEST) | |
# Don't join currentThread (when stop is called inside a request). | |
current = threading.currentThread() | |
if timeout and timeout >= 0: | |
endtime = time.time() + timeout | |
while self._threads: | |
worker = self._threads.pop() | |
if worker is not current and worker.isAlive(): | |
try: | |
if timeout is None or timeout < 0: | |
worker.join() | |
else: | |
remaining_time = endtime - time.time() | |
if remaining_time > 0: | |
worker.join(remaining_time) | |
if worker.isAlive(): | |
# We exhausted the timeout. | |
# Forcibly shut down the socket. | |
c = worker.conn | |
if c and not c.rfile.closed: | |
try: | |
c.socket.shutdown(socket.SHUT_RD) | |
except TypeError: | |
# pyOpenSSL sockets don't take an arg | |
c.socket.shutdown() | |
worker.join() | |
except (AssertionError, | |
# Ignore repeated Ctrl-C. | |
# See http://www.cherrypy.org/ticket/691. | |
KeyboardInterrupt), exc1: | |
pass | |
def _get_qsize(self): | |
return self._queue.qsize() | |
qsize = property(_get_qsize) | |
try: | |
import fcntl | |
except ImportError: | |
try: | |
from ctypes import windll, WinError | |
except ImportError: | |
def prevent_socket_inheritance(sock): | |
"""Dummy function, since neither fcntl nor ctypes are available.""" | |
pass | |
else: | |
def prevent_socket_inheritance(sock): | |
"""Mark the given socket fd as non-inheritable (Windows).""" | |
if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0): | |
raise WinError() | |
else: | |
def prevent_socket_inheritance(sock): | |
"""Mark the given socket fd as non-inheritable (POSIX).""" | |
fd = sock.fileno() | |
old_flags = fcntl.fcntl(fd, fcntl.F_GETFD) | |
fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) | |
class SSLAdapter(object): | |
"""Base class for SSL driver library adapters. | |
Required methods: | |
* ``wrap(sock) -> (wrapped socket, ssl environ dict)`` | |
* ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object`` | |
""" | |
def __init__(self, certificate, private_key, certificate_chain=None): | |
self.certificate = certificate | |
self.private_key = private_key | |
self.certificate_chain = certificate_chain | |
def wrap(self, sock): | |
raise NotImplemented | |
def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): | |
raise NotImplemented | |
class HTTPServer(object): | |
"""An HTTP server.""" | |
_bind_addr = "127.0.0.1" | |
_interrupt = None | |
gateway = None | |
"""A Gateway instance.""" | |
minthreads = None | |
"""The minimum number of worker threads to create (default 10).""" | |
maxthreads = None | |
"""The maximum number of worker threads to create (default -1 = no limit).""" | |
server_name = None | |
"""The name of the server; defaults to socket.gethostname().""" | |
protocol = "HTTP/1.1" | |
"""The version string to write in the Status-Line of all HTTP responses. | |
For example, "HTTP/1.1" is the default. This also limits the supported | |
features used in the response.""" | |
request_queue_size = 5 | |
"""The 'backlog' arg to socket.listen(); max queued connections (default 5).""" | |
shutdown_timeout = 5 | |
"""The total time, in seconds, to wait for worker threads to cleanly exit.""" | |
timeout = 10 | |
"""The timeout in seconds for accepted connections (default 10).""" | |
version = "CherryPy/3.2.0" | |
"""A version string for the HTTPServer.""" | |
software = None | |
"""The value to set for the SERVER_SOFTWARE entry in the WSGI environ. | |
If None, this defaults to ``'%s Server' % self.version``.""" | |
ready = False | |
"""An internal flag which marks whether the socket is accepting connections.""" | |
max_request_header_size = 0 | |
"""The maximum size, in bytes, for request headers, or 0 for no limit.""" | |
max_request_body_size = 0 | |
"""The maximum size, in bytes, for request bodies, or 0 for no limit.""" | |
nodelay = True | |
"""If True (the default since 3.1), sets the TCP_NODELAY socket option.""" | |
ConnectionClass = HTTPConnection | |
"""The class to use for handling HTTP connections.""" | |
ssl_adapter = None | |
"""An instance of SSLAdapter (or a subclass). | |
You must have the corresponding SSL driver library installed.""" | |
def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1, | |
server_name=None): | |
self.bind_addr = bind_addr | |
self.gateway = gateway | |
self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads) | |
if not server_name: | |
server_name = socket.gethostname() | |
self.server_name = server_name | |
self.clear_stats() | |
def clear_stats(self): | |
self._start_time = None | |
self._run_time = 0 | |
self.stats = { | |
'Enabled': False, | |
'Bind Address': lambda s: repr(self.bind_addr), | |
'Run time': lambda s: (not s['Enabled']) and 0 or self.runtime(), | |
'Accepts': 0, | |
'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), | |
'Queue': lambda s: getattr(self.requests, "qsize", None), | |
'Threads': lambda s: len(getattr(self.requests, "_threads", [])), | |
'Threads Idle': lambda s: getattr(self.requests, "idle", None), | |
'Socket Errors': 0, | |
'Requests': lambda s: (not s['Enabled']) and 0 or sum([w['Requests'](w) for w | |
in s['Worker Threads'].values()], 0), | |
'Bytes Read': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Read'](w) for w | |
in s['Worker Threads'].values()], 0), | |
'Bytes Written': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Written'](w) for w | |
in s['Worker Threads'].values()], 0), | |
'Work Time': lambda s: (not s['Enabled']) and 0 or sum([w['Work Time'](w) for w | |
in s['Worker Threads'].values()], 0), | |
'Read Throughput': lambda s: (not s['Enabled']) and 0 or sum( | |
[w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) | |
for w in s['Worker Threads'].values()], 0), | |
'Write Throughput': lambda s: (not s['Enabled']) and 0 or sum( | |
[w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) | |
for w in s['Worker Threads'].values()], 0), | |
'Worker Threads': {}, | |
} | |
logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats | |
def runtime(self): | |
if self._start_time is None: | |
return self._run_time | |
else: | |
return self._run_time + (time.time() - self._start_time) | |
def __str__(self): | |
return "%s.%s(%r)" % (self.__module__, self.__class__.__name__, | |
self.bind_addr) | |
def _get_bind_addr(self): | |
return self._bind_addr | |
def _set_bind_addr(self, value): | |
if isinstance(value, tuple) and value[0] in ('', None): | |
# Despite the socket module docs, using '' does not | |
# allow AI_PASSIVE to work. Passing None instead | |
# returns '0.0.0.0' like we want. In other words: | |
# host AI_PASSIVE result | |
# '' Y 192.168.x.y | |
# '' N 192.168.x.y | |
# None Y 0.0.0.0 | |
# None N 127.0.0.1 | |
# But since you can get the same effect with an explicit | |
# '0.0.0.0', we deny both the empty string and None as values. | |
raise ValueError("Host values of '' or None are not allowed. " | |
"Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " | |
"to listen on all active interfaces.") | |
self._bind_addr = value | |
bind_addr = property(_get_bind_addr, _set_bind_addr, | |
doc="""The interface on which to listen for connections. | |
For TCP sockets, a (host, port) tuple. Host values may be any IPv4 | |
or IPv6 address, or any valid hostname. The string 'localhost' is a | |
synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). | |
The string '0.0.0.0' is a special IPv4 entry meaning "any active | |
interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for | |
IPv6. The empty string or None are not allowed. | |
For UNIX sockets, supply the filename as a string.""") | |
def start(self): | |
"""Run the server forever.""" | |
# We don't have to trap KeyboardInterrupt or SystemExit here, | |
# because cherrpy.server already does so, calling self.stop() for us. | |
# If you're using this server with another framework, you should | |
# trap those exceptions in whatever code block calls start(). | |
self._interrupt = None | |
if self.software is None: | |
self.software = "%s Server" % self.version | |
# SSL backward compatibility | |
if (self.ssl_adapter is None and | |
getattr(self, 'ssl_certificate', None) and | |
getattr(self, 'ssl_private_key', None)): | |
warnings.warn( | |
"SSL attributes are deprecated in CherryPy 3.2, and will " | |
"be removed in CherryPy 3.3. Use an ssl_adapter attribute " | |
"instead.", | |
DeprecationWarning | |
) | |
try: | |
from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter | |
except ImportError: | |
pass | |
else: | |
self.ssl_adapter = pyOpenSSLAdapter( | |
self.ssl_certificate, self.ssl_private_key, | |
getattr(self, 'ssl_certificate_chain', None)) | |
# Select the appropriate socket | |
if isinstance(self.bind_addr, basestring): | |
# AF_UNIX socket | |
# So we can reuse the socket... | |
try: os.unlink(self.bind_addr) | |
except: pass | |
# So everyone can access the socket... | |
try: os.chmod(self.bind_addr, 0777) | |
except: pass | |
info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)] | |
else: | |
# AF_INET or AF_INET6 socket | |
# Get the correct address family for our host (allows IPv6 addresses) | |
host, port = self.bind_addr | |
try: | |
info = socket.getaddrinfo(host, port, socket.AF_UNSPEC, | |
socket.SOCK_STREAM, 0, socket.AI_PASSIVE) | |
except socket.gaierror: | |
if ':' in self.bind_addr[0]: | |
info = [(socket.AF_INET6, socket.SOCK_STREAM, | |
0, "", self.bind_addr + (0, 0))] | |
else: | |
info = [(socket.AF_INET, socket.SOCK_STREAM, | |
0, "", self.bind_addr)] | |
self.socket = None | |
msg = "No socket could be created" | |
for res in info: | |
af, socktype, proto, canonname, sa = res | |
try: | |
self.bind(af, socktype, proto) | |
except socket.error: | |
if self.socket: | |
self.socket.close() | |
self.socket = None | |
continue | |
break | |
if not self.socket: | |
raise socket.error(msg) | |
# Timeout so KeyboardInterrupt can be caught on Win32 | |
self.socket.settimeout(1) | |
self.socket.listen(self.request_queue_size) | |
# Create worker threads | |
self.requests.start() | |
self.ready = True | |
self._start_time = time.time() | |
while self.ready: | |
self.tick() | |
if self.interrupt: | |
while self.interrupt is True: | |
# Wait for self.stop() to complete. See _set_interrupt. | |
time.sleep(0.1) | |
if self.interrupt: | |
raise self.interrupt | |
def bind(self, family, type, proto=0): | |
"""Create (or recreate) the actual socket object.""" | |
self.socket = socket.socket(family, type, proto) | |
prevent_socket_inheritance(self.socket) | |
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
if self.nodelay and not isinstance(self.bind_addr, str): | |
self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
if self.ssl_adapter is not None: | |
self.socket = self.ssl_adapter.bind(self.socket) | |
# If listening on the IPV6 any address ('::' = IN6ADDR_ANY), | |
# activate dual-stack. See http://www.cherrypy.org/ticket/871. | |
if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6 | |
and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')): | |
try: | |
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) | |
except (AttributeError, socket.error): | |
# Apparently, the socket option is not available in | |
# this machine's TCP stack | |
pass | |
self.socket.bind(self.bind_addr) | |
def tick(self): | |
"""Accept a new connection and put it on the Queue.""" | |
try: | |
s, addr = self.socket.accept() | |
if self.stats['Enabled']: | |
self.stats['Accepts'] += 1 | |
if not self.ready: | |
return | |
prevent_socket_inheritance(s) | |
if hasattr(s, 'settimeout'): | |
s.settimeout(self.timeout) | |
makefile = CP_fileobject | |
ssl_env = {} | |
# if ssl cert and key are set, we try to be a secure HTTP server | |
if self.ssl_adapter is not None: | |
try: | |
s, ssl_env = self.ssl_adapter.wrap(s) | |
except NoSSLError: | |
msg = ("The client sent a plain HTTP request, but " | |
"this server only speaks HTTPS on this port.") | |
buf = ["%s 400 Bad Request\r\n" % self.protocol, | |
"Content-Length: %s\r\n" % len(msg), | |
"Content-Type: text/plain\r\n\r\n", | |
msg] | |
wfile = CP_fileobject(s, "wb", DEFAULT_BUFFER_SIZE) | |
try: | |
wfile.sendall("".join(buf)) | |
except socket.error, x: | |
if x.args[0] not in socket_errors_to_ignore: | |
raise | |
return | |
if not s: | |
return | |
makefile = self.ssl_adapter.makefile | |
# Re-apply our timeout since we may have a new socket object | |
if hasattr(s, 'settimeout'): | |
s.settimeout(self.timeout) | |
conn = self.ConnectionClass(self, s, makefile) | |
if not isinstance(self.bind_addr, basestring): | |
# optional values | |
# Until we do DNS lookups, omit REMOTE_HOST | |
if addr is None: # sometimes this can happen | |
# figure out if AF_INET or AF_INET6. | |
if len(s.getsockname()) == 2: | |
# AF_INET | |
addr = ('0.0.0.0', 0) | |
else: | |
# AF_INET6 | |
addr = ('::', 0) | |
conn.remote_addr = addr[0] | |
conn.remote_port = addr[1] | |
conn.ssl_env = ssl_env | |
self.requests.put(conn) | |
except socket.timeout: | |
# The only reason for the timeout in start() is so we can | |
# notice keyboard interrupts on Win32, which don't interrupt | |
# accept() by default | |
return | |
except socket.error, x: | |
if self.stats['Enabled']: | |
self.stats['Socket Errors'] += 1 | |
if x.args[0] in socket_error_eintr: | |
# I *think* this is right. EINTR should occur when a signal | |
# is received during the accept() call; all docs say retry | |
# the call, and I *think* I'm reading it right that Python | |
# will then go ahead and poll for and handle the signal | |
# elsewhere. See http://www.cherrypy.org/ticket/707. | |
return | |
if x.args[0] in socket_errors_nonblocking: | |
# Just try again. See http://www.cherrypy.org/ticket/479. | |
return | |
if x.args[0] in socket_errors_to_ignore: | |
# Our socket was closed. | |
# See http://www.cherrypy.org/ticket/686. | |
return | |
raise | |
def _get_interrupt(self): | |
return self._interrupt | |
def _set_interrupt(self, interrupt): | |
self._interrupt = True | |
self.stop() | |
self._interrupt = interrupt | |
interrupt = property(_get_interrupt, _set_interrupt, | |
doc="Set this to an Exception instance to " | |
"interrupt the server.") | |
def stop(self): | |
"""Gracefully shutdown a server that is serving forever.""" | |
self.ready = False | |
if self._start_time is not None: | |
self._run_time += (time.time() - self._start_time) | |
self._start_time = None | |
sock = getattr(self, "socket", None) | |
if sock: | |
if not isinstance(self.bind_addr, basestring): | |
# Touch our own socket to make accept() return immediately. | |
try: | |
host, port = sock.getsockname()[:2] | |
except socket.error, x: | |
if x.args[0] not in socket_errors_to_ignore: | |
# Changed to use error code and not message | |
# See http://www.cherrypy.org/ticket/860. | |
raise | |
else: | |
# Note that we're explicitly NOT using AI_PASSIVE, | |
# here, because we want an actual IP to touch. | |
# localhost won't work if we've bound to a public IP, | |
# but it will if we bound to '0.0.0.0' (INADDR_ANY). | |
for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC, | |
socket.SOCK_STREAM): | |
af, socktype, proto, canonname, sa = res | |
s = None | |
try: | |
s = socket.socket(af, socktype, proto) | |
# See http://groups.google.com/group/cherrypy-users/ | |
# browse_frm/thread/bbfe5eb39c904fe0 | |
s.settimeout(1.0) | |
s.connect((host, port)) | |
s.close() | |
except socket.error: | |
if s: | |
s.close() | |
if hasattr(sock, "close"): | |
sock.close() | |
self.socket = None | |
self.requests.stop(self.shutdown_timeout) | |
class Gateway(object): | |
def __init__(self, req): | |
self.req = req | |
def respond(self): | |
raise NotImplemented | |
# These may either be wsgiserver.SSLAdapter subclasses or the string names | |
# of such classes (in which case they will be lazily loaded). | |
ssl_adapters = { | |
'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter', | |
'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter', | |
} | |
def get_ssl_adapter_class(name='pyopenssl'): | |
adapter = ssl_adapters[name.lower()] | |
if isinstance(adapter, basestring): | |
last_dot = adapter.rfind(".") | |
attr_name = adapter[last_dot + 1:] | |
mod_path = adapter[:last_dot] | |
try: | |
mod = sys.modules[mod_path] | |
if mod is None: | |
raise KeyError() | |
except KeyError: | |
# The last [''] is important. | |
mod = __import__(mod_path, globals(), locals(), ['']) | |
# Let an AttributeError propagate outward. | |
try: | |
adapter = getattr(mod, attr_name) | |
except AttributeError: | |
raise AttributeError("'%s' object has no attribute '%s'" | |
% (mod_path, attr_name)) | |
return adapter | |
# -------------------------------- WSGI Stuff -------------------------------- # | |
class CherryPyWSGIServer(HTTPServer): | |
wsgi_version = (1, 0) | |
def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None, | |
max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5): | |
self.requests = ThreadPool(self, min=numthreads or 1, max=max) | |
self.wsgi_app = wsgi_app | |
self.gateway = wsgi_gateways[self.wsgi_version] | |
self.bind_addr = bind_addr | |
if not server_name: | |
server_name = socket.gethostname() | |
self.server_name = server_name | |
self.request_queue_size = request_queue_size | |
self.timeout = timeout | |
self.shutdown_timeout = shutdown_timeout | |
self.clear_stats() | |
def _get_numthreads(self): | |
return self.requests.min | |
def _set_numthreads(self, value): | |
self.requests.min = value | |
numthreads = property(_get_numthreads, _set_numthreads) | |
class WSGIGateway(Gateway): | |
def __init__(self, req): | |
self.req = req | |
self.started_response = False | |
self.env = self.get_environ() | |
self.remaining_bytes_out = None | |
def get_environ(self): | |
"""Return a new environ dict targeting the given wsgi.version""" | |
raise NotImplemented | |
def respond(self): | |
response = self.req.server.wsgi_app(self.env, self.start_response) | |
try: | |
for chunk in response: | |
# "The start_response callable must not actually transmit | |
# the response headers. Instead, it must store them for the | |
# server or gateway to transmit only after the first | |
# iteration of the application return value that yields | |
# a NON-EMPTY string, or upon the application's first | |
# invocation of the write() callable." (PEP 333) | |
if chunk: | |
if isinstance(chunk, unicode): | |
chunk = chunk.encode('ISO-8859-1') | |
self.write(chunk) | |
finally: | |
if hasattr(response, "close"): | |
response.close() | |
def start_response(self, status, headers, exc_info = None): | |
"""WSGI callable to begin the HTTP response.""" | |
# "The application may call start_response more than once, | |
# if and only if the exc_info argument is provided." | |
if self.started_response and not exc_info: | |
raise AssertionError("WSGI start_response called a second " | |
"time with no exc_info.") | |
self.started_response = True | |
# "if exc_info is provided, and the HTTP headers have already been | |
# sent, start_response must raise an error, and should raise the | |
# exc_info tuple." | |
if self.req.sent_headers: | |
try: | |
raise exc_info[0], exc_info[1], exc_info[2] | |
finally: | |
exc_info = None | |
self.req.status = status | |
for k, v in headers: | |
if not isinstance(k, str): | |
raise TypeError("WSGI response header key %r is not a byte string." % k) | |
if not isinstance(v, str): | |
raise TypeError("WSGI response header value %r is not a byte string." % v) | |
if k.lower() == 'content-length': | |
self.remaining_bytes_out = int(v) | |
self.req.outheaders.extend(headers) | |
return self.write | |
def write(self, chunk): | |
"""WSGI callable to write unbuffered data to the client. | |
This method is also used internally by start_response (to write | |
data from the iterable returned by the WSGI application). | |
""" | |
if not self.started_response: | |
raise AssertionError("WSGI write called before start_response.") | |
chunklen = len(chunk) | |
rbo = self.remaining_bytes_out | |
if rbo is not None and chunklen > rbo: | |
if not self.req.sent_headers: | |
# Whew. We can send a 500 to the client. | |
self.req.simple_response("500 Internal Server Error", | |
"The requested resource returned more bytes than the " | |
"declared Content-Length.") | |
else: | |
# Dang. We have probably already sent data. Truncate the chunk | |
# to fit (so the client doesn't hang) and raise an error later. | |
chunk = chunk[:rbo] | |
if not self.req.sent_headers: | |
self.req.sent_headers = True | |
self.req.send_headers() | |
self.req.write(chunk) | |
if rbo is not None: | |
rbo -= chunklen | |
if rbo < 0: | |
raise ValueError( | |
"Response body exceeds the declared Content-Length.") | |
class WSGIGateway_10(WSGIGateway): | |
def get_environ(self): | |
"""Return a new environ dict targeting the given wsgi.version""" | |
req = self.req | |
env = { | |
# set a non-standard environ entry so the WSGI app can know what | |
# the *real* server protocol is (and what features to support). | |
# See http://www.faqs.org/rfcs/rfc2145.html. | |
'ACTUAL_SERVER_PROTOCOL': req.server.protocol, | |
'PATH_INFO': req.path, | |
'QUERY_STRING': req.qs, | |
'REMOTE_ADDR': req.conn.remote_addr or '', | |
'REMOTE_PORT': str(req.conn.remote_port or ''), | |
'REQUEST_METHOD': req.method, | |
'REQUEST_URI': req.uri, | |
'SCRIPT_NAME': '', | |
'SERVER_NAME': req.server.server_name, | |
# Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. | |
'SERVER_PROTOCOL': req.request_protocol, | |
'SERVER_SOFTWARE': req.server.software, | |
'wsgi.errors': sys.stderr, | |
'wsgi.input': req.rfile, | |
'wsgi.multiprocess': False, | |
'wsgi.multithread': True, | |
'wsgi.run_once': False, | |
'wsgi.url_scheme': req.scheme, | |
'wsgi.version': (1, 0), | |
} | |
if isinstance(req.server.bind_addr, basestring): | |
# AF_UNIX. This isn't really allowed by WSGI, which doesn't | |
# address unix domain sockets. But it's better than nothing. | |
env["SERVER_PORT"] = "" | |
else: | |
env["SERVER_PORT"] = str(req.server.bind_addr[1]) | |
# Request headers | |
for k, v in req.inheaders.iteritems(): | |
env["HTTP_" + k.upper().replace("-", "_")] = v | |
# CONTENT_TYPE/CONTENT_LENGTH | |
ct = env.pop("HTTP_CONTENT_TYPE", None) | |
if ct is not None: | |
env["CONTENT_TYPE"] = ct | |
cl = env.pop("HTTP_CONTENT_LENGTH", None) | |
if cl is not None: | |
env["CONTENT_LENGTH"] = cl | |
if req.conn.ssl_env: | |
env.update(req.conn.ssl_env) | |
return env | |
class WSGIGateway_u0(WSGIGateway_10): | |
def get_environ(self): | |
"""Return a new environ dict targeting the given wsgi.version""" | |
req = self.req | |
env_10 = WSGIGateway_10.get_environ(self) | |
env = dict([(k.decode('ISO-8859-1'), v) for k, v in env_10.iteritems()]) | |
env[u'wsgi.version'] = ('u', 0) | |
# Request-URI | |
env.setdefault(u'wsgi.url_encoding', u'utf-8') | |
try: | |
for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: | |
env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) | |
except UnicodeDecodeError: | |
# Fall back to latin 1 so apps can transcode if needed. | |
env[u'wsgi.url_encoding'] = u'ISO-8859-1' | |
for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]: | |
env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding']) | |
for k, v in sorted(env.items()): | |
if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'): | |
env[k] = v.decode('ISO-8859-1') | |
return env | |
wsgi_gateways = { | |
(1, 0): WSGIGateway_10, | |
('u', 0): WSGIGateway_u0, | |
} | |
class WSGIPathInfoDispatcher(object): | |
"""A WSGI dispatcher for dispatch based on the PATH_INFO. | |
apps: a dict or list of (path_prefix, app) pairs. | |
""" | |
def __init__(self, apps): | |
try: | |
apps = apps.items() | |
except AttributeError: | |
pass | |
# Sort the apps by len(path), descending | |
apps.sort(cmp=lambda x,y: cmp(len(x[0]), len(y[0]))) | |
apps.reverse() | |
# The path_prefix strings must start, but not end, with a slash. | |
# Use "" instead of "/". | |
self.apps = [(p.rstrip("/"), a) for p, a in apps] | |
def __call__(self, environ, start_response): | |
path = environ["PATH_INFO"] or "/" | |
for p, app in self.apps: | |
# The apps list should be sorted by length, descending. | |
if path.startswith(p + "/") or path == p: | |
environ = environ.copy() | |
environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p | |
environ["PATH_INFO"] = path[len(p):] | |
return app(environ, start_response) | |
start_response('404 Not Found', [('Content-Type', 'text/plain'), | |
('Content-Length', '0')]) | |
return [''] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
application: cmfeedparser | |
version: 1 | |
runtime: python27 | |
api_version: 1 | |
threadsafe: no | |
handlers: | |
- url: /.* | |
script: parse.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Web application | |
(from web.py) | |
""" | |
import webapi as web | |
import webapi, wsgi, utils | |
import debugerror | |
from utils import lstrips, safeunicode | |
import sys | |
import urllib | |
import traceback | |
import itertools | |
import os | |
import types | |
from exceptions import SystemExit | |
try: | |
import wsgiref.handlers | |
except ImportError: | |
pass # don't break people with old Pythons | |
__all__ = [ | |
"application", "auto_application", | |
"subdir_application", "subdomain_application", | |
"loadhook", "unloadhook", | |
"autodelegate" | |
] | |
class application: | |
""" | |
Application to delegate requests based on path. | |
>>> urls = ("/hello", "hello") | |
>>> app = application(urls, globals()) | |
>>> class hello: | |
... def GET(self): return "hello" | |
>>> | |
>>> app.request("/hello").data | |
'hello' | |
""" | |
def __init__(self, mapping=(), fvars={}, autoreload=None): | |
if autoreload is None: | |
autoreload = web.config.get('debug', False) | |
self.init_mapping(mapping) | |
self.fvars = fvars | |
self.processors = [] | |
self.add_processor(loadhook(self._load)) | |
self.add_processor(unloadhook(self._unload)) | |
if autoreload: | |
def main_module_name(): | |
mod = sys.modules['__main__'] | |
file = getattr(mod, '__file__', None) # make sure this works even from python interpreter | |
return file and os.path.splitext(os.path.basename(file))[0] | |
def modname(fvars): | |
"""find name of the module name from fvars.""" | |
file, name = fvars.get('__file__'), fvars.get('__name__') | |
if file is None or name is None: | |
return None | |
if name == '__main__': | |
# Since the __main__ module can't be reloaded, the module has | |
# to be imported using its file name. | |
name = main_module_name() | |
return name | |
mapping_name = utils.dictfind(fvars, mapping) | |
module_name = modname(fvars) | |
def reload_mapping(): | |
"""loadhook to reload mapping and fvars.""" | |
mod = __import__(module_name, None, None, ['']) | |
mapping = getattr(mod, mapping_name, None) | |
if mapping: | |
self.fvars = mod.__dict__ | |
self.init_mapping(mapping) | |
self.add_processor(loadhook(Reloader())) | |
if mapping_name and module_name: | |
self.add_processor(loadhook(reload_mapping)) | |
# load __main__ module usings its filename, so that it can be reloaded. | |
if main_module_name() and '__main__' in sys.argv: | |
try: | |
__import__(main_module_name()) | |
except ImportError: | |
pass | |
def _load(self): | |
web.ctx.app_stack.append(self) | |
def _unload(self): | |
web.ctx.app_stack = web.ctx.app_stack[:-1] | |
if web.ctx.app_stack: | |
# this is a sub-application, revert ctx to earlier state. | |
oldctx = web.ctx.get('_oldctx') | |
if oldctx: | |
web.ctx.home = oldctx.home | |
web.ctx.homepath = oldctx.homepath | |
web.ctx.path = oldctx.path | |
web.ctx.fullpath = oldctx.fullpath | |
def _cleanup(self): | |
# Threads can be recycled by WSGI servers. | |
# Clearing up all thread-local state to avoid interefereing with subsequent requests. | |
utils.ThreadedDict.clear_all() | |
def init_mapping(self, mapping): | |
self.mapping = list(utils.group(mapping, 2)) | |
def add_mapping(self, pattern, classname): | |
self.mapping.append((pattern, classname)) | |
def add_processor(self, processor): | |
""" | |
Adds a processor to the application. | |
>>> urls = ("/(.*)", "echo") | |
>>> app = application(urls, globals()) | |
>>> class echo: | |
... def GET(self, name): return name | |
... | |
>>> | |
>>> def hello(handler): return "hello, " + handler() | |
... | |
>>> app.add_processor(hello) | |
>>> app.request("/web.py").data | |
'hello, web.py' | |
""" | |
self.processors.append(processor) | |
def request(self, localpart='/', method='GET', data=None, | |
host="0.0.0.0:8080", headers=None, https=False, **kw): | |
"""Makes request to this application for the specified path and method. | |
Response will be a storage object with data, status and headers. | |
>>> urls = ("/hello", "hello") | |
>>> app = application(urls, globals()) | |
>>> class hello: | |
... def GET(self): | |
... web.header('Content-Type', 'text/plain') | |
... return "hello" | |
... | |
>>> response = app.request("/hello") | |
>>> response.data | |
'hello' | |
>>> response.status | |
'200 OK' | |
>>> response.headers['Content-Type'] | |
'text/plain' | |
To use https, use https=True. | |
>>> urls = ("/redirect", "redirect") | |
>>> app = application(urls, globals()) | |
>>> class redirect: | |
... def GET(self): raise web.seeother("/foo") | |
... | |
>>> response = app.request("/redirect") | |
>>> response.headers['Location'] | |
'http://0.0.0.0:8080/foo' | |
>>> response = app.request("/redirect", https=True) | |
>>> response.headers['Location'] | |
'https://0.0.0.0:8080/foo' | |
The headers argument specifies HTTP headers as a mapping object | |
such as a dict. | |
>>> urls = ('/ua', 'uaprinter') | |
>>> class uaprinter: | |
... def GET(self): | |
... return 'your user-agent is ' + web.ctx.env['HTTP_USER_AGENT'] | |
... | |
>>> app = application(urls, globals()) | |
>>> app.request('/ua', headers = { | |
... 'User-Agent': 'a small jumping bean/1.0 (compatible)' | |
... }).data | |
'your user-agent is a small jumping bean/1.0 (compatible)' | |
""" | |
path, maybe_query = urllib.splitquery(localpart) | |
query = maybe_query or "" | |
if 'env' in kw: | |
env = kw['env'] | |
else: | |
env = {} | |
env = dict(env, HTTP_HOST=host, REQUEST_METHOD=method, PATH_INFO=path, QUERY_STRING=query, HTTPS=str(https)) | |
headers = headers or {} | |
for k, v in headers.items(): | |
env['HTTP_' + k.upper().replace('-', '_')] = v | |
if 'HTTP_CONTENT_LENGTH' in env: | |
env['CONTENT_LENGTH'] = env.pop('HTTP_CONTENT_LENGTH') | |
if 'HTTP_CONTENT_TYPE' in env: | |
env['CONTENT_TYPE'] = env.pop('HTTP_CONTENT_TYPE') | |
if method not in ["HEAD", "GET"]: | |
data = data or '' | |
import StringIO | |
if isinstance(data, dict): | |
q = urllib.urlencode(data) | |
else: | |
q = data | |
env['wsgi.input'] = StringIO.StringIO(q) | |
if not env.get('CONTENT_TYPE', '').lower().startswith('multipart/') and 'CONTENT_LENGTH' not in env: | |
env['CONTENT_LENGTH'] = len(q) | |
response = web.storage() | |
def start_response(status, headers): | |
response.status = status | |
response.headers = dict(headers) | |
response.header_items = headers | |
response.data = "".join(self.wsgifunc()(env, start_response)) | |
return response | |
def browser(self): | |
import browser | |
return browser.AppBrowser(self) | |
def handle(self): | |
fn, args = self._match(self.mapping, web.ctx.path) | |
return self._delegate(fn, self.fvars, args) | |
def handle_with_processors(self): | |
def process(processors): | |
try: | |
if processors: | |
p, processors = processors[0], processors[1:] | |
return p(lambda: process(processors)) | |
else: | |
return self.handle() | |
except web.HTTPError: | |
raise | |
except (KeyboardInterrupt, SystemExit): | |
raise | |
except: | |
print >> web.debug, traceback.format_exc() | |
raise self.internalerror() | |
# processors must be applied in the resvere order. (??) | |
return process(self.processors) | |
def wsgifunc(self, *middleware): | |
"""Returns a WSGI-compatible function for this application.""" | |
def peep(iterator): | |
"""Peeps into an iterator by doing an iteration | |
and returns an equivalent iterator. | |
""" | |
# wsgi requires the headers first | |
# so we need to do an iteration | |
# and save the result for later | |
try: | |
firstchunk = iterator.next() | |
except StopIteration: | |
firstchunk = '' | |
return itertools.chain([firstchunk], iterator) | |
def is_generator(x): return x and hasattr(x, 'next') | |
def wsgi(env, start_resp): | |
# clear threadlocal to avoid inteference of previous requests | |
self._cleanup() | |
self.load(env) | |
try: | |
# allow uppercase methods only | |
if web.ctx.method.upper() != web.ctx.method: | |
raise web.nomethod() | |
result = self.handle_with_processors() | |
if is_generator(result): | |
result = peep(result) | |
else: | |
result = [result] | |
except web.HTTPError, e: | |
result = [e.data] | |
result = web.safestr(iter(result)) | |
status, headers = web.ctx.status, web.ctx.headers | |
start_resp(status, headers) | |
def cleanup(): | |
self._cleanup() | |
yield '' # force this function to be a generator | |
return itertools.chain(result, cleanup()) | |
for m in middleware: | |
wsgi = m(wsgi) | |
return wsgi | |
def run(self, *middleware): | |
""" | |
Starts handling requests. If called in a CGI or FastCGI context, it will follow | |
that protocol. If called from the command line, it will start an HTTP | |
server on the port named in the first command line argument, or, if there | |
is no argument, on port 8080. | |
`middleware` is a list of WSGI middleware which is applied to the resulting WSGI | |
function. | |
""" | |
return wsgi.runwsgi(self.wsgifunc(*middleware)) | |
def cgirun(self, *middleware): | |
""" | |
Return a CGI handler. This is mostly useful with Google App Engine. | |
There you can just do: | |
main = app.cgirun() | |
""" | |
wsgiapp = self.wsgifunc(*middleware) | |
try: | |
from google.appengine.ext.webapp.util import run_wsgi_app | |
return run_wsgi_app(wsgiapp) | |
except ImportError: | |
# we're not running from within Google App Engine | |
return wsgiref.handlers.CGIHandler().run(wsgiapp) | |
def load(self, env): | |
"""Initializes ctx using env.""" | |
ctx = web.ctx | |
ctx.clear() | |
ctx.status = '200 OK' | |
ctx.headers = [] | |
ctx.output = '' | |
ctx.environ = ctx.env = env | |
ctx.host = env.get('HTTP_HOST') | |
if env.get('wsgi.url_scheme') in ['http', 'https']: | |
ctx.protocol = env['wsgi.url_scheme'] | |
elif env.get('HTTPS', '').lower() in ['on', 'true', '1']: | |
ctx.protocol = 'https' | |
else: | |
ctx.protocol = 'http' | |
ctx.homedomain = ctx.protocol + '://' + env.get('HTTP_HOST', '[unknown]') | |
ctx.homepath = os.environ.get('REAL_SCRIPT_NAME', env.get('SCRIPT_NAME', '')) | |
ctx.home = ctx.homedomain + ctx.homepath | |
#@@ home is changed when the request is handled to a sub-application. | |
#@@ but the real home is required for doing absolute redirects. | |
ctx.realhome = ctx.home | |
ctx.ip = env.get('REMOTE_ADDR') | |
ctx.method = env.get('REQUEST_METHOD') | |
ctx.path = env.get('PATH_INFO') | |
# http://trac.lighttpd.net/trac/ticket/406 requires: | |
if env.get('SERVER_SOFTWARE', '').startswith('lighttpd/'): | |
ctx.path = lstrips(env.get('REQUEST_URI').split('?')[0], ctx.homepath) | |
# Apache and CherryPy webservers unquote the url but lighttpd doesn't. | |
# unquote explicitly for lighttpd to make ctx.path uniform across all servers. | |
ctx.path = urllib.unquote(ctx.path) | |
if env.get('QUERY_STRING'): | |
ctx.query = '?' + env.get('QUERY_STRING', '') | |
else: | |
ctx.query = '' | |
ctx.fullpath = ctx.path + ctx.query | |
for k, v in ctx.iteritems(): | |
if isinstance(v, str): | |
ctx[k] = safeunicode(v) | |
# status must always be str | |
ctx.status = '200 OK' | |
ctx.app_stack = [] | |
def _delegate(self, f, fvars, args=[]): | |
def handle_class(cls): | |
meth = web.ctx.method | |
if meth == 'HEAD' and not hasattr(cls, meth): | |
meth = 'GET' | |
if not hasattr(cls, meth): | |
raise web.nomethod(cls) | |
tocall = getattr(cls(), meth) | |
return tocall(*args) | |
def is_class(o): return isinstance(o, (types.ClassType, type)) | |
if f is None: | |
raise web.notfound() | |
elif isinstance(f, application): | |
return f.handle_with_processors() | |
elif is_class(f): | |
return handle_class(f) | |
elif isinstance(f, basestring): | |
if f.startswith('redirect '): | |
url = f.split(' ', 1)[1] | |
if web.ctx.method == "GET": | |
x = web.ctx.env.get('QUERY_STRING', '') | |
if x: | |
url += '?' + x | |
raise web.redirect(url) | |
elif '.' in f: | |
mod, cls = f.rsplit('.', 1) | |
mod = __import__(mod, None, None, ['']) | |
cls = getattr(mod, cls) | |
else: | |
cls = fvars[f] | |
return handle_class(cls) | |
elif hasattr(f, '__call__'): | |
return f() | |
else: | |
return web.notfound() | |
def _match(self, mapping, value): | |
for pat, what in mapping: | |
if isinstance(what, application): | |
if value.startswith(pat): | |
f = lambda: self._delegate_sub_application(pat, what) | |
return f, None | |
else: | |
continue | |
elif isinstance(what, basestring): | |
what, result = utils.re_subm('^' + pat + '$', what, value) | |
else: | |
result = utils.re_compile('^' + pat + '$').match(value) | |
if result: # it's a match | |
return what, [x for x in result.groups()] | |
return None, None | |
def _delegate_sub_application(self, dir, app): | |
"""Deletes request to sub application `app` rooted at the directory `dir`. | |
The home, homepath, path and fullpath values in web.ctx are updated to mimic request | |
to the subapp and are restored after it is handled. | |
@@Any issues with when used with yield? | |
""" | |
web.ctx._oldctx = web.storage(web.ctx) | |
web.ctx.home += dir | |
web.ctx.homepath += dir | |
web.ctx.path = web.ctx.path[len(dir):] | |
web.ctx.fullpath = web.ctx.fullpath[len(dir):] | |
return app.handle_with_processors() | |
def get_parent_app(self): | |
if self in web.ctx.app_stack: | |
index = web.ctx.app_stack.index(self) | |
if index > 0: | |
return web.ctx.app_stack[index-1] | |
def notfound(self): | |
"""Returns HTTPError with '404 not found' message""" | |
parent = self.get_parent_app() | |
if parent: | |
return parent.notfound() | |
else: | |
return web._NotFound() | |
def internalerror(self): | |
"""Returns HTTPError with '500 internal error' message""" | |
parent = self.get_parent_app() | |
if parent: | |
return parent.internalerror() | |
elif web.config.get('debug'): | |
import debugerror | |
return debugerror.debugerror() | |
else: | |
return web._InternalError() | |
class auto_application(application): | |
"""Application similar to `application` but urls are constructed | |
automatiacally using metaclass. | |
>>> app = auto_application() | |
>>> class hello(app.page): | |
... def GET(self): return "hello, world" | |
... | |
>>> class foo(app.page): | |
... path = '/foo/.*' | |
... def GET(self): return "foo" | |
>>> app.request("/hello").data | |
'hello, world' | |
>>> app.request('/foo/bar').data | |
'foo' | |
""" | |
def __init__(self): | |
application.__init__(self) | |
class metapage(type): | |
def __init__(klass, name, bases, attrs): | |
type.__init__(klass, name, bases, attrs) | |
path = attrs.get('path', '/' + name) | |
# path can be specified as None to ignore that class | |
# typically required to create a abstract base class. | |
if path is not None: | |
self.add_mapping(path, klass) | |
class page: | |
path = None | |
__metaclass__ = metapage | |
self.page = page | |
# The application class already has the required functionality of subdir_application | |
subdir_application = application | |
class subdomain_application(application): | |
""" | |
Application to delegate requests based on the host. | |
>>> urls = ("/hello", "hello") | |
>>> app = application(urls, globals()) | |
>>> class hello: | |
... def GET(self): return "hello" | |
>>> | |
>>> mapping = (r"hello\.example\.com", app) | |
>>> app2 = subdomain_application(mapping) | |
>>> app2.request("/hello", host="hello.example.com").data | |
'hello' | |
>>> response = app2.request("/hello", host="something.example.com") | |
>>> response.status | |
'404 Not Found' | |
>>> response.data | |
'not found' | |
""" | |
def handle(self): | |
host = web.ctx.host.split(':')[0] #strip port | |
fn, args = self._match(self.mapping, host) | |
return self._delegate(fn, self.fvars, args) | |
def _match(self, mapping, value): | |
for pat, what in mapping: | |
if isinstance(what, basestring): | |
what, result = utils.re_subm('^' + pat + '$', what, value) | |
else: | |
result = utils.re_compile('^' + pat + '$').match(value) | |
if result: # it's a match | |
return what, [x for x in result.groups()] | |
return None, None | |
def loadhook(h): | |
""" | |
Converts a load hook into an application processor. | |
>>> app = auto_application() | |
>>> def f(): "something done before handling request" | |
... | |
>>> app.add_processor(loadhook(f)) | |
""" | |
def processor(handler): | |
h() | |
return handler() | |
return processor | |
def unloadhook(h): | |
""" | |
Converts an unload hook into an application processor. | |
>>> app = auto_application() | |
>>> def f(): "something done after handling request" | |
... | |
>>> app.add_processor(unloadhook(f)) | |
""" | |
def processor(handler): | |
try: | |
result = handler() | |
is_generator = result and hasattr(result, 'next') | |
except: | |
# run the hook even when handler raises some exception | |
h() | |
raise | |
if is_generator: | |
return wrap(result) | |
else: | |
h() | |
return result | |
def wrap(result): | |
def next(): | |
try: | |
return result.next() | |
except: | |
# call the hook at the and of iterator | |
h() | |
raise | |
result = iter(result) | |
while True: | |
yield next() | |
return processor | |
def autodelegate(prefix=''): | |
""" | |
Returns a method that takes one argument and calls the method named prefix+arg, | |
calling `notfound()` if there isn't one. Example: | |
urls = ('/prefs/(.*)', 'prefs') | |
class prefs: | |
GET = autodelegate('GET_') | |
def GET_password(self): pass | |
def GET_privacy(self): pass | |
`GET_password` would get called for `/prefs/password` while `GET_privacy` for | |
`GET_privacy` gets called for `/prefs/privacy`. | |
If a user visits `/prefs/password/change` then `GET_password(self, '/change')` | |
is called. | |
""" | |
def internal(self, arg): | |
if '/' in arg: | |
first, rest = arg.split('/', 1) | |
func = prefix + first | |
args = ['/' + rest] | |
else: | |
func = prefix + arg | |
args = [] | |
if hasattr(self, func): | |
try: | |
return getattr(self, func)(*args) | |
except TypeError: | |
raise web.notfound() | |
else: | |
raise web.notfound() | |
return internal | |
class Reloader: | |
"""Checks to see if any loaded modules have changed on disk and, | |
if so, reloads them. | |
""" | |
"""File suffix of compiled modules.""" | |
if sys.platform.startswith('java'): | |
SUFFIX = '$py.class' | |
else: | |
SUFFIX = '.pyc' | |
def __init__(self): | |
self.mtimes = {} | |
def __call__(self): | |
for mod in sys.modules.values(): | |
self.check(mod) | |
def check(self, mod): | |
# jython registers java packages as modules but they either | |
# don't have a __file__ attribute or its value is None | |
if not (mod and hasattr(mod, '__file__') and mod.__file__): | |
return | |
try: | |
mtime = os.stat(mod.__file__).st_mtime | |
except (OSError, IOError): | |
return | |
if mod.__file__.endswith(self.__class__.SUFFIX) and os.path.exists(mod.__file__[:-1]): | |
mtime = max(os.stat(mod.__file__[:-1]).st_mtime, mtime) | |
if mod not in self.mtimes: | |
self.mtimes[mod] = mtime | |
elif self.mtimes[mod] < mtime: | |
try: | |
reload(mod) | |
self.mtimes[mod] = mtime | |
except ImportError: | |
pass | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Browser to test web applications. | |
(from web.py) | |
""" | |
from utils import re_compile | |
from net import htmlunquote | |
import httplib, urllib, urllib2 | |
import copy | |
from StringIO import StringIO | |
DEBUG = False | |
__all__ = [ | |
"BrowserError", | |
"Browser", "AppBrowser", | |
"AppHandler" | |
] | |
class BrowserError(Exception): | |
pass | |
class Browser: | |
def __init__(self): | |
import cookielib | |
self.cookiejar = cookielib.CookieJar() | |
self._cookie_processor = urllib2.HTTPCookieProcessor(self.cookiejar) | |
self.form = None | |
self.url = "http://0.0.0.0:8080/" | |
self.path = "/" | |
self.status = None | |
self.data = None | |
self._response = None | |
self._forms = None | |
def reset(self): | |
"""Clears all cookies and history.""" | |
self.cookiejar.clear() | |
def build_opener(self): | |
"""Builds the opener using urllib2.build_opener. | |
Subclasses can override this function to prodive custom openers. | |
""" | |
return urllib2.build_opener() | |
def do_request(self, req): | |
if DEBUG: | |
print 'requesting', req.get_method(), req.get_full_url() | |
opener = self.build_opener() | |
opener.add_handler(self._cookie_processor) | |
try: | |
self._response = opener.open(req) | |
except urllib2.HTTPError, e: | |
self._response = e | |
self.url = self._response.geturl() | |
self.path = urllib2.Request(self.url).get_selector() | |
self.data = self._response.read() | |
self.status = self._response.code | |
self._forms = None | |
self.form = None | |
return self.get_response() | |
def open(self, url, data=None, headers={}): | |
"""Opens the specified url.""" | |
url = urllib.basejoin(self.url, url) | |
req = urllib2.Request(url, data, headers) | |
return self.do_request(req) | |
def show(self): | |
"""Opens the current page in real web browser.""" | |
f = open('page.html', 'w') | |
f.write(self.data) | |
f.close() | |
import webbrowser, os | |
url = 'file://' + os.path.abspath('page.html') | |
webbrowser.open(url) | |
def get_response(self): | |
"""Returns a copy of the current response.""" | |
return urllib.addinfourl(StringIO(self.data), self._response.info(), self._response.geturl()) | |
def get_soup(self): | |
"""Returns beautiful soup of the current document.""" | |
import BeautifulSoup | |
return BeautifulSoup.BeautifulSoup(self.data) | |
def get_text(self, e=None): | |
"""Returns content of e or the current document as plain text.""" | |
e = e or self.get_soup() | |
return ''.join([htmlunquote(c) for c in e.recursiveChildGenerator() if isinstance(c, unicode)]) | |
def _get_links(self): | |
soup = self.get_soup() | |
return [a for a in soup.findAll(name='a')] | |
def get_links(self, text=None, text_regex=None, url=None, url_regex=None, predicate=None): | |
"""Returns all links in the document.""" | |
return self._filter_links(self._get_links(), | |
text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate) | |
def follow_link(self, link=None, text=None, text_regex=None, url=None, url_regex=None, predicate=None): | |
if link is None: | |
links = self._filter_links(self.get_links(), | |
text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate) | |
link = links and links[0] | |
if link: | |
return self.open(link['href']) | |
else: | |
raise BrowserError("No link found") | |
def find_link(self, text=None, text_regex=None, url=None, url_regex=None, predicate=None): | |
links = self._filter_links(self.get_links(), | |
text=text, text_regex=text_regex, url=url, url_regex=url_regex, predicate=predicate) | |
return links and links[0] or None | |
def _filter_links(self, links, | |
text=None, text_regex=None, | |
url=None, url_regex=None, | |
predicate=None): | |
predicates = [] | |
if text is not None: | |
predicates.append(lambda link: link.string == text) | |
if text_regex is not None: | |
predicates.append(lambda link: re_compile(text_regex).search(link.string or '')) | |
if url is not None: | |
predicates.append(lambda link: link.get('href') == url) | |
if url_regex is not None: | |
predicates.append(lambda link: re_compile(url_regex).search(link.get('href', ''))) | |
if predicate: | |
predicate.append(predicate) | |
def f(link): | |
for p in predicates: | |
if not p(link): | |
return False | |
return True | |
return [link for link in links if f(link)] | |
def get_forms(self): | |
"""Returns all forms in the current document. | |
The returned form objects implement the ClientForm.HTMLForm interface. | |
""" | |
if self._forms is None: | |
import ClientForm | |
self._forms = ClientForm.ParseResponse(self.get_response(), backwards_compat=False) | |
return self._forms | |
def select_form(self, name=None, predicate=None, index=0): | |
"""Selects the specified form.""" | |
forms = self.get_forms() | |
if name is not None: | |
forms = [f for f in forms if f.name == name] | |
if predicate: | |
forms = [f for f in forms if predicate(f)] | |
if forms: | |
self.form = forms[index] | |
return self.form | |
else: | |
raise BrowserError("No form selected.") | |
def submit(self, **kw): | |
"""submits the currently selected form.""" | |
if self.form is None: | |
raise BrowserError("No form selected.") | |
req = self.form.click(**kw) | |
return self.do_request(req) | |
def __getitem__(self, key): | |
return self.form[key] | |
def __setitem__(self, key, value): | |
self.form[key] = value | |
class AppBrowser(Browser): | |
"""Browser interface to test web.py apps. | |
b = AppBrowser(app) | |
b.open('/') | |
b.follow_link(text='Login') | |
b.select_form(name='login') | |
b['username'] = 'joe' | |
b['password'] = 'secret' | |
b.submit() | |
assert b.path == '/' | |
assert 'Welcome joe' in b.get_text() | |
""" | |
def __init__(self, app): | |
Browser.__init__(self) | |
self.app = app | |
def build_opener(self): | |
return urllib2.build_opener(AppHandler(self.app)) | |
class AppHandler(urllib2.HTTPHandler): | |
"""urllib2 handler to handle requests using web.py application.""" | |
handler_order = 100 | |
def __init__(self, app): | |
self.app = app | |
def http_open(self, req): | |
result = self.app.request( | |
localpart=req.get_selector(), | |
method=req.get_method(), | |
host=req.get_host(), | |
data=req.get_data(), | |
headers=dict(req.header_items()), | |
https=req.get_type() == "https" | |
) | |
return self._make_response(result, req.get_full_url()) | |
def https_open(self, req): | |
return self.http_open(req) | |
try: | |
https_request = urllib2.HTTPHandler.do_request_ | |
except AttributeError: | |
# for python 2.3 | |
pass | |
def _make_response(self, result, url): | |
data = "\r\n".join(["%s: %s" % (k, v) for k, v in result.header_items]) | |
headers = httplib.HTTPMessage(StringIO(data)) | |
response = urllib.addinfourl(StringIO(result.data), headers, url) | |
code, msg = result.status.split(None, 1) | |
response.code, response.msg = int(code), msg | |
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Database API | |
(part of web.py) | |
""" | |
__all__ = [ | |
"UnknownParamstyle", "UnknownDB", "TransactionError", | |
"sqllist", "sqlors", "reparam", "sqlquote", | |
"SQLQuery", "SQLParam", "sqlparam", | |
"SQLLiteral", "sqlliteral", | |
"database", 'DB', | |
] | |
import time | |
try: | |
import datetime | |
except ImportError: | |
datetime = None | |
try: set | |
except NameError: | |
from sets import Set as set | |
from utils import threadeddict, storage, iters, iterbetter, safestr, safeunicode | |
try: | |
# db module can work independent of web.py | |
from webapi import debug, config | |
except: | |
import sys | |
debug = sys.stderr | |
config = storage() | |
class UnknownDB(Exception): | |
"""raised for unsupported dbms""" | |
pass | |
class _ItplError(ValueError): | |
def __init__(self, text, pos): | |
ValueError.__init__(self) | |
self.text = text | |
self.pos = pos | |
def __str__(self): | |
return "unfinished expression in %s at char %d" % ( | |
repr(self.text), self.pos) | |
class TransactionError(Exception): pass | |
class UnknownParamstyle(Exception): | |
""" | |
raised for unsupported db paramstyles | |
(currently supported: qmark, numeric, format, pyformat) | |
""" | |
pass | |
class SQLParam(object): | |
""" | |
Parameter in SQLQuery. | |
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam("joe")]) | |
>>> q | |
<sql: "SELECT * FROM test WHERE name='joe'"> | |
>>> q.query() | |
'SELECT * FROM test WHERE name=%s' | |
>>> q.values() | |
['joe'] | |
""" | |
__slots__ = ["value"] | |
def __init__(self, value): | |
self.value = value | |
def get_marker(self, paramstyle='pyformat'): | |
if paramstyle == 'qmark': | |
return '?' | |
elif paramstyle == 'numeric': | |
return ':1' | |
elif paramstyle is None or paramstyle in ['format', 'pyformat']: | |
return '%s' | |
raise UnknownParamstyle, paramstyle | |
def sqlquery(self): | |
return SQLQuery([self]) | |
def __add__(self, other): | |
return self.sqlquery() + other | |
def __radd__(self, other): | |
return other + self.sqlquery() | |
def __str__(self): | |
return str(self.value) | |
def __repr__(self): | |
return '<param: %s>' % repr(self.value) | |
sqlparam = SQLParam | |
class SQLQuery(object): | |
""" | |
You can pass this sort of thing as a clause in any db function. | |
Otherwise, you can pass a dictionary to the keyword argument `vars` | |
and the function will call reparam for you. | |
Internally, consists of `items`, which is a list of strings and | |
SQLParams, which get concatenated to produce the actual query. | |
""" | |
__slots__ = ["items"] | |
# tested in sqlquote's docstring | |
def __init__(self, items=None): | |
r"""Creates a new SQLQuery. | |
>>> SQLQuery("x") | |
<sql: 'x'> | |
>>> q = SQLQuery(['SELECT * FROM ', 'test', ' WHERE x=', SQLParam(1)]) | |
>>> q | |
<sql: 'SELECT * FROM test WHERE x=1'> | |
>>> q.query(), q.values() | |
('SELECT * FROM test WHERE x=%s', [1]) | |
>>> SQLQuery(SQLParam(1)) | |
<sql: '1'> | |
""" | |
if items is None: | |
self.items = [] | |
elif isinstance(items, list): | |
self.items = items | |
elif isinstance(items, SQLParam): | |
self.items = [items] | |
elif isinstance(items, SQLQuery): | |
self.items = list(items.items) | |
else: | |
self.items = [items] | |
# Take care of SQLLiterals | |
for i, item in enumerate(self.items): | |
if isinstance(item, SQLParam) and isinstance(item.value, SQLLiteral): | |
self.items[i] = item.value.v | |
def append(self, value): | |
self.items.append(value) | |
def __add__(self, other): | |
if isinstance(other, basestring): | |
items = [other] | |
elif isinstance(other, SQLQuery): | |
items = other.items | |
else: | |
return NotImplemented | |
return SQLQuery(self.items + items) | |
def __radd__(self, other): | |
if isinstance(other, basestring): | |
items = [other] | |
else: | |
return NotImplemented | |
return SQLQuery(items + self.items) | |
def __iadd__(self, other): | |
if isinstance(other, (basestring, SQLParam)): | |
self.items.append(other) | |
elif isinstance(other, SQLQuery): | |
self.items.extend(other.items) | |
else: | |
return NotImplemented | |
return self | |
def __len__(self): | |
return len(self.query()) | |
def query(self, paramstyle=None): | |
""" | |
Returns the query part of the sql query. | |
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) | |
>>> q.query() | |
'SELECT * FROM test WHERE name=%s' | |
>>> q.query(paramstyle='qmark') | |
'SELECT * FROM test WHERE name=?' | |
""" | |
s = [] | |
for x in self.items: | |
if isinstance(x, SQLParam): | |
x = x.get_marker(paramstyle) | |
s.append(safestr(x)) | |
else: | |
x = safestr(x) | |
# automatically escape % characters in the query | |
# For backward compatability, ignore escaping when the query looks already escaped | |
if paramstyle in ['format', 'pyformat']: | |
if '%' in x and '%%' not in x: | |
x = x.replace('%', '%%') | |
s.append(x) | |
return "".join(s) | |
def values(self): | |
""" | |
Returns the values of the parameters used in the sql query. | |
>>> q = SQLQuery(["SELECT * FROM test WHERE name=", SQLParam('joe')]) | |
>>> q.values() | |
['joe'] | |
""" | |
return [i.value for i in self.items if isinstance(i, SQLParam)] | |
def join(items, sep=' ', prefix=None, suffix=None, target=None): | |
""" | |
Joins multiple queries. | |
>>> SQLQuery.join(['a', 'b'], ', ') | |
<sql: 'a, b'> | |
Optinally, prefix and suffix arguments can be provided. | |
>>> SQLQuery.join(['a', 'b'], ', ', prefix='(', suffix=')') | |
<sql: '(a, b)'> | |
If target argument is provided, the items are appended to target instead of creating a new SQLQuery. | |
""" | |
if target is None: | |
target = SQLQuery() | |
target_items = target.items | |
if prefix: | |
target_items.append(prefix) | |
for i, item in enumerate(items): | |
if i != 0: | |
target_items.append(sep) | |
if isinstance(item, SQLQuery): | |
target_items.extend(item.items) | |
else: | |
target_items.append(item) | |
if suffix: | |
target_items.append(suffix) | |
return target | |
join = staticmethod(join) | |
def _str(self): | |
try: | |
return self.query() % tuple([sqlify(x) for x in self.values()]) | |
except (ValueError, TypeError): | |
return self.query() | |
def __str__(self): | |
return safestr(self._str()) | |
def __unicode__(self): | |
return safeunicode(self._str()) | |
def __repr__(self): | |
return '<sql: %s>' % repr(str(self)) | |
class SQLLiteral: | |
""" | |
Protects a string from `sqlquote`. | |
>>> sqlquote('NOW()') | |
<sql: "'NOW()'"> | |
>>> sqlquote(SQLLiteral('NOW()')) | |
<sql: 'NOW()'> | |
""" | |
def __init__(self, v): | |
self.v = v | |
def __repr__(self): | |
return self.v | |
sqlliteral = SQLLiteral | |
def _sqllist(values): | |
""" | |
>>> _sqllist([1, 2, 3]) | |
<sql: '(1, 2, 3)'> | |
""" | |
items = [] | |
items.append('(') | |
for i, v in enumerate(values): | |
if i != 0: | |
items.append(', ') | |
items.append(sqlparam(v)) | |
items.append(')') | |
return SQLQuery(items) | |
def reparam(string_, dictionary): | |
""" | |
Takes a string and a dictionary and interpolates the string | |
using values from the dictionary. Returns an `SQLQuery` for the result. | |
>>> reparam("s = $s", dict(s=True)) | |
<sql: "s = 't'"> | |
>>> reparam("s IN $s", dict(s=[1, 2])) | |
<sql: 's IN (1, 2)'> | |
""" | |
dictionary = dictionary.copy() # eval mucks with it | |
vals = [] | |
result = [] | |
for live, chunk in _interpolate(string_): | |
if live: | |
v = eval(chunk, dictionary) | |
result.append(sqlquote(v)) | |
else: | |
result.append(chunk) | |
return SQLQuery.join(result, '') | |
def sqlify(obj): | |
""" | |
converts `obj` to its proper SQL version | |
>>> sqlify(None) | |
'NULL' | |
>>> sqlify(True) | |
"'t'" | |
>>> sqlify(3) | |
'3' | |
""" | |
# because `1 == True and hash(1) == hash(True)` | |
# we have to do this the hard way... | |
if obj is None: | |
return 'NULL' | |
elif obj is True: | |
return "'t'" | |
elif obj is False: | |
return "'f'" | |
elif datetime and isinstance(obj, datetime.datetime): | |
return repr(obj.isoformat()) | |
else: | |
if isinstance(obj, unicode): obj = obj.encode('utf8') | |
return repr(obj) | |
def sqllist(lst): | |
""" | |
Converts the arguments for use in something like a WHERE clause. | |
>>> sqllist(['a', 'b']) | |
'a, b' | |
>>> sqllist('a') | |
'a' | |
>>> sqllist(u'abc') | |
u'abc' | |
""" | |
if isinstance(lst, basestring): | |
return lst | |
else: | |
return ', '.join(lst) | |
def sqlors(left, lst): | |
""" | |
`left is a SQL clause like `tablename.arg = ` | |
and `lst` is a list of values. Returns a reparam-style | |
pair featuring the SQL that ORs together the clause | |
for each item in the lst. | |
>>> sqlors('foo = ', []) | |
<sql: '1=2'> | |
>>> sqlors('foo = ', [1]) | |
<sql: 'foo = 1'> | |
>>> sqlors('foo = ', 1) | |
<sql: 'foo = 1'> | |
>>> sqlors('foo = ', [1,2,3]) | |
<sql: '(foo = 1 OR foo = 2 OR foo = 3 OR 1=2)'> | |
""" | |
if isinstance(lst, iters): | |
lst = list(lst) | |
ln = len(lst) | |
if ln == 0: | |
return SQLQuery("1=2") | |
if ln == 1: | |
lst = lst[0] | |
if isinstance(lst, iters): | |
return SQLQuery(['('] + | |
sum([[left, sqlparam(x), ' OR '] for x in lst], []) + | |
['1=2)'] | |
) | |
else: | |
return left + sqlparam(lst) | |
def sqlwhere(dictionary, grouping=' AND '): | |
""" | |
Converts a `dictionary` to an SQL WHERE clause `SQLQuery`. | |
>>> sqlwhere({'cust_id': 2, 'order_id':3}) | |
<sql: 'order_id = 3 AND cust_id = 2'> | |
>>> sqlwhere({'cust_id': 2, 'order_id':3}, grouping=', ') | |
<sql: 'order_id = 3, cust_id = 2'> | |
>>> sqlwhere({'a': 'a', 'b': 'b'}).query() | |
'a = %s AND b = %s' | |
""" | |
return SQLQuery.join([k + ' = ' + sqlparam(v) for k, v in dictionary.items()], grouping) | |
def sqlquote(a): | |
""" | |
Ensures `a` is quoted properly for use in a SQL query. | |
>>> 'WHERE x = ' + sqlquote(True) + ' AND y = ' + sqlquote(3) | |
<sql: "WHERE x = 't' AND y = 3"> | |
>>> 'WHERE x = ' + sqlquote(True) + ' AND y IN ' + sqlquote([2, 3]) | |
<sql: "WHERE x = 't' AND y IN (2, 3)"> | |
""" | |
if isinstance(a, list): | |
return _sqllist(a) | |
else: | |
return sqlparam(a).sqlquery() | |
class Transaction: | |
"""Database transaction.""" | |
def __init__(self, ctx): | |
self.ctx = ctx | |
self.transaction_count = transaction_count = len(ctx.transactions) | |
class transaction_engine: | |
"""Transaction Engine used in top level transactions.""" | |
def do_transact(self): | |
ctx.commit(unload=False) | |
def do_commit(self): | |
ctx.commit() | |
def do_rollback(self): | |
ctx.rollback() | |
class subtransaction_engine: | |
"""Transaction Engine used in sub transactions.""" | |
def query(self, q): | |
db_cursor = ctx.db.cursor() | |
ctx.db_execute(db_cursor, SQLQuery(q % transaction_count)) | |
def do_transact(self): | |
self.query('SAVEPOINT webpy_sp_%s') | |
def do_commit(self): | |
self.query('RELEASE SAVEPOINT webpy_sp_%s') | |
def do_rollback(self): | |
self.query('ROLLBACK TO SAVEPOINT webpy_sp_%s') | |
class dummy_engine: | |
"""Transaction Engine used instead of subtransaction_engine | |
when sub transactions are not supported.""" | |
do_transact = do_commit = do_rollback = lambda self: None | |
if self.transaction_count: | |
# nested transactions are not supported in some databases | |
if self.ctx.get('ignore_nested_transactions'): | |
self.engine = dummy_engine() | |
else: | |
self.engine = subtransaction_engine() | |
else: | |
self.engine = transaction_engine() | |
self.engine.do_transact() | |
self.ctx.transactions.append(self) | |
def __enter__(self): | |
return self | |
def __exit__(self, exctype, excvalue, traceback): | |
if exctype is not None: | |
self.rollback() | |
else: | |
self.commit() | |
def commit(self): | |
if len(self.ctx.transactions) > self.transaction_count: | |
self.engine.do_commit() | |
self.ctx.transactions = self.ctx.transactions[:self.transaction_count] | |
def rollback(self): | |
if len(self.ctx.transactions) > self.transaction_count: | |
self.engine.do_rollback() | |
self.ctx.transactions = self.ctx.transactions[:self.transaction_count] | |
class DB: | |
"""Database""" | |
def __init__(self, db_module, keywords): | |
"""Creates a database. | |
""" | |
# some DB implementaions take optional paramater `driver` to use a specific driver modue | |
# but it should not be passed to connect | |
keywords.pop('driver', None) | |
self.db_module = db_module | |
self.keywords = keywords | |
self._ctx = threadeddict() | |
# flag to enable/disable printing queries | |
self.printing = config.get('debug_sql', config.get('debug', False)) | |
self.supports_multiple_insert = False | |
try: | |
import DBUtils | |
# enable pooling if DBUtils module is available. | |
self.has_pooling = True | |
except ImportError: | |
self.has_pooling = False | |
# Pooling can be disabled by passing pooling=False in the keywords. | |
self.has_pooling = self.keywords.pop('pooling', True) and self.has_pooling | |
def _getctx(self): | |
if not self._ctx.get('db'): | |
self._load_context(self._ctx) | |
return self._ctx | |
ctx = property(_getctx) | |
def _load_context(self, ctx): | |
ctx.dbq_count = 0 | |
ctx.transactions = [] # stack of transactions | |
if self.has_pooling: | |
ctx.db = self._connect_with_pooling(self.keywords) | |
else: | |
ctx.db = self._connect(self.keywords) | |
ctx.db_execute = self._db_execute | |
if not hasattr(ctx.db, 'commit'): | |
ctx.db.commit = lambda: None | |
if not hasattr(ctx.db, 'rollback'): | |
ctx.db.rollback = lambda: None | |
def commit(unload=True): | |
# do db commit and release the connection if pooling is enabled. | |
ctx.db.commit() | |
if unload and self.has_pooling: | |
self._unload_context(self._ctx) | |
def rollback(): | |
# do db rollback and release the connection if pooling is enabled. | |
ctx.db.rollback() | |
if self.has_pooling: | |
self._unload_context(self._ctx) | |
ctx.commit = commit | |
ctx.rollback = rollback | |
def _unload_context(self, ctx): | |
del ctx.db | |
def _connect(self, keywords): | |
return self.db_module.connect(**keywords) | |
def _connect_with_pooling(self, keywords): | |
def get_pooled_db(): | |
from DBUtils import PooledDB | |
# In DBUtils 0.9.3, `dbapi` argument is renamed as `creator` | |
# see Bug#122112 | |
if PooledDB.__version__.split('.') < '0.9.3'.split('.'): | |
return PooledDB.PooledDB(dbapi=self.db_module, **keywords) | |
else: | |
return PooledDB.PooledDB(creator=self.db_module, **keywords) | |
if getattr(self, '_pooleddb', None) is None: | |
self._pooleddb = get_pooled_db() | |
return self._pooleddb.connection() | |
def _db_cursor(self): | |
return self.ctx.db.cursor() | |
def _param_marker(self): | |
"""Returns parameter marker based on paramstyle attribute if this database.""" | |
style = getattr(self, 'paramstyle', 'pyformat') | |
if style == 'qmark': | |
return '?' | |
elif style == 'numeric': | |
return ':1' | |
elif style in ['format', 'pyformat']: | |
return '%s' | |
raise UnknownParamstyle, style | |
def _db_execute(self, cur, sql_query): | |
"""executes an sql query""" | |
self.ctx.dbq_count += 1 | |
try: | |
a = time.time() | |
query, params = self._process_query(sql_query) | |
out = cur.execute(query, params) | |
b = time.time() | |
except: | |
if self.printing: | |
print >> debug, 'ERR:', str(sql_query) | |
if self.ctx.transactions: | |
self.ctx.transactions[-1].rollback() | |
else: | |
self.ctx.rollback() | |
raise | |
if self.printing: | |
print >> debug, '%s (%s): %s' % (round(b-a, 2), self.ctx.dbq_count, str(sql_query)) | |
return out | |
def _process_query(self, sql_query): | |
"""Takes the SQLQuery object and returns query string and parameters. | |
""" | |
paramstyle = getattr(self, 'paramstyle', 'pyformat') | |
query = sql_query.query(paramstyle) | |
params = sql_query.values() | |
return query, params | |
def _where(self, where, vars): | |
if isinstance(where, (int, long)): | |
where = "id = " + sqlparam(where) | |
#@@@ for backward-compatibility | |
elif isinstance(where, (list, tuple)) and len(where) == 2: | |
where = SQLQuery(where[0], where[1]) | |
elif isinstance(where, SQLQuery): | |
pass | |
else: | |
where = reparam(where, vars) | |
return where | |
def query(self, sql_query, vars=None, processed=False, _test=False): | |
""" | |
Execute SQL query `sql_query` using dictionary `vars` to interpolate it. | |
If `processed=True`, `vars` is a `reparam`-style list to use | |
instead of interpolating. | |
>>> db = DB(None, {}) | |
>>> db.query("SELECT * FROM foo", _test=True) | |
<sql: 'SELECT * FROM foo'> | |
>>> db.query("SELECT * FROM foo WHERE x = $x", vars=dict(x='f'), _test=True) | |
<sql: "SELECT * FROM foo WHERE x = 'f'"> | |
>>> db.query("SELECT * FROM foo WHERE x = " + sqlquote('f'), _test=True) | |
<sql: "SELECT * FROM foo WHERE x = 'f'"> | |
""" | |
if vars is None: vars = {} | |
if not processed and not isinstance(sql_query, SQLQuery): | |
sql_query = reparam(sql_query, vars) | |
if _test: return sql_query | |
db_cursor = self._db_cursor() | |
self._db_execute(db_cursor, sql_query) | |
if db_cursor.description: | |
names = [x[0] for x in db_cursor.description] | |
def iterwrapper(): | |
row = db_cursor.fetchone() | |
while row: | |
yield storage(dict(zip(names, row))) | |
row = db_cursor.fetchone() | |
out = iterbetter(iterwrapper()) | |
out.__len__ = lambda: int(db_cursor.rowcount) | |
out.list = lambda: [storage(dict(zip(names, x))) \ | |
for x in db_cursor.fetchall()] | |
else: | |
out = db_cursor.rowcount | |
if not self.ctx.transactions: | |
self.ctx.commit() | |
return out | |
def select(self, tables, vars=None, what='*', where=None, order=None, group=None, | |
limit=None, offset=None, _test=False): | |
""" | |
Selects `what` from `tables` with clauses `where`, `order`, | |
`group`, `limit`, and `offset`. Uses vars to interpolate. | |
Otherwise, each clause can be a SQLQuery. | |
>>> db = DB(None, {}) | |
>>> db.select('foo', _test=True) | |
<sql: 'SELECT * FROM foo'> | |
>>> db.select(['foo', 'bar'], where="foo.bar_id = bar.id", limit=5, _test=True) | |
<sql: 'SELECT * FROM foo, bar WHERE foo.bar_id = bar.id LIMIT 5'> | |
""" | |
if vars is None: vars = {} | |
sql_clauses = self.sql_clauses(what, tables, where, group, order, limit, offset) | |
clauses = [self.gen_clause(sql, val, vars) for sql, val in sql_clauses if val is not None] | |
qout = SQLQuery.join(clauses) | |
if _test: return qout | |
return self.query(qout, processed=True) | |
def where(self, table, what='*', order=None, group=None, limit=None, | |
offset=None, _test=False, **kwargs): | |
""" | |
Selects from `table` where keys are equal to values in `kwargs`. | |
>>> db = DB(None, {}) | |
>>> db.where('foo', bar_id=3, _test=True) | |
<sql: 'SELECT * FROM foo WHERE bar_id = 3'> | |
>>> db.where('foo', source=2, crust='dewey', _test=True) | |
<sql: "SELECT * FROM foo WHERE source = 2 AND crust = 'dewey'"> | |
>>> db.where('foo', _test=True) | |
<sql: 'SELECT * FROM foo'> | |
""" | |
where_clauses = [] | |
for k, v in kwargs.iteritems(): | |
where_clauses.append(k + ' = ' + sqlquote(v)) | |
if where_clauses: | |
where = SQLQuery.join(where_clauses, " AND ") | |
else: | |
where = None | |
return self.select(table, what=what, order=order, | |
group=group, limit=limit, offset=offset, _test=_test, | |
where=where) | |
def sql_clauses(self, what, tables, where, group, order, limit, offset): | |
return ( | |
('SELECT', what), | |
('FROM', sqllist(tables)), | |
('WHERE', where), | |
('GROUP BY', group), | |
('ORDER BY', order), | |
('LIMIT', limit), | |
('OFFSET', offset)) | |
def gen_clause(self, sql, val, vars): | |
if isinstance(val, (int, long)): | |
if sql == 'WHERE': | |
nout = 'id = ' + sqlquote(val) | |
else: | |
nout = SQLQuery(val) | |
#@@@ | |
elif isinstance(val, (list, tuple)) and len(val) == 2: | |
nout = SQLQuery(val[0], val[1]) # backwards-compatibility | |
elif isinstance(val, SQLQuery): | |
nout = val | |
else: | |
nout = reparam(val, vars) | |
def xjoin(a, b): | |
if a and b: return a + ' ' + b | |
else: return a or b | |
return xjoin(sql, nout) | |
def insert(self, tablename, seqname=None, _test=False, **values): | |
""" | |
Inserts `values` into `tablename`. Returns current sequence ID. | |
Set `seqname` to the ID if it's not the default, or to `False` | |
if there isn't one. | |
>>> db = DB(None, {}) | |
>>> q = db.insert('foo', name='bob', age=2, created=SQLLiteral('NOW()'), _test=True) | |
>>> q | |
<sql: "INSERT INTO foo (age, name, created) VALUES (2, 'bob', NOW())"> | |
>>> q.query() | |
'INSERT INTO foo (age, name, created) VALUES (%s, %s, NOW())' | |
>>> q.values() | |
[2, 'bob'] | |
""" | |
def q(x): return "(" + x + ")" | |
if values: | |
_keys = SQLQuery.join(values.keys(), ', ') | |
_values = SQLQuery.join([sqlparam(v) for v in values.values()], ', ') | |
sql_query = "INSERT INTO %s " % tablename + q(_keys) + ' VALUES ' + q(_values) | |
else: | |
sql_query = SQLQuery(self._get_insert_default_values_query(tablename)) | |
if _test: return sql_query | |
db_cursor = self._db_cursor() | |
if seqname is not False: | |
sql_query = self._process_insert_query(sql_query, tablename, seqname) | |
if isinstance(sql_query, tuple): | |
# for some databases, a separate query has to be made to find | |
# the id of the inserted row. | |
q1, q2 = sql_query | |
self._db_execute(db_cursor, q1) | |
self._db_execute(db_cursor, q2) | |
else: | |
self._db_execute(db_cursor, sql_query) | |
try: | |
out = db_cursor.fetchone()[0] | |
except Exception: | |
out = None | |
if not self.ctx.transactions: | |
self.ctx.commit() | |
return out | |
def _get_insert_default_values_query(self, table): | |
return "INSERT INTO %s DEFAULT VALUES" % table | |
def multiple_insert(self, tablename, values, seqname=None, _test=False): | |
""" | |
Inserts multiple rows into `tablename`. The `values` must be a list of dictioanries, | |
one for each row to be inserted, each with the same set of keys. | |
Returns the list of ids of the inserted rows. | |
Set `seqname` to the ID if it's not the default, or to `False` | |
if there isn't one. | |
>>> db = DB(None, {}) | |
>>> db.supports_multiple_insert = True | |
>>> values = [{"name": "foo", "email": "foo@example.com"}, {"name": "bar", "email": "bar@example.com"}] | |
>>> db.multiple_insert('person', values=values, _test=True) | |
<sql: "INSERT INTO person (name, email) VALUES ('foo', 'foo@example.com'), ('bar', 'bar@example.com')"> | |
""" | |
if not values: | |
return [] | |
if not self.supports_multiple_insert: | |
out = [self.insert(tablename, seqname=seqname, _test=_test, **v) for v in values] | |
if seqname is False: | |
return None | |
else: | |
return out | |
keys = values[0].keys() | |
#@@ make sure all keys are valid | |
# make sure all rows have same keys. | |
for v in values: | |
if v.keys() != keys: | |
raise ValueError, 'Bad data' | |
sql_query = SQLQuery('INSERT INTO %s (%s) VALUES ' % (tablename, ', '.join(keys))) | |
for i, row in enumerate(values): | |
if i != 0: | |
sql_query.append(", ") | |
SQLQuery.join([SQLParam(row[k]) for k in keys], sep=", ", target=sql_query, prefix="(", suffix=")") | |
if _test: return sql_query | |
db_cursor = self._db_cursor() | |
if seqname is not False: | |
sql_query = self._process_insert_query(sql_query, tablename, seqname) | |
if isinstance(sql_query, tuple): | |
# for some databases, a separate query has to be made to find | |
# the id of the inserted row. | |
q1, q2 = sql_query | |
self._db_execute(db_cursor, q1) | |
self._db_execute(db_cursor, q2) | |
else: | |
self._db_execute(db_cursor, sql_query) | |
try: | |
out = db_cursor.fetchone()[0] | |
out = range(out-len(values)+1, out+1) | |
except Exception: | |
out = None | |
if not self.ctx.transactions: | |
self.ctx.commit() | |
return out | |
def update(self, tables, where, vars=None, _test=False, **values): | |
""" | |
Update `tables` with clause `where` (interpolated using `vars`) | |
and setting `values`. | |
>>> db = DB(None, {}) | |
>>> name = 'Joseph' | |
>>> q = db.update('foo', where='name = $name', name='bob', age=2, | |
... created=SQLLiteral('NOW()'), vars=locals(), _test=True) | |
>>> q | |
<sql: "UPDATE foo SET age = 2, name = 'bob', created = NOW() WHERE name = 'Joseph'"> | |
>>> q.query() | |
'UPDATE foo SET age = %s, name = %s, created = NOW() WHERE name = %s' | |
>>> q.values() | |
[2, 'bob', 'Joseph'] | |
""" | |
if vars is None: vars = {} | |
where = self._where(where, vars) | |
query = ( | |
"UPDATE " + sqllist(tables) + | |
" SET " + sqlwhere(values, ', ') + | |
" WHERE " + where) | |
if _test: return query | |
db_cursor = self._db_cursor() | |
self._db_execute(db_cursor, query) | |
if not self.ctx.transactions: | |
self.ctx.commit() | |
return db_cursor.rowcount | |
def delete(self, table, where, using=None, vars=None, _test=False): | |
""" | |
Deletes from `table` with clauses `where` and `using`. | |
>>> db = DB(None, {}) | |
>>> name = 'Joe' | |
>>> db.delete('foo', where='name = $name', vars=locals(), _test=True) | |
<sql: "DELETE FROM foo WHERE name = 'Joe'"> | |
""" | |
if vars is None: vars = {} | |
where = self._where(where, vars) | |
q = 'DELETE FROM ' + table | |
if using: q += ' USING ' + sqllist(using) | |
if where: q += ' WHERE ' + where | |
if _test: return q | |
db_cursor = self._db_cursor() | |
self._db_execute(db_cursor, q) | |
if not self.ctx.transactions: | |
self.ctx.commit() | |
return db_cursor.rowcount | |
def _process_insert_query(self, query, tablename, seqname): | |
return query | |
def transaction(self): | |
"""Start a transaction.""" | |
return Transaction(self.ctx) | |
class PostgresDB(DB): | |
"""Postgres driver.""" | |
def __init__(self, **keywords): | |
if 'pw' in keywords: | |
keywords['password'] = keywords.pop('pw') | |
db_module = import_driver(["psycopg2", "psycopg", "pgdb"], preferred=keywords.pop('driver', None)) | |
if db_module.__name__ == "psycopg2": | |
import psycopg2.extensions | |
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) | |
# if db is not provided postgres driver will take it from PGDATABASE environment variable | |
if 'db' in keywords: | |
keywords['database'] = keywords.pop('db') | |
self.dbname = "postgres" | |
self.paramstyle = db_module.paramstyle | |
DB.__init__(self, db_module, keywords) | |
self.supports_multiple_insert = True | |
self._sequences = None | |
def _process_insert_query(self, query, tablename, seqname): | |
if seqname is None: | |
# when seqname is not provided guess the seqname and make sure it exists | |
seqname = tablename + "_id_seq" | |
if seqname not in self._get_all_sequences(): | |
seqname = None | |
if seqname: | |
query += "; SELECT currval('%s')" % seqname | |
return query | |
def _get_all_sequences(self): | |
"""Query postgres to find names of all sequences used in this database.""" | |
if self._sequences is None: | |
q = "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S'" | |
self._sequences = set([c.relname for c in self.query(q)]) | |
return self._sequences | |
def _connect(self, keywords): | |
conn = DB._connect(self, keywords) | |
try: | |
conn.set_client_encoding('UTF8') | |
except AttributeError: | |
# fallback for pgdb driver | |
conn.cursor().execute("set client_encoding to 'UTF-8'") | |
return conn | |
def _connect_with_pooling(self, keywords): | |
conn = DB._connect_with_pooling(self, keywords) | |
conn._con._con.set_client_encoding('UTF8') | |
return conn | |
class MySQLDB(DB): | |
def __init__(self, **keywords): | |
import MySQLdb as db | |
if 'pw' in keywords: | |
keywords['passwd'] = keywords['pw'] | |
del keywords['pw'] | |
if 'charset' not in keywords: | |
keywords['charset'] = 'utf8' | |
elif keywords['charset'] is None: | |
del keywords['charset'] | |
self.paramstyle = db.paramstyle = 'pyformat' # it's both, like psycopg | |
self.dbname = "mysql" | |
DB.__init__(self, db, keywords) | |
self.supports_multiple_insert = True | |
def _process_insert_query(self, query, tablename, seqname): | |
return query, SQLQuery('SELECT last_insert_id();') | |
def _get_insert_default_values_query(self, table): | |
return "INSERT INTO %s () VALUES()" % table | |
def import_driver(drivers, preferred=None): | |
"""Import the first available driver or preferred driver. | |
""" | |
if preferred: | |
drivers = [preferred] | |
for d in drivers: | |
try: | |
return __import__(d, None, None, ['x']) | |
except ImportError: | |
pass | |
raise ImportError("Unable to import " + " or ".join(drivers)) | |
class SqliteDB(DB): | |
def __init__(self, **keywords): | |
db = import_driver(["sqlite3", "pysqlite2.dbapi2", "sqlite"], preferred=keywords.pop('driver', None)) | |
if db.__name__ in ["sqlite3", "pysqlite2.dbapi2"]: | |
db.paramstyle = 'qmark' | |
# sqlite driver doesn't create datatime objects for timestamp columns unless `detect_types` option is passed. | |
# It seems to be supported in sqlite3 and pysqlite2 drivers, not surte about sqlite. | |
keywords.setdefault('detect_types', db.PARSE_DECLTYPES) | |
self.paramstyle = db.paramstyle | |
keywords['database'] = keywords.pop('db') | |
self.dbname = "sqlite" | |
DB.__init__(self, db, keywords) | |
def _process_insert_query(self, query, tablename, seqname): | |
return query, SQLQuery('SELECT last_insert_rowid();') | |
def query(self, *a, **kw): | |
out = DB.query(self, *a, **kw) | |
if isinstance(out, iterbetter): | |
del out.__len__ | |
return out | |
class FirebirdDB(DB): | |
"""Firebird Database. | |
""" | |
def __init__(self, **keywords): | |
try: | |
import kinterbasdb as db | |
except Exception: | |
db = None | |
pass | |
if 'pw' in keywords: | |
keywords['passwd'] = keywords['pw'] | |
del keywords['pw'] | |
keywords['database'] = keywords['db'] | |
del keywords['db'] | |
DB.__init__(self, db, keywords) | |
def delete(self, table, where=None, using=None, vars=None, _test=False): | |
# firebird doesn't support using clause | |
using=None | |
return DB.delete(self, table, where, using, vars, _test) | |
def sql_clauses(self, what, tables, where, group, order, limit, offset): | |
return ( | |
('SELECT', ''), | |
('FIRST', limit), | |
('SKIP', offset), | |
('', what), | |
('FROM', sqllist(tables)), | |
('WHERE', where), | |
('GROUP BY', group), | |
('ORDER BY', order) | |
) | |
class MSSQLDB(DB): | |
def __init__(self, **keywords): | |
import pymssql as db | |
if 'pw' in keywords: | |
keywords['password'] = keywords.pop('pw') | |
keywords['database'] = keywords.pop('db') | |
self.dbname = "mssql" | |
DB.__init__(self, db, keywords) | |
def _process_query(self, sql_query): | |
"""Takes the SQLQuery object and returns query string and parameters. | |
""" | |
# MSSQLDB expects params to be a tuple. | |
# Overwriting the default implementation to convert params to tuple. | |
paramstyle = getattr(self, 'paramstyle', 'pyformat') | |
query = sql_query.query(paramstyle) | |
params = sql_query.values() | |
return query, tuple(params) | |
def sql_clauses(self, what, tables, where, group, order, limit, offset): | |
return ( | |
('SELECT', what), | |
('TOP', limit), | |
('FROM', sqllist(tables)), | |
('WHERE', where), | |
('GROUP BY', group), | |
('ORDER BY', order), | |
('OFFSET', offset)) | |
def _test(self): | |
"""Test LIMIT. | |
Fake presence of pymssql module for running tests. | |
>>> import sys | |
>>> sys.modules['pymssql'] = sys.modules['sys'] | |
MSSQL has TOP clause instead of LIMIT clause. | |
>>> db = MSSQLDB(db='test', user='joe', pw='secret') | |
>>> db.select('foo', limit=4, _test=True) | |
<sql: 'SELECT * TOP 4 FROM foo'> | |
""" | |
pass | |
class OracleDB(DB): | |
def __init__(self, **keywords): | |
import cx_Oracle as db | |
if 'pw' in keywords: | |
keywords['password'] = keywords.pop('pw') | |
#@@ TODO: use db.makedsn if host, port is specified | |
keywords['dsn'] = keywords.pop('db') | |
self.dbname = 'oracle' | |
db.paramstyle = 'numeric' | |
self.paramstyle = db.paramstyle | |
# oracle doesn't support pooling | |
keywords.pop('pooling', None) | |
DB.__init__(self, db, keywords) | |
def _process_insert_query(self, query, tablename, seqname): | |
if seqname is None: | |
# It is not possible to get seq name from table name in Oracle | |
return query | |
else: | |
return query + "; SELECT %s.currval FROM dual" % seqname | |
_databases = {} | |
def database(dburl=None, **params): | |
"""Creates appropriate database using params. | |
Pooling will be enabled if DBUtils module is available. | |
Pooling can be disabled by passing pooling=False in params. | |
""" | |
dbn = params.pop('dbn') | |
if dbn in _databases: | |
return _databases[dbn](**params) | |
else: | |
raise UnknownDB, dbn | |
def register_database(name, clazz): | |
""" | |
Register a database. | |
>>> class LegacyDB(DB): | |
... def __init__(self, **params): | |
... pass | |
... | |
>>> register_database('legacy', LegacyDB) | |
>>> db = database(dbn='legacy', db='test', user='joe', passwd='secret') | |
""" | |
_databases[name] = clazz | |
register_database('mysql', MySQLDB) | |
register_database('postgres', PostgresDB) | |
register_database('sqlite', SqliteDB) | |
register_database('firebird', FirebirdDB) | |
register_database('mssql', MSSQLDB) | |
register_database('oracle', OracleDB) | |
def _interpolate(format): | |
""" | |
Takes a format string and returns a list of 2-tuples of the form | |
(boolean, string) where boolean says whether string should be evaled | |
or not. | |
from <http://lfw.org/python/Itpl.py> (public domain, Ka-Ping Yee) | |
""" | |
from tokenize import tokenprog | |
def matchorfail(text, pos): | |
match = tokenprog.match(text, pos) | |
if match is None: | |
raise _ItplError(text, pos) | |
return match, match.end() | |
namechars = "abcdefghijklmnopqrstuvwxyz" \ | |
"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; | |
chunks = [] | |
pos = 0 | |
while 1: | |
dollar = format.find("$", pos) | |
if dollar < 0: | |
break | |
nextchar = format[dollar + 1] | |
if nextchar == "{": | |
chunks.append((0, format[pos:dollar])) | |
pos, level = dollar + 2, 1 | |
while level: | |
match, pos = matchorfail(format, pos) | |
tstart, tend = match.regs[3] | |
token = format[tstart:tend] | |
if token == "{": | |
level = level + 1 | |
elif token == "}": | |
level = level - 1 | |
chunks.append((1, format[dollar + 2:pos - 1])) | |
elif nextchar in namechars: | |
chunks.append((0, format[pos:dollar])) | |
match, pos = matchorfail(format, dollar + 1) | |
while pos < len(format): | |
if format[pos] == "." and \ | |
pos + 1 < len(format) and format[pos + 1] in namechars: | |
match, pos = matchorfail(format, pos + 1) | |
elif format[pos] in "([": | |
pos, level = pos + 1, 1 | |
while level: | |
match, pos = matchorfail(format, pos) | |
tstart, tend = match.regs[3] | |
token = format[tstart:tend] | |
if token[0] in "([": | |
level = level + 1 | |
elif token[0] in ")]": | |
level = level - 1 | |
else: | |
break | |
chunks.append((1, format[dollar + 1:pos])) | |
else: | |
chunks.append((0, format[pos:dollar + 1])) | |
pos = dollar + 1 + (nextchar == "$") | |
if pos < len(format): | |
chunks.append((0, format[pos:])) | |
return chunks | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
pretty debug errors | |
(part of web.py) | |
portions adapted from Django <djangoproject.com> | |
Copyright (c) 2005, the Lawrence Journal-World | |
Used under the modified BSD license: | |
http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5 | |
""" | |
__all__ = ["debugerror", "djangoerror", "emailerrors"] | |
import sys, urlparse, pprint, traceback | |
from template import Template | |
from net import websafe | |
from utils import sendmail, safestr | |
import webapi as web | |
import os, os.path | |
whereami = os.path.join(os.getcwd(), __file__) | |
whereami = os.path.sep.join(whereami.split(os.path.sep)[:-1]) | |
djangoerror_t = """\ | |
$def with (exception_type, exception_value, frames) | |
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" "http://www.w3.org/TR/html4/loose.dtd"> | |
<html lang="en"> | |
<head> | |
<meta http-equiv="content-type" content="text/html; charset=utf-8" /> | |
<meta name="robots" content="NONE,NOARCHIVE" /> | |
<title>$exception_type at $ctx.path</title> | |
<style type="text/css"> | |
html * { padding:0; margin:0; } | |
body * { padding:10px 20px; } | |
body * * { padding:0; } | |
body { font:small sans-serif; } | |
body>div { border-bottom:1px solid #ddd; } | |
h1 { font-weight:normal; } | |
h2 { margin-bottom:.8em; } | |
h2 span { font-size:80%; color:#666; font-weight:normal; } | |
h3 { margin:1em 0 .5em 0; } | |
h4 { margin:0 0 .5em 0; font-weight: normal; } | |
table { | |
border:1px solid #ccc; border-collapse: collapse; background:white; } | |
tbody td, tbody th { vertical-align:top; padding:2px 3px; } | |
thead th { | |
padding:1px 6px 1px 3px; background:#fefefe; text-align:left; | |
font-weight:normal; font-size:11px; border:1px solid #ddd; } | |
tbody th { text-align:right; color:#666; padding-right:.5em; } | |
table.vars { margin:5px 0 2px 40px; } | |
table.vars td, table.req td { font-family:monospace; } | |
table td.code { width:100%;} | |
table td.code div { overflow:hidden; } | |
table.source th { color:#666; } | |
table.source td { | |
font-family:monospace; white-space:pre; border-bottom:1px solid #eee; } | |
ul.traceback { list-style-type:none; } | |
ul.traceback li.frame { margin-bottom:1em; } | |
div.context { margin: 10px 0; } | |
div.context ol { | |
padding-left:30px; margin:0 10px; list-style-position: inside; } | |
div.context ol li { | |
font-family:monospace; white-space:pre; color:#666; cursor:pointer; } | |
div.context ol.context-line li { color:black; background-color:#ccc; } | |
div.context ol.context-line li span { float: right; } | |
div.commands { margin-left: 40px; } | |
div.commands a { color:black; text-decoration:none; } | |
#summary { background: #ffc; } | |
#summary h2 { font-weight: normal; color: #666; } | |
#explanation { background:#eee; } | |
#template, #template-not-exist { background:#f6f6f6; } | |
#template-not-exist ul { margin: 0 0 0 20px; } | |
#traceback { background:#eee; } | |
#requestinfo { background:#f6f6f6; padding-left:120px; } | |
#summary table { border:none; background:transparent; } | |
#requestinfo h2, #requestinfo h3 { position:relative; margin-left:-100px; } | |
#requestinfo h3 { margin-bottom:-1em; } | |
.error { background: #ffc; } | |
.specific { color:#cc3300; font-weight:bold; } | |
</style> | |
<script type="text/javascript"> | |
//<!-- | |
function getElementsByClassName(oElm, strTagName, strClassName){ | |
// Written by Jonathan Snook, http://www.snook.ca/jon; | |
// Add-ons by Robert Nyman, http://www.robertnyman.com | |
var arrElements = (strTagName == "*" && document.all)? document.all : | |
oElm.getElementsByTagName(strTagName); | |
var arrReturnElements = new Array(); | |
strClassName = strClassName.replace(/\-/g, "\\-"); | |
var oRegExp = new RegExp("(^|\\s)" + strClassName + "(\\s|$$)"); | |
var oElement; | |
for(var i=0; i<arrElements.length; i++){ | |
oElement = arrElements[i]; | |
if(oRegExp.test(oElement.className)){ | |
arrReturnElements.push(oElement); | |
} | |
} | |
return (arrReturnElements) | |
} | |
function hideAll(elems) { | |
for (var e = 0; e < elems.length; e++) { | |
elems[e].style.display = 'none'; | |
} | |
} | |
window.onload = function() { | |
hideAll(getElementsByClassName(document, 'table', 'vars')); | |
hideAll(getElementsByClassName(document, 'ol', 'pre-context')); | |
hideAll(getElementsByClassName(document, 'ol', 'post-context')); | |
} | |
function toggle() { | |
for (var i = 0; i < arguments.length; i++) { | |
var e = document.getElementById(arguments[i]); | |
if (e) { | |
e.style.display = e.style.display == 'none' ? 'block' : 'none'; | |
} | |
} | |
return false; | |
} | |
function varToggle(link, id) { | |
toggle('v' + id); | |
var s = link.getElementsByTagName('span')[0]; | |
var uarr = String.fromCharCode(0x25b6); | |
var darr = String.fromCharCode(0x25bc); | |
s.innerHTML = s.innerHTML == uarr ? darr : uarr; | |
return false; | |
} | |
//--> | |
</script> | |
</head> | |
<body> | |
$def dicttable (d, kls='req', id=None): | |
$ items = d and d.items() or [] | |
$items.sort() | |
$:dicttable_items(items, kls, id) | |
$def dicttable_items(items, kls='req', id=None): | |
$if items: | |
<table class="$kls" | |
$if id: id="$id" | |
><thead><tr><th>Variable</th><th>Value</th></tr></thead> | |
<tbody> | |
$for k, v in items: | |
<tr><td>$k</td><td class="code"><div>$prettify(v)</div></td></tr> | |
</tbody> | |
</table> | |
$else: | |
<p>No data.</p> | |
<div id="summary"> | |
<h1>$exception_type at $ctx.path</h1> | |
<h2>$exception_value</h2> | |
<table><tr> | |
<th>Python</th> | |
<td>$frames[0].filename in $frames[0].function, line $frames[0].lineno</td> | |
</tr><tr> | |
<th>Web</th> | |
<td>$ctx.method $ctx.home$ctx.path</td> | |
</tr></table> | |
</div> | |
<div id="traceback"> | |
<h2>Traceback <span>(innermost first)</span></h2> | |
<ul class="traceback"> | |
$for frame in frames: | |
<li class="frame"> | |
<code>$frame.filename</code> in <code>$frame.function</code> | |
$if frame.context_line is not None: | |
<div class="context" id="c$frame.id"> | |
$if frame.pre_context: | |
<ol start="$frame.pre_context_lineno" class="pre-context" id="pre$frame.id"> | |
$for line in frame.pre_context: | |
<li onclick="toggle('pre$frame.id', 'post$frame.id')">$line</li> | |
</ol> | |
<ol start="$frame.lineno" class="context-line"><li onclick="toggle('pre$frame.id', 'post$frame.id')">$frame.context_line <span>...</span></li></ol> | |
$if frame.post_context: | |
<ol start='${frame.lineno + 1}' class="post-context" id="post$frame.id"> | |
$for line in frame.post_context: | |
<li onclick="toggle('pre$frame.id', 'post$frame.id')">$line</li> | |
</ol> | |
</div> | |
$if frame.vars: | |
<div class="commands"> | |
<a href='#' onclick="return varToggle(this, '$frame.id')"><span>▶</span> Local vars</a> | |
$# $inspect.formatargvalues(*inspect.getargvalues(frame['tb'].tb_frame)) | |
</div> | |
$:dicttable(frame.vars, kls='vars', id=('v' + str(frame.id))) | |
</li> | |
</ul> | |
</div> | |
<div id="requestinfo"> | |
$if ctx.output or ctx.headers: | |
<h2>Response so far</h2> | |
<h3>HEADERS</h3> | |
$:dicttable_items(ctx.headers) | |
<h3>BODY</h3> | |
<p class="req" style="padding-bottom: 2em"><code> | |
$ctx.output | |
</code></p> | |
<h2>Request information</h2> | |
<h3>INPUT</h3> | |
$:dicttable(web.input(_unicode=False)) | |
<h3 id="cookie-info">COOKIES</h3> | |
$:dicttable(web.cookies()) | |
<h3 id="meta-info">META</h3> | |
$ newctx = [(k, v) for (k, v) in ctx.iteritems() if not k.startswith('_') and not isinstance(v, dict)] | |
$:dicttable(dict(newctx)) | |
<h3 id="meta-info">ENVIRONMENT</h3> | |
$:dicttable(ctx.env) | |
</div> | |
<div id="explanation"> | |
<p> | |
You're seeing this error because you have <code>web.config.debug</code> | |
set to <code>True</code>. Set that to <code>False</code> if you don't want to see this. | |
</p> | |
</div> | |
</body> | |
</html> | |
""" | |
djangoerror_r = None | |
def djangoerror(): | |
def _get_lines_from_file(filename, lineno, context_lines): | |
""" | |
Returns context_lines before and after lineno from file. | |
Returns (pre_context_lineno, pre_context, context_line, post_context). | |
""" | |
try: | |
source = open(filename).readlines() | |
lower_bound = max(0, lineno - context_lines) | |
upper_bound = lineno + context_lines | |
pre_context = \ | |
[line.strip('\n') for line in source[lower_bound:lineno]] | |
context_line = source[lineno].strip('\n') | |
post_context = \ | |
[line.strip('\n') for line in source[lineno + 1:upper_bound]] | |
return lower_bound, pre_context, context_line, post_context | |
except (OSError, IOError, IndexError): | |
return None, [], None, [] | |
exception_type, exception_value, tback = sys.exc_info() | |
frames = [] | |
while tback is not None: | |
filename = tback.tb_frame.f_code.co_filename | |
function = tback.tb_frame.f_code.co_name | |
lineno = tback.tb_lineno - 1 | |
# hack to get correct line number for templates | |
lineno += tback.tb_frame.f_locals.get("__lineoffset__", 0) | |
pre_context_lineno, pre_context, context_line, post_context = \ | |
_get_lines_from_file(filename, lineno, 7) | |
if '__hidetraceback__' not in tback.tb_frame.f_locals: | |
frames.append(web.storage({ | |
'tback': tback, | |
'filename': filename, | |
'function': function, | |
'lineno': lineno, | |
'vars': tback.tb_frame.f_locals, | |
'id': id(tback), | |
'pre_context': pre_context, | |
'context_line': context_line, | |
'post_context': post_context, | |
'pre_context_lineno': pre_context_lineno, | |
})) | |
tback = tback.tb_next | |
frames.reverse() | |
urljoin = urlparse.urljoin | |
def prettify(x): | |
try: | |
out = pprint.pformat(x) | |
except Exception, e: | |
out = '[could not display: <' + e.__class__.__name__ + \ | |
': '+str(e)+'>]' | |
return out | |
global djangoerror_r | |
if djangoerror_r is None: | |
djangoerror_r = Template(djangoerror_t, filename=__file__, filter=websafe) | |
t = djangoerror_r | |
globals = {'ctx': web.ctx, 'web':web, 'dict':dict, 'str':str, 'prettify': prettify} | |
t.t.func_globals.update(globals) | |
return t(exception_type, exception_value, frames) | |
def debugerror(): | |
""" | |
A replacement for `internalerror` that presents a nice page with lots | |
of debug information for the programmer. | |
(Based on the beautiful 500 page from [Django](http://djangoproject.com/), | |
designed by [Wilson Miner](http://wilsonminer.com/).) | |
""" | |
return web._InternalError(djangoerror()) | |
def emailerrors(to_address, olderror, from_address=None): | |
""" | |
Wraps the old `internalerror` handler (pass as `olderror`) to | |
additionally email all errors to `to_address`, to aid in | |
debugging production websites. | |
Emails contain a normal text traceback as well as an | |
attachment containing the nice `debugerror` page. | |
""" | |
from_address = from_address or to_address | |
def emailerrors_internal(): | |
error = olderror() | |
tb = sys.exc_info() | |
error_name = tb[0] | |
error_value = tb[1] | |
tb_txt = ''.join(traceback.format_exception(*tb)) | |
path = web.ctx.path | |
request = web.ctx.method + ' ' + web.ctx.home + web.ctx.fullpath | |
message = "\n%s\n\n%s\n\n" % (request, tb_txt) | |
sendmail( | |
"your buggy site <%s>" % from_address, | |
"the bugfixer <%s>" % to_address, | |
"bug: %(error_name)s: %(error_value)s (%(path)s)" % locals(), | |
message, | |
attachments=[ | |
dict(filename="bug.html", content=safestr(djangoerror())) | |
], | |
) | |
return error | |
return emailerrors_internal | |
if __name__ == "__main__": | |
urls = ( | |
'/', 'index' | |
) | |
from application import application | |
app = application(urls, globals()) | |
app.internalerror = debugerror | |
class index: | |
def GET(self): | |
thisdoesnotexist | |
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
"""Universal feed parser | |
Handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds | |
Visit http://feedparser.org/ for the latest version | |
Visit http://feedparser.org/docs/ for the latest documentation | |
Required: Python 2.4 or later | |
Recommended: CJKCodecs and iconv_codec <http://cjkpython.i18n.org/> | |
""" | |
__version__ = "5.0.1" | |
__license__ = """Copyright (c) 2002-2008, Mark Pilgrim, All rights reserved. | |
Redistribution and use in source and binary forms, with or without modification, | |
are permitted provided that the following conditions are met: | |
* Redistributions of source code must retain the above copyright notice, | |
this list of conditions and the following disclaimer. | |
* Redistributions in binary form must reproduce the above copyright notice, | |
this list of conditions and the following disclaimer in the documentation | |
and/or other materials provided with the distribution. | |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 'AS IS' | |
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | |
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE | |
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | |
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | |
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | |
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | |
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | |
POSSIBILITY OF SUCH DAMAGE.""" | |
__author__ = "Mark Pilgrim <http://diveintomark.org/>" | |
__contributors__ = ["Jason Diamond <http://injektilo.org/>", | |
"John Beimler <http://john.beimler.org/>", | |
"Fazal Majid <http://www.majid.info/mylos/weblog/>", | |
"Aaron Swartz <http://aaronsw.com/>", | |
"Kevin Marks <http://epeus.blogspot.com/>", | |
"Sam Ruby <http://intertwingly.net/>", | |
"Ade Oshineye <http://blog.oshineye.com/>", | |
"Martin Pool <http://sourcefrog.net/>", | |
"Kurt McKee <http://kurtmckee.org/>"] | |
_debug = 0 | |
# HTTP "User-Agent" header to send to servers when downloading feeds. | |
# If you are embedding feedparser in a larger application, you should | |
# change this to your application name and URL. | |
USER_AGENT = "UniversalFeedParser/%s +http://feedparser.org/" % __version__ | |
# HTTP "Accept" header to send to servers when downloading feeds. If you don't | |
# want to send an Accept header, set this to None. | |
ACCEPT_HEADER = "application/atom+xml,application/rdf+xml,application/rss+xml,application/x-netcdf,application/xml;q=0.9,text/xml;q=0.2,*/*;q=0.1" | |
# List of preferred XML parsers, by SAX driver name. These will be tried first, | |
# but if they're not installed, Python will keep searching through its own list | |
# of pre-installed parsers until it finds one that supports everything we need. | |
PREFERRED_XML_PARSERS = ["drv_libxml2"] | |
# If you want feedparser to automatically run HTML markup through HTML Tidy, set | |
# this to 1. Requires mxTidy <http://www.egenix.com/files/python/mxTidy.html> | |
# or utidylib <http://utidylib.berlios.de/>. | |
TIDY_MARKUP = 0 | |
# List of Python interfaces for HTML Tidy, in order of preference. Only useful | |
# if TIDY_MARKUP = 1 | |
PREFERRED_TIDY_INTERFACES = ["uTidy", "mxTidy"] | |
# If you want feedparser to automatically resolve all relative URIs, set this | |
# to 1. | |
RESOLVE_RELATIVE_URIS = 1 | |
# If you want feedparser to automatically sanitize all potentially unsafe | |
# HTML content, set this to 1. | |
SANITIZE_HTML = 1 | |
# ---------- Python 3 modules (make it work if possible) ---------- | |
try: | |
import rfc822 | |
except ImportError: | |
from email import _parseaddr as rfc822 | |
try: | |
# Python 3.1 introduces bytes.maketrans and simultaneously | |
# deprecates string.maketrans; use bytes.maketrans if possible | |
_maketrans = bytes.maketrans | |
except (NameError, AttributeError): | |
import string | |
_maketrans = string.maketrans | |
# base64 support for Atom feeds that contain embedded binary data | |
try: | |
import base64, binascii | |
# Python 3.1 deprecates decodestring in favor of decodebytes | |
_base64decode = getattr(base64, 'decodebytes', base64.decodestring) | |
except: | |
base64 = binascii = None | |
def _s2bytes(s): | |
# Convert a UTF-8 str to bytes if the interpreter is Python 3 | |
try: | |
return bytes(s, 'utf8') | |
except (NameError, TypeError): | |
# In Python 2.5 and below, bytes doesn't exist (NameError) | |
# In Python 2.6 and above, bytes and str are the same (TypeError) | |
return s | |
def _l2bytes(l): | |
# Convert a list of ints to bytes if the interpreter is Python 3 | |
try: | |
if bytes is not str: | |
# In Python 2.6 and above, this call won't raise an exception | |
# but it will return bytes([65]) as '[65]' instead of 'A' | |
return bytes(l) | |
raise NameError | |
except NameError: | |
return ''.join(map(chr, l)) | |
# If you want feedparser to allow all URL schemes, set this to () | |
# List culled from Python's urlparse documentation at: | |
# http://docs.python.org/library/urlparse.html | |
# as well as from "URI scheme" at Wikipedia: | |
# https://secure.wikimedia.org/wikipedia/en/wiki/URI_scheme | |
# Many more will likely need to be added! | |
ACCEPTABLE_URI_SCHEMES = ( | |
'file', 'ftp', 'gopher', 'h323', 'hdl', 'http', 'https', 'imap', 'mailto', | |
'mms', 'news', 'nntp', 'prospero', 'rsync', 'rtsp', 'rtspu', 'sftp', | |
'shttp', 'sip', 'sips', 'snews', 'svn', 'svn+ssh', 'telnet', 'wais', | |
# Additional common-but-unofficial schemes | |
'aim', 'callto', 'cvs', 'facetime', 'feed', 'git', 'gtalk', 'irc', 'ircs', | |
'irc6', 'itms', 'mms', 'msnim', 'skype', 'ssh', 'smb', 'svn', 'ymsg', | |
) | |
#ACCEPTABLE_URI_SCHEMES = () | |
# ---------- required modules (should come with any Python distribution) ---------- | |
import sgmllib, re, sys, copy, urlparse, time, types, cgi, urllib, urllib2, datetime | |
try: | |
from io import BytesIO as _StringIO | |
except ImportError: | |
try: | |
from cStringIO import StringIO as _StringIO | |
except: | |
from StringIO import StringIO as _StringIO | |
# ---------- optional modules (feedparser will work without these, but with reduced functionality) ---------- | |
# gzip is included with most Python distributions, but may not be available if you compiled your own | |
try: | |
import gzip | |
except: | |
gzip = None | |
try: | |
import zlib | |
except: | |
zlib = None | |
# If a real XML parser is available, feedparser will attempt to use it. feedparser has | |
# been tested with the built-in SAX parser, PyXML, and libxml2. On platforms where the | |
# Python distribution does not come with an XML parser (such as Mac OS X 10.2 and some | |
# versions of FreeBSD), feedparser will quietly fall back on regex-based parsing. | |
try: | |
import xml.sax | |
xml.sax.make_parser(PREFERRED_XML_PARSERS) # test for valid parsers | |
from xml.sax.saxutils import escape as _xmlescape | |
_XML_AVAILABLE = 1 | |
except: | |
_XML_AVAILABLE = 0 | |
def _xmlescape(data,entities={}): | |
data = data.replace('&', '&') | |
data = data.replace('>', '>') | |
data = data.replace('<', '<') | |
for char, entity in entities: | |
data = data.replace(char, entity) | |
return data | |
# cjkcodecs and iconv_codec provide support for more character encodings. | |
# Both are available from http://cjkpython.i18n.org/ | |
try: | |
import cjkcodecs.aliases | |
except: | |
pass | |
try: | |
import iconv_codec | |
except: | |
pass | |
# chardet library auto-detects character encodings | |
# Download from http://chardet.feedparser.org/ | |
try: | |
import chardet | |
if _debug: | |
import chardet.constants | |
chardet.constants._debug = 1 | |
except: | |
chardet = None | |
# reversable htmlentitydefs mappings for Python 2.2 | |
try: | |
from htmlentitydefs import name2codepoint, codepoint2name | |
except: | |
import htmlentitydefs | |
name2codepoint={} | |
codepoint2name={} | |
for (name,codepoint) in htmlentitydefs.entitydefs.iteritems(): | |
if codepoint.startswith('&#'): codepoint=unichr(int(codepoint[2:-1])) | |
name2codepoint[name]=ord(codepoint) | |
codepoint2name[ord(codepoint)]=name | |
# BeautifulSoup parser used for parsing microformats from embedded HTML content | |
# http://www.crummy.com/software/BeautifulSoup/ | |
# feedparser is tested with BeautifulSoup 3.0.x, but it might work with the | |
# older 2.x series. If it doesn't, and you can figure out why, I'll accept a | |
# patch and modify the compatibility statement accordingly. | |
try: | |
import BeautifulSoup | |
except: | |
BeautifulSoup = None | |
# ---------- don't touch these ---------- | |
class ThingsNobodyCaresAboutButMe(Exception): pass | |
class CharacterEncodingOverride(ThingsNobodyCaresAboutButMe): pass | |
class CharacterEncodingUnknown(ThingsNobodyCaresAboutButMe): pass | |
class NonXMLContentType(ThingsNobodyCaresAboutButMe): pass | |
class UndeclaredNamespace(Exception): pass | |
sgmllib.tagfind = re.compile('[a-zA-Z][-_.:a-zA-Z0-9]*') | |
sgmllib.special = re.compile('<!') | |
sgmllib.charref = re.compile('&#(\d+|[xX][0-9a-fA-F]+);') | |
if sgmllib.endbracket.search(' <').start(0): | |
class EndBracketRegEx: | |
def __init__(self): | |
# Overriding the built-in sgmllib.endbracket regex allows the | |
# parser to find angle brackets embedded in element attributes. | |
self.endbracket = re.compile('''([^'"<>]|"[^"]*"(?=>|/|\s|\w+=)|'[^']*'(?=>|/|\s|\w+=))*(?=[<>])|.*?(?=[<>])''') | |
def search(self,string,index=0): | |
match = self.endbracket.match(string,index) | |
if match is not None: | |
# Returning a new object in the calling thread's context | |
# resolves a thread-safety. | |
return EndBracketMatch(match) | |
return None | |
class EndBracketMatch: | |
def __init__(self, match): | |
self.match = match | |
def start(self, n): | |
return self.match.end(n) | |
sgmllib.endbracket = EndBracketRegEx() | |
SUPPORTED_VERSIONS = {'': 'unknown', | |
'rss090': 'RSS 0.90', | |
'rss091n': 'RSS 0.91 (Netscape)', | |
'rss091u': 'RSS 0.91 (Userland)', | |
'rss092': 'RSS 0.92', | |
'rss093': 'RSS 0.93', | |
'rss094': 'RSS 0.94', | |
'rss20': 'RSS 2.0', | |
'rss10': 'RSS 1.0', | |
'rss': 'RSS (unknown version)', | |
'atom01': 'Atom 0.1', | |
'atom02': 'Atom 0.2', | |
'atom03': 'Atom 0.3', | |
'atom10': 'Atom 1.0', | |
'atom': 'Atom (unknown version)', | |
'cdf': 'CDF', | |
'hotrss': 'Hot RSS' | |
} | |
try: | |
UserDict = dict | |
except NameError: | |
# Python 2.1 does not have dict | |
from UserDict import UserDict | |
def dict(aList): | |
rc = {} | |
for k, v in aList: | |
rc[k] = v | |
return rc | |
class FeedParserDict(UserDict): | |
keymap = {'channel': 'feed', | |
'items': 'entries', | |
'guid': 'id', | |
'date': 'updated', | |
'date_parsed': 'updated_parsed', | |
'description': ['summary', 'subtitle'], | |
'url': ['href'], | |
'modified': 'updated', | |
'modified_parsed': 'updated_parsed', | |
'issued': 'published', | |
'issued_parsed': 'published_parsed', | |
'copyright': 'rights', | |
'copyright_detail': 'rights_detail', | |
'tagline': 'subtitle', | |
'tagline_detail': 'subtitle_detail'} | |
def __getitem__(self, key): | |
if key == 'category': | |
return UserDict.__getitem__(self, 'tags')[0]['term'] | |
if key == 'enclosures': | |
norel = lambda link: FeedParserDict([(name,value) for (name,value) in link.items() if name!='rel']) | |
return [norel(link) for link in UserDict.__getitem__(self, 'links') if link['rel']=='enclosure'] | |
if key == 'license': | |
for link in UserDict.__getitem__(self, 'links'): | |
if link['rel']=='license' and link.has_key('href'): | |
return link['href'] | |
if key == 'categories': | |
return [(tag['scheme'], tag['term']) for tag in UserDict.__getitem__(self, 'tags')] | |
realkey = self.keymap.get(key, key) | |
if type(realkey) == types.ListType: | |
for k in realkey: | |
if UserDict.__contains__(self, k): | |
return UserDict.__getitem__(self, k) | |
if UserDict.__contains__(self, key): | |
return UserDict.__getitem__(self, key) | |
return UserDict.__getitem__(self, realkey) | |
def __setitem__(self, key, value): | |
for k in self.keymap.keys(): | |
if key == k: | |
key = self.keymap[k] | |
if type(key) == types.ListType: | |
key = key[0] | |
return UserDict.__setitem__(self, key, value) | |
def get(self, key, default=None): | |
if self.has_key(key): | |
return self[key] | |
else: | |
return default | |
def setdefault(self, key, value): | |
if not self.has_key(key): | |
self[key] = value | |
return self[key] | |
def has_key(self, key): | |
try: | |
return hasattr(self, key) or UserDict.__contains__(self, key) | |
except AttributeError: | |
return False | |
# This alias prevents the 2to3 tool from changing the semantics of the | |
# __contains__ function below and exhausting the maximum recursion depth | |
__has_key = has_key | |
def __getattr__(self, key): | |
try: | |
return self.__dict__[key] | |
except KeyError: | |
pass | |
try: | |
assert not key.startswith('_') | |
return self.__getitem__(key) | |
except: | |
raise AttributeError, "object has no attribute '%s'" % key | |
def __setattr__(self, key, value): | |
if key.startswith('_') or key == 'data': | |
self.__dict__[key] = value | |
else: | |
return self.__setitem__(key, value) | |
def __contains__(self, key): | |
return self.__has_key(key) | |
def zopeCompatibilityHack(): | |
global FeedParserDict | |
del FeedParserDict | |
def FeedParserDict(aDict=None): | |
rc = {} | |
if aDict: | |
rc.update(aDict) | |
return rc | |
_ebcdic_to_ascii_map = None | |
def _ebcdic_to_ascii(s): | |
global _ebcdic_to_ascii_map | |
if not _ebcdic_to_ascii_map: | |
emap = ( | |
0,1,2,3,156,9,134,127,151,141,142,11,12,13,14,15, | |
16,17,18,19,157,133,8,135,24,25,146,143,28,29,30,31, | |
128,129,130,131,132,10,23,27,136,137,138,139,140,5,6,7, | |
144,145,22,147,148,149,150,4,152,153,154,155,20,21,158,26, | |
32,160,161,162,163,164,165,166,167,168,91,46,60,40,43,33, | |
38,169,170,171,172,173,174,175,176,177,93,36,42,41,59,94, | |
45,47,178,179,180,181,182,183,184,185,124,44,37,95,62,63, | |
186,187,188,189,190,191,192,193,194,96,58,35,64,39,61,34, | |
195,97,98,99,100,101,102,103,104,105,196,197,198,199,200,201, | |
202,106,107,108,109,110,111,112,113,114,203,204,205,206,207,208, | |
209,126,115,116,117,118,119,120,121,122,210,211,212,213,214,215, | |
216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231, | |
123,65,66,67,68,69,70,71,72,73,232,233,234,235,236,237, | |
125,74,75,76,77,78,79,80,81,82,238,239,240,241,242,243, | |
92,159,83,84,85,86,87,88,89,90,244,245,246,247,248,249, | |
48,49,50,51,52,53,54,55,56,57,250,251,252,253,254,255 | |
) | |
_ebcdic_to_ascii_map = _maketrans( \ | |
_l2bytes(range(256)), _l2bytes(emap)) | |
return s.translate(_ebcdic_to_ascii_map) | |
_cp1252 = { | |
unichr(128): unichr(8364), # euro sign | |
unichr(130): unichr(8218), # single low-9 quotation mark | |
unichr(131): unichr( 402), # latin small letter f with hook | |
unichr(132): unichr(8222), # double low-9 quotation mark | |
unichr(133): unichr(8230), # horizontal ellipsis | |
unichr(134): unichr(8224), # dagger | |
unichr(135): unichr(8225), # double dagger | |
unichr(136): unichr( 710), # modifier letter circumflex accent | |
unichr(137): unichr(8240), # per mille sign | |
unichr(138): unichr( 352), # latin capital letter s with caron | |
unichr(139): unichr(8249), # single left-pointing angle quotation mark | |
unichr(140): unichr( 338), # latin capital ligature oe | |
unichr(142): unichr( 381), # latin capital letter z with caron | |
unichr(145): unichr(8216), # left single quotation mark | |
unichr(146): unichr(8217), # right single quotation mark | |
unichr(147): unichr(8220), # left double quotation mark | |
unichr(148): unichr(8221), # right double quotation mark | |
unichr(149): unichr(8226), # bullet | |
unichr(150): unichr(8211), # en dash | |
unichr(151): unichr(8212), # em dash | |
unichr(152): unichr( 732), # small tilde | |
unichr(153): unichr(8482), # trade mark sign | |
unichr(154): unichr( 353), # latin small letter s with caron | |
unichr(155): unichr(8250), # single right-pointing angle quotation mark | |
unichr(156): unichr( 339), # latin small ligature oe | |
unichr(158): unichr( 382), # latin small letter z with caron | |
unichr(159): unichr( 376)} # latin capital letter y with diaeresis | |
_urifixer = re.compile('^([A-Za-z][A-Za-z0-9+-.]*://)(/*)(.*?)') | |
def _urljoin(base, uri): | |
uri = _urifixer.sub(r'\1\3', uri) | |
try: | |
return urlparse.urljoin(base, uri) | |
except: | |
uri = urlparse.urlunparse([urllib.quote(part) for part in urlparse.urlparse(uri)]) | |
return urlparse.urljoin(base, uri) | |
class _FeedParserMixin: | |
namespaces = {'': '', | |
'http://backend.userland.com/rss': '', | |
'http://blogs.law.harvard.edu/tech/rss': '', | |
'http://purl.org/rss/1.0/': '', | |
'http://my.netscape.com/rdf/simple/0.9/': '', | |
'http://example.com/newformat#': '', | |
'http://example.com/necho': '', | |
'http://purl.org/echo/': '', | |
'uri/of/echo/namespace#': '', | |
'http://purl.org/pie/': '', | |
'http://purl.org/atom/ns#': '', | |
'http://www.w3.org/2005/Atom': '', | |
'http://purl.org/rss/1.0/modules/rss091#': '', | |
'http://webns.net/mvcb/': 'admin', | |
'http://purl.org/rss/1.0/modules/aggregation/': 'ag', | |
'http://purl.org/rss/1.0/modules/annotate/': 'annotate', | |
'http://media.tangent.org/rss/1.0/': 'audio', | |
'http://backend.userland.com/blogChannelModule': 'blogChannel', | |
'http://web.resource.org/cc/': 'cc', | |
'http://backend.userland.com/creativeCommonsRssModule': 'creativeCommons', | |
'http://purl.org/rss/1.0/modules/company': 'co', | |
'http://purl.org/rss/1.0/modules/content/': 'content', | |
'http://my.theinfo.org/changed/1.0/rss/': 'cp', | |
'http://purl.org/dc/elements/1.1/': 'dc', | |
'http://purl.org/dc/terms/': 'dcterms', | |
'http://purl.org/rss/1.0/modules/email/': 'email', | |
'http://purl.org/rss/1.0/modules/event/': 'ev', | |
'http://rssnamespace.org/feedburner/ext/1.0': 'feedburner', | |
'http://freshmeat.net/rss/fm/': 'fm', | |
'http://xmlns.com/foaf/0.1/': 'foaf', | |
'http://www.w3.org/2003/01/geo/wgs84_pos#': 'geo', | |
'http://postneo.com/icbm/': 'icbm', | |
'http://purl.org/rss/1.0/modules/image/': 'image', | |
'http://www.itunes.com/DTDs/PodCast-1.0.dtd': 'itunes', | |
'http://example.com/DTDs/PodCast-1.0.dtd': 'itunes', | |
'http://purl.org/rss/1.0/modules/link/': 'l', | |
'http://search.yahoo.com/mrss': 'media', | |
#Version 1.1.2 of the Media RSS spec added the trailing slash on the namespace | |
'http://search.yahoo.com/mrss/': 'media', | |
'http://madskills.com/public/xml/rss/module/pingback/': 'pingback', | |
'http://prismstandard.org/namespaces/1.2/basic/': 'prism', | |
'http://www.w3.org/1999/02/22-rdf-syntax-ns#': 'rdf', | |
'http://www.w3.org/2000/01/rdf-schema#': 'rdfs', | |
'http://purl.org/rss/1.0/modules/reference/': 'ref', | |
'http://purl.org/rss/1.0/modules/richequiv/': 'reqv', | |
'http://purl.org/rss/1.0/modules/search/': 'search', | |
'http://purl.org/rss/1.0/modules/slash/': 'slash', | |
'http://schemas.xmlsoap.org/soap/envelope/': 'soap', | |
'http://purl.org/rss/1.0/modules/servicestatus/': 'ss', | |
'http://hacks.benhammersley.com/rss/streaming/': 'str', | |
'http://purl.org/rss/1.0/modules/subscription/': 'sub', | |
'http://purl.org/rss/1.0/modules/syndication/': 'sy', | |
'http://schemas.pocketsoap.com/rss/myDescModule/': 'szf', | |
'http://purl.org/rss/1.0/modules/taxonomy/': 'taxo', | |
'http://purl.org/rss/1.0/modules/threading/': 'thr', | |
'http://purl.org/rss/1.0/modules/textinput/': 'ti', | |
'http://madskills.com/public/xml/rss/module/trackback/':'trackback', | |
'http://wellformedweb.org/commentAPI/': 'wfw', | |
'http://purl.org/rss/1.0/modules/wiki/': 'wiki', | |
'http://www.w3.org/1999/xhtml': 'xhtml', | |
'http://www.w3.org/1999/xlink': 'xlink', | |
'http://www.w3.org/XML/1998/namespace': 'xml' | |
} | |
_matchnamespaces = {} | |
can_be_relative_uri = ['link', 'id', 'wfw_comment', 'wfw_commentrss', 'docs', 'url', 'href', 'comments', 'icon', 'logo'] | |
can_contain_relative_uris = ['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description'] | |
can_contain_dangerous_markup = ['content', 'title', 'summary', 'info', 'tagline', 'subtitle', 'copyright', 'rights', 'description'] | |
html_types = ['text/html', 'application/xhtml+xml'] | |
def __init__(self, baseuri=None, baselang=None, encoding='utf-8'): | |
if _debug: sys.stderr.write('initializing FeedParser\n') | |
if not self._matchnamespaces: | |
for k, v in self.namespaces.items(): | |
self._matchnamespaces[k.lower()] = v | |
self.feeddata = FeedParserDict() # feed-level data | |
self.encoding = encoding # character encoding | |
self.entries = [] # list of entry-level data | |
self.version = '' # feed type/version, see SUPPORTED_VERSIONS | |
self.namespacesInUse = {} # dictionary of namespaces defined by the feed | |
# the following are used internally to track state; | |
# this is really out of control and should be refactored | |
self.infeed = 0 | |
self.inentry = 0 | |
self.incontent = 0 | |
self.intextinput = 0 | |
self.inimage = 0 | |
self.inauthor = 0 | |
self.incontributor = 0 | |
self.inpublisher = 0 | |
self.insource = 0 | |
self.sourcedata = FeedParserDict() | |
self.contentparams = FeedParserDict() | |
self._summaryKey = None | |
self.namespacemap = {} | |
self.elementstack = [] | |
self.basestack = [] | |
self.langstack = [] | |
self.baseuri = baseuri or '' | |
self.lang = baselang or None | |
self.svgOK = 0 | |
self.hasTitle = 0 | |
if baselang: | |
self.feeddata['language'] = baselang.replace('_','-') | |
def unknown_starttag(self, tag, attrs): | |
if _debug: sys.stderr.write('start %s with %s\n' % (tag, attrs)) | |
# normalize attrs | |
attrs = [(k.lower(), v) for k, v in attrs] | |
attrs = [(k, k in ('rel', 'type') and v.lower() or v) for k, v in attrs] | |
# the sgml parser doesn't handle entities in attributes, but | |
# strict xml parsers do -- account for this difference | |
if isinstance(self, _LooseFeedParser): | |
attrs = [(k, v.replace('&', '&')) for k, v in attrs] | |
# track xml:base and xml:lang | |
attrsD = dict(attrs) | |
baseuri = attrsD.get('xml:base', attrsD.get('base')) or self.baseuri | |
if type(baseuri) != type(u''): | |
try: | |
baseuri = unicode(baseuri, self.encoding) | |
except: | |
baseuri = unicode(baseuri, 'iso-8859-1') | |
# ensure that self.baseuri is always an absolute URI that | |
# uses a whitelisted URI scheme (e.g. not `javscript:`) | |
if self.baseuri: | |
self.baseuri = _makeSafeAbsoluteURI(self.baseuri, baseuri) or self.baseuri | |
else: | |
self.baseuri = _urljoin(self.baseuri, baseuri) | |
lang = attrsD.get('xml:lang', attrsD.get('lang')) | |
if lang == '': | |
# xml:lang could be explicitly set to '', we need to capture that | |
lang = None | |
elif lang is None: | |
# if no xml:lang is specified, use parent lang | |
lang = self.lang | |
if lang: | |
if tag in ('feed', 'rss', 'rdf:RDF'): | |
self.feeddata['language'] = lang.replace('_','-') | |
self.lang = lang | |
self.basestack.append(self.baseuri) | |
self.langstack.append(lang) | |
# track namespaces | |
for prefix, uri in attrs: | |
if prefix.startswith('xmlns:'): | |
self.trackNamespace(prefix[6:], uri) | |
elif prefix == 'xmlns': | |
self.trackNamespace(None, uri) | |
# track inline content | |
if self.incontent and self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): | |
if tag in ['xhtml:div', 'div']: return # typepad does this 10/2007 | |
# element declared itself as escaped markup, but it isn't really | |
self.contentparams['type'] = 'application/xhtml+xml' | |
if self.incontent and self.contentparams.get('type') == 'application/xhtml+xml': | |
if tag.find(':') <> -1: | |
prefix, tag = tag.split(':', 1) | |
namespace = self.namespacesInUse.get(prefix, '') | |
if tag=='math' and namespace=='http://www.w3.org/1998/Math/MathML': | |
attrs.append(('xmlns',namespace)) | |
if tag=='svg' and namespace=='http://www.w3.org/2000/svg': | |
attrs.append(('xmlns',namespace)) | |
if tag == 'svg': self.svgOK += 1 | |
return self.handle_data('<%s%s>' % (tag, self.strattrs(attrs)), escape=0) | |
# match namespaces | |
if tag.find(':') <> -1: | |
prefix, suffix = tag.split(':', 1) | |
else: | |
prefix, suffix = '', tag | |
prefix = self.namespacemap.get(prefix, prefix) | |
if prefix: | |
prefix = prefix + '_' | |
# special hack for better tracking of empty textinput/image elements in illformed feeds | |
if (not prefix) and tag not in ('title', 'link', 'description', 'name'): | |
self.intextinput = 0 | |
if (not prefix) and tag not in ('title', 'link', 'description', 'url', 'href', 'width', 'height'): | |
self.inimage = 0 | |
# call special handler (if defined) or default handler | |
methodname = '_start_' + prefix + suffix | |
try: | |
method = getattr(self, methodname) | |
return method(attrsD) | |
except AttributeError: | |
# Since there's no handler or something has gone wrong we explicitly add the element and its attributes | |
unknown_tag = prefix + suffix | |
if len(attrsD) == 0: | |
# No attributes so merge it into the encosing dictionary | |
return self.push(unknown_tag, 1) | |
else: | |
# Has attributes so create it in its own dictionary | |
context = self._getContext() | |
context[unknown_tag] = attrsD | |
def unknown_endtag(self, tag): | |
if _debug: sys.stderr.write('end %s\n' % tag) | |
# match namespaces | |
if tag.find(':') <> -1: | |
prefix, suffix = tag.split(':', 1) | |
else: | |
prefix, suffix = '', tag | |
prefix = self.namespacemap.get(prefix, prefix) | |
if prefix: | |
prefix = prefix + '_' | |
if suffix == 'svg' and self.svgOK: self.svgOK -= 1 | |
# call special handler (if defined) or default handler | |
methodname = '_end_' + prefix + suffix | |
try: | |
if self.svgOK: raise AttributeError() | |
method = getattr(self, methodname) | |
method() | |
except AttributeError: | |
self.pop(prefix + suffix) | |
# track inline content | |
if self.incontent and self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): | |
# element declared itself as escaped markup, but it isn't really | |
if tag in ['xhtml:div', 'div']: return # typepad does this 10/2007 | |
self.contentparams['type'] = 'application/xhtml+xml' | |
if self.incontent and self.contentparams.get('type') == 'application/xhtml+xml': | |
tag = tag.split(':')[-1] | |
self.handle_data('</%s>' % tag, escape=0) | |
# track xml:base and xml:lang going out of scope | |
if self.basestack: | |
self.basestack.pop() | |
if self.basestack and self.basestack[-1]: | |
self.baseuri = self.basestack[-1] | |
if self.langstack: | |
self.langstack.pop() | |
if self.langstack: # and (self.langstack[-1] is not None): | |
self.lang = self.langstack[-1] | |
def handle_charref(self, ref): | |
# called for each character reference, e.g. for ' ', ref will be '160' | |
if not self.elementstack: return | |
ref = ref.lower() | |
if ref in ('34', '38', '39', '60', '62', 'x22', 'x26', 'x27', 'x3c', 'x3e'): | |
text = '&#%s;' % ref | |
else: | |
if ref[0] == 'x': | |
c = int(ref[1:], 16) | |
else: | |
c = int(ref) | |
text = unichr(c).encode('utf-8') | |
self.elementstack[-1][2].append(text) | |
def handle_entityref(self, ref): | |
# called for each entity reference, e.g. for '©', ref will be 'copy' | |
if not self.elementstack: return | |
if _debug: sys.stderr.write('entering handle_entityref with %s\n' % ref) | |
if ref in ('lt', 'gt', 'quot', 'amp', 'apos'): | |
text = '&%s;' % ref | |
elif ref in self.entities.keys(): | |
text = self.entities[ref] | |
if text.startswith('&#') and text.endswith(';'): | |
return self.handle_entityref(text) | |
else: | |
try: name2codepoint[ref] | |
except KeyError: text = '&%s;' % ref | |
else: text = unichr(name2codepoint[ref]).encode('utf-8') | |
self.elementstack[-1][2].append(text) | |
def handle_data(self, text, escape=1): | |
# called for each block of plain text, i.e. outside of any tag and | |
# not containing any character or entity references | |
if not self.elementstack: return | |
if escape and self.contentparams.get('type') == 'application/xhtml+xml': | |
text = _xmlescape(text) | |
self.elementstack[-1][2].append(text) | |
def handle_comment(self, text): | |
# called for each comment, e.g. <!-- insert message here --> | |
pass | |
def handle_pi(self, text): | |
# called for each processing instruction, e.g. <?instruction> | |
pass | |
def handle_decl(self, text): | |
pass | |
def parse_declaration(self, i): | |
# override internal declaration handler to handle CDATA blocks | |
if _debug: sys.stderr.write('entering parse_declaration\n') | |
if self.rawdata[i:i+9] == '<![CDATA[': | |
k = self.rawdata.find(']]>', i) | |
if k == -1: | |
# CDATA block began but didn't finish | |
k = len(self.rawdata) | |
return k | |
self.handle_data(_xmlescape(self.rawdata[i+9:k]), 0) | |
return k+3 | |
else: | |
k = self.rawdata.find('>', i) | |
if k >= 0: | |
return k+1 | |
else: | |
# We have an incomplete CDATA block. | |
return k | |
def mapContentType(self, contentType): | |
contentType = contentType.lower() | |
if contentType == 'text' or contentType == 'plain': | |
contentType = 'text/plain' | |
elif contentType == 'html': | |
contentType = 'text/html' | |
elif contentType == 'xhtml': | |
contentType = 'application/xhtml+xml' | |
return contentType | |
def trackNamespace(self, prefix, uri): | |
loweruri = uri.lower() | |
if (prefix, loweruri) == (None, 'http://my.netscape.com/rdf/simple/0.9/') and not self.version: | |
self.version = 'rss090' | |
if loweruri == 'http://purl.org/rss/1.0/' and not self.version: | |
self.version = 'rss10' | |
if loweruri == 'http://www.w3.org/2005/atom' and not self.version: | |
self.version = 'atom10' | |
if loweruri.find('backend.userland.com/rss') <> -1: | |
# match any backend.userland.com namespace | |
uri = 'http://backend.userland.com/rss' | |
loweruri = uri | |
if self._matchnamespaces.has_key(loweruri): | |
self.namespacemap[prefix] = self._matchnamespaces[loweruri] | |
self.namespacesInUse[self._matchnamespaces[loweruri]] = uri | |
else: | |
self.namespacesInUse[prefix or ''] = uri | |
def resolveURI(self, uri): | |
return _urljoin(self.baseuri or '', uri) | |
def decodeEntities(self, element, data): | |
return data | |
def strattrs(self, attrs): | |
return ''.join([' %s="%s"' % (t[0],_xmlescape(t[1],{'"':'"'})) for t in attrs]) | |
def push(self, element, expectingText): | |
self.elementstack.append([element, expectingText, []]) | |
def pop(self, element, stripWhitespace=1): | |
if not self.elementstack: return | |
if self.elementstack[-1][0] != element: return | |
element, expectingText, pieces = self.elementstack.pop() | |
if self.version == 'atom10' and self.contentparams.get('type','text') == 'application/xhtml+xml': | |
# remove enclosing child element, but only if it is a <div> and | |
# only if all the remaining content is nested underneath it. | |
# This means that the divs would be retained in the following: | |
# <div>foo</div><div>bar</div> | |
while pieces and len(pieces)>1 and not pieces[-1].strip(): | |
del pieces[-1] | |
while pieces and len(pieces)>1 and not pieces[0].strip(): | |
del pieces[0] | |
if pieces and (pieces[0] == '<div>' or pieces[0].startswith('<div ')) and pieces[-1]=='</div>': | |
depth = 0 | |
for piece in pieces[:-1]: | |
if piece.startswith('</'): | |
depth -= 1 | |
if depth == 0: break | |
elif piece.startswith('<') and not piece.endswith('/>'): | |
depth += 1 | |
else: | |
pieces = pieces[1:-1] | |
# Ensure each piece is a str for Python 3 | |
for (i, v) in enumerate(pieces): | |
if not isinstance(v, basestring): | |
pieces[i] = v.decode('utf-8') | |
output = ''.join(pieces) | |
if stripWhitespace: | |
output = output.strip() | |
if not expectingText: return output | |
# decode base64 content | |
if base64 and self.contentparams.get('base64', 0): | |
try: | |
output = _base64decode(output) | |
except binascii.Error: | |
pass | |
except binascii.Incomplete: | |
pass | |
except TypeError: | |
# In Python 3, base64 takes and outputs bytes, not str | |
# This may not be the most correct way to accomplish this | |
output = _base64decode(output.encode('utf-8')).decode('utf-8') | |
# resolve relative URIs | |
if (element in self.can_be_relative_uri) and output: | |
output = self.resolveURI(output) | |
# decode entities within embedded markup | |
if not self.contentparams.get('base64', 0): | |
output = self.decodeEntities(element, output) | |
if self.lookslikehtml(output): | |
self.contentparams['type']='text/html' | |
# remove temporary cruft from contentparams | |
try: | |
del self.contentparams['mode'] | |
except KeyError: | |
pass | |
try: | |
del self.contentparams['base64'] | |
except KeyError: | |
pass | |
is_htmlish = self.mapContentType(self.contentparams.get('type', 'text/html')) in self.html_types | |
# resolve relative URIs within embedded markup | |
if is_htmlish and RESOLVE_RELATIVE_URIS: | |
if element in self.can_contain_relative_uris: | |
output = _resolveRelativeURIs(output, self.baseuri, self.encoding, self.contentparams.get('type', 'text/html')) | |
# parse microformats | |
# (must do this before sanitizing because some microformats | |
# rely on elements that we sanitize) | |
if is_htmlish and element in ['content', 'description', 'summary']: | |
mfresults = _parseMicroformats(output, self.baseuri, self.encoding) | |
if mfresults: | |
for tag in mfresults.get('tags', []): | |
self._addTag(tag['term'], tag['scheme'], tag['label']) | |
for enclosure in mfresults.get('enclosures', []): | |
self._start_enclosure(enclosure) | |
for xfn in mfresults.get('xfn', []): | |
self._addXFN(xfn['relationships'], xfn['href'], xfn['name']) | |
vcard = mfresults.get('vcard') | |
if vcard: | |
self._getContext()['vcard'] = vcard | |
# sanitize embedded markup | |
if is_htmlish and SANITIZE_HTML: | |
if element in self.can_contain_dangerous_markup: | |
output = _sanitizeHTML(output, self.encoding, self.contentparams.get('type', 'text/html')) | |
if self.encoding and type(output) != type(u''): | |
try: | |
output = unicode(output, self.encoding) | |
except: | |
pass | |
# address common error where people take data that is already | |
# utf-8, presume that it is iso-8859-1, and re-encode it. | |
if self.encoding in ('utf-8', 'utf-8_INVALID_PYTHON_3') and type(output) == type(u''): | |
try: | |
output = unicode(output.encode('iso-8859-1'), 'utf-8') | |
except: | |
pass | |
# map win-1252 extensions to the proper code points | |
if type(output) == type(u''): | |
output = u''.join([c in _cp1252.keys() and _cp1252[c] or c for c in output]) | |
# categories/tags/keywords/whatever are handled in _end_category | |
if element == 'category': | |
return output | |
if element == 'title' and self.hasTitle: | |
return output | |
# store output in appropriate place(s) | |
if self.inentry and not self.insource: | |
if element == 'content': | |
self.entries[-1].setdefault(element, []) | |
contentparams = copy.deepcopy(self.contentparams) | |
contentparams['value'] = output | |
self.entries[-1][element].append(contentparams) | |
elif element == 'link': | |
if not self.inimage: | |
# query variables in urls in link elements are improperly | |
# converted from `?a=1&b=2` to `?a=1&b;=2` as if they're | |
# unhandled character references. fix this special case. | |
output = re.sub("&([A-Za-z0-9_]+);", "&\g<1>", output) | |
self.entries[-1][element] = output | |
if output: | |
self.entries[-1]['links'][-1]['href'] = output | |
else: | |
if element == 'description': | |
element = 'summary' | |
self.entries[-1][element] = output | |
if self.incontent: | |
contentparams = copy.deepcopy(self.contentparams) | |
contentparams['value'] = output | |
self.entries[-1][element + '_detail'] = contentparams | |
elif (self.infeed or self.insource):# and (not self.intextinput) and (not self.inimage): | |
context = self._getContext() | |
if element == 'description': | |
element = 'subtitle' | |
context[element] = output | |
if element == 'link': | |
# fix query variables; see above for the explanation | |
output = re.sub("&([A-Za-z0-9_]+);", "&\g<1>", output) | |
context[element] = output | |
context['links'][-1]['href'] = output | |
elif self.incontent: | |
contentparams = copy.deepcopy(self.contentparams) | |
contentparams['value'] = output | |
context[element + '_detail'] = contentparams | |
return output | |
def pushContent(self, tag, attrsD, defaultContentType, expectingText): | |
self.incontent += 1 | |
if self.lang: self.lang=self.lang.replace('_','-') | |
self.contentparams = FeedParserDict({ | |
'type': self.mapContentType(attrsD.get('type', defaultContentType)), | |
'language': self.lang, | |
'base': self.baseuri}) | |
self.contentparams['base64'] = self._isBase64(attrsD, self.contentparams) | |
self.push(tag, expectingText) | |
def popContent(self, tag): | |
value = self.pop(tag) | |
self.incontent -= 1 | |
self.contentparams.clear() | |
return value | |
# a number of elements in a number of RSS variants are nominally plain | |
# text, but this is routinely ignored. This is an attempt to detect | |
# the most common cases. As false positives often result in silent | |
# data loss, this function errs on the conservative side. | |
def lookslikehtml(self, s): | |
if self.version.startswith('atom'): return | |
if self.contentparams.get('type','text/html') != 'text/plain': return | |
# must have a close tag or a entity reference to qualify | |
if not (re.search(r'</(\w+)>',s) or re.search("&#?\w+;",s)): return | |
# all tags must be in a restricted subset of valid HTML tags | |
if filter(lambda t: t.lower() not in _HTMLSanitizer.acceptable_elements, | |
re.findall(r'</?(\w+)',s)): return | |
# all entities must have been defined as valid HTML entities | |
from htmlentitydefs import entitydefs | |
if filter(lambda e: e not in entitydefs.keys(), | |
re.findall(r'&(\w+);',s)): return | |
return 1 | |
def _mapToStandardPrefix(self, name): | |
colonpos = name.find(':') | |
if colonpos <> -1: | |
prefix = name[:colonpos] | |
suffix = name[colonpos+1:] | |
prefix = self.namespacemap.get(prefix, prefix) | |
name = prefix + ':' + suffix | |
return name | |
def _getAttribute(self, attrsD, name): | |
return attrsD.get(self._mapToStandardPrefix(name)) | |
def _isBase64(self, attrsD, contentparams): | |
if attrsD.get('mode', '') == 'base64': | |
return 1 | |
if self.contentparams['type'].startswith('text/'): | |
return 0 | |
if self.contentparams['type'].endswith('+xml'): | |
return 0 | |
if self.contentparams['type'].endswith('/xml'): | |
return 0 | |
return 1 | |
def _itsAnHrefDamnIt(self, attrsD): | |
href = attrsD.get('url', attrsD.get('uri', attrsD.get('href', None))) | |
if href: | |
try: | |
del attrsD['url'] | |
except KeyError: | |
pass | |
try: | |
del attrsD['uri'] | |
except KeyError: | |
pass | |
attrsD['href'] = href | |
return attrsD | |
def _save(self, key, value, overwrite=False): | |
context = self._getContext() | |
if overwrite: | |
context[key] = value | |
else: | |
context.setdefault(key, value) | |
def _start_rss(self, attrsD): | |
versionmap = {'0.91': 'rss091u', | |
'0.92': 'rss092', | |
'0.93': 'rss093', | |
'0.94': 'rss094'} | |
#If we're here then this is an RSS feed. | |
#If we don't have a version or have a version that starts with something | |
#other than RSS then there's been a mistake. Correct it. | |
if not self.version or not self.version.startswith('rss'): | |
attr_version = attrsD.get('version', '') | |
version = versionmap.get(attr_version) | |
if version: | |
self.version = version | |
elif attr_version.startswith('2.'): | |
self.version = 'rss20' | |
else: | |
self.version = 'rss' | |
def _start_dlhottitles(self, attrsD): | |
self.version = 'hotrss' | |
def _start_channel(self, attrsD): | |
self.infeed = 1 | |
self._cdf_common(attrsD) | |
_start_feedinfo = _start_channel | |
def _cdf_common(self, attrsD): | |
if attrsD.has_key('lastmod'): | |
self._start_modified({}) | |
self.elementstack[-1][-1] = attrsD['lastmod'] | |
self._end_modified() | |
if attrsD.has_key('href'): | |
self._start_link({}) | |
self.elementstack[-1][-1] = attrsD['href'] | |
self._end_link() | |
def _start_feed(self, attrsD): | |
self.infeed = 1 | |
versionmap = {'0.1': 'atom01', | |
'0.2': 'atom02', | |
'0.3': 'atom03'} | |
if not self.version: | |
attr_version = attrsD.get('version') | |
version = versionmap.get(attr_version) | |
if version: | |
self.version = version | |
else: | |
self.version = 'atom' | |
def _end_channel(self): | |
self.infeed = 0 | |
_end_feed = _end_channel | |
def _start_image(self, attrsD): | |
context = self._getContext() | |
if not self.inentry: | |
context.setdefault('image', FeedParserDict()) | |
self.inimage = 1 | |
self.hasTitle = 0 | |
self.push('image', 0) | |
def _end_image(self): | |
self.pop('image') | |
self.inimage = 0 | |
def _start_textinput(self, attrsD): | |
context = self._getContext() | |
context.setdefault('textinput', FeedParserDict()) | |
self.intextinput = 1 | |
self.hasTitle = 0 | |
self.push('textinput', 0) | |
_start_textInput = _start_textinput | |
def _end_textinput(self): | |
self.pop('textinput') | |
self.intextinput = 0 | |
_end_textInput = _end_textinput | |
def _start_author(self, attrsD): | |
self.inauthor = 1 | |
self.push('author', 1) | |
# Append a new FeedParserDict when expecting an author | |
context = self._getContext() | |
context.setdefault('authors', []) | |
context['authors'].append(FeedParserDict()) | |
_start_managingeditor = _start_author | |
_start_dc_author = _start_author | |
_start_dc_creator = _start_author | |
_start_itunes_author = _start_author | |
def _end_author(self): | |
self.pop('author') | |
self.inauthor = 0 | |
self._sync_author_detail() | |
_end_managingeditor = _end_author | |
_end_dc_author = _end_author | |
_end_dc_creator = _end_author | |
_end_itunes_author = _end_author | |
def _start_itunes_owner(self, attrsD): | |
self.inpublisher = 1 | |
self.push('publisher', 0) | |
def _end_itunes_owner(self): | |
self.pop('publisher') | |
self.inpublisher = 0 | |
self._sync_author_detail('publisher') | |
def _start_contributor(self, attrsD): | |
self.incontributor = 1 | |
context = self._getContext() | |
context.setdefault('contributors', []) | |
context['contributors'].append(FeedParserDict()) | |
self.push('contributor', 0) | |
def _end_contributor(self): | |
self.pop('contributor') | |
self.incontributor = 0 | |
def _start_dc_contributor(self, attrsD): | |
self.incontributor = 1 | |
context = self._getContext() | |
context.setdefault('contributors', []) | |
context['contributors'].append(FeedParserDict()) | |
self.push('name', 0) | |
def _end_dc_contributor(self): | |
self._end_name() | |
self.incontributor = 0 | |
def _start_name(self, attrsD): | |
self.push('name', 0) | |
_start_itunes_name = _start_name | |
def _end_name(self): | |
value = self.pop('name') | |
if self.inpublisher: | |
self._save_author('name', value, 'publisher') | |
elif self.inauthor: | |
self._save_author('name', value) | |
elif self.incontributor: | |
self._save_contributor('name', value) | |
elif self.intextinput: | |
context = self._getContext() | |
context['name'] = value | |
_end_itunes_name = _end_name | |
def _start_width(self, attrsD): | |
self.push('width', 0) | |
def _end_width(self): | |
value = self.pop('width') | |
try: | |
value = int(value) | |
except: | |
value = 0 | |
if self.inimage: | |
context = self._getContext() | |
context['width'] = value | |
def _start_height(self, attrsD): | |
self.push('height', 0) | |
def _end_height(self): | |
value = self.pop('height') | |
try: | |
value = int(value) | |
except: | |
value = 0 | |
if self.inimage: | |
context = self._getContext() | |
context['height'] = value | |
def _start_url(self, attrsD): | |
self.push('href', 1) | |
_start_homepage = _start_url | |
_start_uri = _start_url | |
def _end_url(self): | |
value = self.pop('href') | |
if self.inauthor: | |
self._save_author('href', value) | |
elif self.incontributor: | |
self._save_contributor('href', value) | |
_end_homepage = _end_url | |
_end_uri = _end_url | |
def _start_email(self, attrsD): | |
self.push('email', 0) | |
_start_itunes_email = _start_email | |
def _end_email(self): | |
value = self.pop('email') | |
if self.inpublisher: | |
self._save_author('email', value, 'publisher') | |
elif self.inauthor: | |
self._save_author('email', value) | |
elif self.incontributor: | |
self._save_contributor('email', value) | |
_end_itunes_email = _end_email | |
def _getContext(self): | |
if self.insource: | |
context = self.sourcedata | |
elif self.inimage and self.feeddata.has_key('image'): | |
context = self.feeddata['image'] | |
elif self.intextinput: | |
context = self.feeddata['textinput'] | |
elif self.inentry: | |
context = self.entries[-1] | |
else: | |
context = self.feeddata | |
return context | |
def _save_author(self, key, value, prefix='author'): | |
context = self._getContext() | |
context.setdefault(prefix + '_detail', FeedParserDict()) | |
context[prefix + '_detail'][key] = value | |
self._sync_author_detail() | |
context.setdefault('authors', [FeedParserDict()]) | |
context['authors'][-1][key] = value | |
def _save_contributor(self, key, value): | |
context = self._getContext() | |
context.setdefault('contributors', [FeedParserDict()]) | |
context['contributors'][-1][key] = value | |
def _sync_author_detail(self, key='author'): | |
context = self._getContext() | |
detail = context.get('%s_detail' % key) | |
if detail: | |
name = detail.get('name') | |
email = detail.get('email') | |
if name and email: | |
context[key] = '%s (%s)' % (name, email) | |
elif name: | |
context[key] = name | |
elif email: | |
context[key] = email | |
else: | |
author, email = context.get(key), None | |
if not author: return | |
emailmatch = re.search(r'''(([a-zA-Z0-9\_\-\.\+]+)@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.)|(([a-zA-Z0-9\-]+\.)+))([a-zA-Z]{2,4}|[0-9]{1,3})(\]?))(\?subject=\S+)?''', author) | |
if emailmatch: | |
email = emailmatch.group(0) | |
# probably a better way to do the following, but it passes all the tests | |
author = author.replace(email, '') | |
author = author.replace('()', '') | |
author = author.replace('<>', '') | |
author = author.replace('<>', '') | |
author = author.strip() | |
if author and (author[0] == '('): | |
author = author[1:] | |
if author and (author[-1] == ')'): | |
author = author[:-1] | |
author = author.strip() | |
if author or email: | |
context.setdefault('%s_detail' % key, FeedParserDict()) | |
if author: | |
context['%s_detail' % key]['name'] = author | |
if email: | |
context['%s_detail' % key]['email'] = email | |
def _start_subtitle(self, attrsD): | |
self.pushContent('subtitle', attrsD, 'text/plain', 1) | |
_start_tagline = _start_subtitle | |
_start_itunes_subtitle = _start_subtitle | |
def _end_subtitle(self): | |
self.popContent('subtitle') | |
_end_tagline = _end_subtitle | |
_end_itunes_subtitle = _end_subtitle | |
def _start_rights(self, attrsD): | |
self.pushContent('rights', attrsD, 'text/plain', 1) | |
_start_dc_rights = _start_rights | |
_start_copyright = _start_rights | |
def _end_rights(self): | |
self.popContent('rights') | |
_end_dc_rights = _end_rights | |
_end_copyright = _end_rights | |
def _start_item(self, attrsD): | |
self.entries.append(FeedParserDict()) | |
self.push('item', 0) | |
self.inentry = 1 | |
self.guidislink = 0 | |
self.hasTitle = 0 | |
id = self._getAttribute(attrsD, 'rdf:about') | |
if id: | |
context = self._getContext() | |
context['id'] = id | |
self._cdf_common(attrsD) | |
_start_entry = _start_item | |
_start_product = _start_item | |
def _end_item(self): | |
self.pop('item') | |
self.inentry = 0 | |
_end_entry = _end_item | |
def _start_dc_language(self, attrsD): | |
self.push('language', 1) | |
_start_language = _start_dc_language | |
def _end_dc_language(self): | |
self.lang = self.pop('language') | |
_end_language = _end_dc_language | |
def _start_dc_publisher(self, attrsD): | |
self.push('publisher', 1) | |
_start_webmaster = _start_dc_publisher | |
def _end_dc_publisher(self): | |
self.pop('publisher') | |
self._sync_author_detail('publisher') | |
_end_webmaster = _end_dc_publisher | |
def _start_published(self, attrsD): | |
self.push('published', 1) | |
_start_dcterms_issued = _start_published | |
_start_issued = _start_published | |
def _end_published(self): | |
value = self.pop('published') | |
self._save('published_parsed', _parse_date(value), overwrite=True) | |
_end_dcterms_issued = _end_published | |
_end_issued = _end_published | |
def _start_updated(self, attrsD): | |
self.push('updated', 1) | |
_start_modified = _start_updated | |
_start_dcterms_modified = _start_updated | |
_start_pubdate = _start_updated | |
_start_dc_date = _start_updated | |
_start_lastbuilddate = _start_updated | |
def _end_updated(self): | |
value = self.pop('updated') | |
parsed_value = _parse_date(value) | |
self._save('updated_parsed', parsed_value, overwrite=True) | |
_end_modified = _end_updated | |
_end_dcterms_modified = _end_updated | |
_end_pubdate = _end_updated | |
_end_dc_date = _end_updated | |
_end_lastbuilddate = _end_updated | |
def _start_created(self, attrsD): | |
self.push('created', 1) | |
_start_dcterms_created = _start_created | |
def _end_created(self): | |
value = self.pop('created') | |
self._save('created_parsed', _parse_date(value), overwrite=True) | |
_end_dcterms_created = _end_created | |
def _start_expirationdate(self, attrsD): | |
self.push('expired', 1) | |
def _end_expirationdate(self): | |
self._save('expired_parsed', _parse_date(self.pop('expired')), overwrite=True) | |
def _start_cc_license(self, attrsD): | |
context = self._getContext() | |
value = self._getAttribute(attrsD, 'rdf:resource') | |
attrsD = FeedParserDict() | |
attrsD['rel']='license' | |
if value: attrsD['href']=value | |
context.setdefault('links', []).append(attrsD) | |
def _start_creativecommons_license(self, attrsD): | |
self.push('license', 1) | |
_start_creativeCommons_license = _start_creativecommons_license | |
def _end_creativecommons_license(self): | |
value = self.pop('license') | |
context = self._getContext() | |
attrsD = FeedParserDict() | |
attrsD['rel']='license' | |
if value: attrsD['href']=value | |
context.setdefault('links', []).append(attrsD) | |
del context['license'] | |
_end_creativeCommons_license = _end_creativecommons_license | |
def _addXFN(self, relationships, href, name): | |
context = self._getContext() | |
xfn = context.setdefault('xfn', []) | |
value = FeedParserDict({'relationships': relationships, 'href': href, 'name': name}) | |
if value not in xfn: | |
xfn.append(value) | |
def _addTag(self, term, scheme, label): | |
context = self._getContext() | |
tags = context.setdefault('tags', []) | |
if (not term) and (not scheme) and (not label): return | |
value = FeedParserDict({'term': term, 'scheme': scheme, 'label': label}) | |
if value not in tags: | |
tags.append(value) | |
def _start_category(self, attrsD): | |
if _debug: sys.stderr.write('entering _start_category with %s\n' % repr(attrsD)) | |
term = attrsD.get('term') | |
scheme = attrsD.get('scheme', attrsD.get('domain')) | |
label = attrsD.get('label') | |
self._addTag(term, scheme, label) | |
self.push('category', 1) | |
_start_dc_subject = _start_category | |
_start_keywords = _start_category | |
def _start_media_category(self, attrsD): | |
attrsD.setdefault('scheme', 'http://search.yahoo.com/mrss/category_schema') | |
self._start_category(attrsD) | |
def _end_itunes_keywords(self): | |
for term in self.pop('itunes_keywords').split(): | |
self._addTag(term, 'http://www.itunes.com/', None) | |
def _start_itunes_category(self, attrsD): | |
self._addTag(attrsD.get('text'), 'http://www.itunes.com/', None) | |
self.push('category', 1) | |
def _end_category(self): | |
value = self.pop('category') | |
if not value: return | |
context = self._getContext() | |
tags = context['tags'] | |
if value and len(tags) and not tags[-1]['term']: | |
tags[-1]['term'] = value | |
else: | |
self._addTag(value, None, None) | |
_end_dc_subject = _end_category | |
_end_keywords = _end_category | |
_end_itunes_category = _end_category | |
_end_media_category = _end_category | |
def _start_cloud(self, attrsD): | |
self._getContext()['cloud'] = FeedParserDict(attrsD) | |
def _start_link(self, attrsD): | |
attrsD.setdefault('rel', 'alternate') | |
if attrsD['rel'] == 'self': | |
attrsD.setdefault('type', 'application/atom+xml') | |
else: | |
attrsD.setdefault('type', 'text/html') | |
context = self._getContext() | |
attrsD = self._itsAnHrefDamnIt(attrsD) | |
if attrsD.has_key('href'): | |
attrsD['href'] = self.resolveURI(attrsD['href']) | |
expectingText = self.infeed or self.inentry or self.insource | |
context.setdefault('links', []) | |
if not (self.inentry and self.inimage): | |
context['links'].append(FeedParserDict(attrsD)) | |
if attrsD.has_key('href'): | |
expectingText = 0 | |
if (attrsD.get('rel') == 'alternate') and (self.mapContentType(attrsD.get('type')) in self.html_types): | |
context['link'] = attrsD['href'] | |
else: | |
self.push('link', expectingText) | |
_start_producturl = _start_link | |
def _end_link(self): | |
value = self.pop('link') | |
context = self._getContext() | |
_end_producturl = _end_link | |
def _start_guid(self, attrsD): | |
self.guidislink = (attrsD.get('ispermalink', 'true') == 'true') | |
self.push('id', 1) | |
def _end_guid(self): | |
value = self.pop('id') | |
self._save('guidislink', self.guidislink and not self._getContext().has_key('link')) | |
if self.guidislink: | |
# guid acts as link, but only if 'ispermalink' is not present or is 'true', | |
# and only if the item doesn't already have a link element | |
self._save('link', value) | |
def _start_title(self, attrsD): | |
if self.svgOK: return self.unknown_starttag('title', attrsD.items()) | |
self.pushContent('title', attrsD, 'text/plain', self.infeed or self.inentry or self.insource) | |
_start_dc_title = _start_title | |
_start_media_title = _start_title | |
def _end_title(self): | |
if self.svgOK: return | |
value = self.popContent('title') | |
if not value: return | |
context = self._getContext() | |
self.hasTitle = 1 | |
_end_dc_title = _end_title | |
def _end_media_title(self): | |
hasTitle = self.hasTitle | |
self._end_title() | |
self.hasTitle = hasTitle | |
def _start_description(self, attrsD): | |
context = self._getContext() | |
if context.has_key('summary'): | |
self._summaryKey = 'content' | |
self._start_content(attrsD) | |
else: | |
self.pushContent('description', attrsD, 'text/html', self.infeed or self.inentry or self.insource) | |
_start_dc_description = _start_description | |
def _start_abstract(self, attrsD): | |
self.pushContent('description', attrsD, 'text/plain', self.infeed or self.inentry or self.insource) | |
def _end_description(self): | |
if self._summaryKey == 'content': | |
self._end_content() | |
else: | |
value = self.popContent('description') | |
self._summaryKey = None | |
_end_abstract = _end_description | |
_end_dc_description = _end_description | |
def _start_info(self, attrsD): | |
self.pushContent('info', attrsD, 'text/plain', 1) | |
_start_feedburner_browserfriendly = _start_info | |
def _end_info(self): | |
self.popContent('info') | |
_end_feedburner_browserfriendly = _end_info | |
def _start_generator(self, attrsD): | |
if attrsD: | |
attrsD = self._itsAnHrefDamnIt(attrsD) | |
if attrsD.has_key('href'): | |
attrsD['href'] = self.resolveURI(attrsD['href']) | |
self._getContext()['generator_detail'] = FeedParserDict(attrsD) | |
self.push('generator', 1) | |
def _end_generator(self): | |
value = self.pop('generator') | |
context = self._getContext() | |
if context.has_key('generator_detail'): | |
context['generator_detail']['name'] = value | |
def _start_admin_generatoragent(self, attrsD): | |
self.push('generator', 1) | |
value = self._getAttribute(attrsD, 'rdf:resource') | |
if value: | |
self.elementstack[-1][2].append(value) | |
self.pop('generator') | |
self._getContext()['generator_detail'] = FeedParserDict({'href': value}) | |
def _start_admin_errorreportsto(self, attrsD): | |
self.push('errorreportsto', 1) | |
value = self._getAttribute(attrsD, 'rdf:resource') | |
if value: | |
self.elementstack[-1][2].append(value) | |
self.pop('errorreportsto') | |
def _start_summary(self, attrsD): | |
context = self._getContext() | |
if context.has_key('summary'): | |
self._summaryKey = 'content' | |
self._start_content(attrsD) | |
else: | |
self._summaryKey = 'summary' | |
self.pushContent(self._summaryKey, attrsD, 'text/plain', 1) | |
_start_itunes_summary = _start_summary | |
def _end_summary(self): | |
if self._summaryKey == 'content': | |
self._end_content() | |
else: | |
self.popContent(self._summaryKey or 'summary') | |
self._summaryKey = None | |
_end_itunes_summary = _end_summary | |
def _start_enclosure(self, attrsD): | |
attrsD = self._itsAnHrefDamnIt(attrsD) | |
context = self._getContext() | |
attrsD['rel']='enclosure' | |
context.setdefault('links', []).append(FeedParserDict(attrsD)) | |
def _start_source(self, attrsD): | |
if 'url' in attrsD: | |
# This means that we're processing a source element from an RSS 2.0 feed | |
self.sourcedata['href'] = attrsD[u'url'] | |
self.push('source', 1) | |
self.insource = 1 | |
self.hasTitle = 0 | |
def _end_source(self): | |
self.insource = 0 | |
value = self.pop('source') | |
if value: | |
self.sourcedata['title'] = value | |
self._getContext()['source'] = copy.deepcopy(self.sourcedata) | |
self.sourcedata.clear() | |
def _start_content(self, attrsD): | |
self.pushContent('content', attrsD, 'text/plain', 1) | |
src = attrsD.get('src') | |
if src: | |
self.contentparams['src'] = src | |
self.push('content', 1) | |
def _start_prodlink(self, attrsD): | |
self.pushContent('content', attrsD, 'text/html', 1) | |
def _start_body(self, attrsD): | |
self.pushContent('content', attrsD, 'application/xhtml+xml', 1) | |
_start_xhtml_body = _start_body | |
def _start_content_encoded(self, attrsD): | |
self.pushContent('content', attrsD, 'text/html', 1) | |
_start_fullitem = _start_content_encoded | |
def _end_content(self): | |
copyToSummary = self.mapContentType(self.contentparams.get('type')) in (['text/plain'] + self.html_types) | |
value = self.popContent('content') | |
if copyToSummary: | |
self._save('summary', value) | |
_end_body = _end_content | |
_end_xhtml_body = _end_content | |
_end_content_encoded = _end_content | |
_end_fullitem = _end_content | |
_end_prodlink = _end_content | |
def _start_itunes_image(self, attrsD): | |
self.push('itunes_image', 0) | |
if attrsD.get('href'): | |
self._getContext()['image'] = FeedParserDict({'href': attrsD.get('href')}) | |
_start_itunes_link = _start_itunes_image | |
def _end_itunes_block(self): | |
value = self.pop('itunes_block', 0) | |
self._getContext()['itunes_block'] = (value == 'yes') and 1 or 0 | |
def _end_itunes_explicit(self): | |
value = self.pop('itunes_explicit', 0) | |
# Convert 'yes' -> True, 'clean' to False, and any other value to None | |
# False and None both evaluate as False, so the difference can be ignored | |
# by applications that only need to know if the content is explicit. | |
self._getContext()['itunes_explicit'] = (None, False, True)[(value == 'yes' and 2) or value == 'clean' or 0] | |
def _start_media_content(self, attrsD): | |
context = self._getContext() | |
context.setdefault('media_content', []) | |
context['media_content'].append(attrsD) | |
def _start_media_thumbnail(self, attrsD): | |
context = self._getContext() | |
context.setdefault('media_thumbnail', []) | |
self.push('url', 1) # new | |
context['media_thumbnail'].append(attrsD) | |
def _end_media_thumbnail(self): | |
url = self.pop('url') | |
context = self._getContext() | |
if url != None and len(url.strip()) != 0: | |
if not context['media_thumbnail'][-1].has_key('url'): | |
context['media_thumbnail'][-1]['url'] = url | |
def _start_media_player(self, attrsD): | |
self.push('media_player', 0) | |
self._getContext()['media_player'] = FeedParserDict(attrsD) | |
def _end_media_player(self): | |
value = self.pop('media_player') | |
context = self._getContext() | |
context['media_player']['content'] = value | |
def _start_newlocation(self, attrsD): | |
self.push('newlocation', 1) | |
def _end_newlocation(self): | |
url = self.pop('newlocation') | |
context = self._getContext() | |
# don't set newlocation if the context isn't right | |
if context is not self.feeddata: | |
return | |
context['newlocation'] = _makeSafeAbsoluteURI(self.baseuri, url.strip()) | |
if _XML_AVAILABLE: | |
class _StrictFeedParser(_FeedParserMixin, xml.sax.handler.ContentHandler): | |
def __init__(self, baseuri, baselang, encoding): | |
if _debug: sys.stderr.write('trying StrictFeedParser\n') | |
xml.sax.handler.ContentHandler.__init__(self) | |
_FeedParserMixin.__init__(self, baseuri, baselang, encoding) | |
self.bozo = 0 | |
self.exc = None | |
self.decls = {} | |
def startPrefixMapping(self, prefix, uri): | |
self.trackNamespace(prefix, uri) | |
if uri == 'http://www.w3.org/1999/xlink': | |
self.decls['xmlns:'+prefix] = uri | |
def startElementNS(self, name, qname, attrs): | |
namespace, localname = name | |
lowernamespace = str(namespace or '').lower() | |
if lowernamespace.find('backend.userland.com/rss') <> -1: | |
# match any backend.userland.com namespace | |
namespace = 'http://backend.userland.com/rss' | |
lowernamespace = namespace | |
if qname and qname.find(':') > 0: | |
givenprefix = qname.split(':')[0] | |
else: | |
givenprefix = None | |
prefix = self._matchnamespaces.get(lowernamespace, givenprefix) | |
if givenprefix and (prefix == None or (prefix == '' and lowernamespace == '')) and not self.namespacesInUse.has_key(givenprefix): | |
raise UndeclaredNamespace, "'%s' is not associated with a namespace" % givenprefix | |
localname = str(localname).lower() | |
# qname implementation is horribly broken in Python 2.1 (it | |
# doesn't report any), and slightly broken in Python 2.2 (it | |
# doesn't report the xml: namespace). So we match up namespaces | |
# with a known list first, and then possibly override them with | |
# the qnames the SAX parser gives us (if indeed it gives us any | |
# at all). Thanks to MatejC for helping me test this and | |
# tirelessly telling me that it didn't work yet. | |
attrsD, self.decls = self.decls, {} | |
if localname=='math' and namespace=='http://www.w3.org/1998/Math/MathML': | |
attrsD['xmlns']=namespace | |
if localname=='svg' and namespace=='http://www.w3.org/2000/svg': | |
attrsD['xmlns']=namespace | |
if prefix: | |
localname = prefix.lower() + ':' + localname | |
elif namespace and not qname: #Expat | |
for name,value in self.namespacesInUse.items(): | |
if name and value == namespace: | |
localname = name + ':' + localname | |
break | |
if _debug: sys.stderr.write('startElementNS: qname = %s, namespace = %s, givenprefix = %s, prefix = %s, attrs = %s, localname = %s\n' % (qname, namespace, givenprefix, prefix, attrs.items(), localname)) | |
for (namespace, attrlocalname), attrvalue in attrs._attrs.items(): | |
lowernamespace = (namespace or '').lower() | |
prefix = self._matchnamespaces.get(lowernamespace, '') | |
if prefix: | |
attrlocalname = prefix + ':' + attrlocalname | |
attrsD[str(attrlocalname).lower()] = attrvalue | |
for qname in attrs.getQNames(): | |
attrsD[str(qname).lower()] = attrs.getValueByQName(qname) | |
self.unknown_starttag(localname, attrsD.items()) | |
def characters(self, text): | |
self.handle_data(text) | |
def endElementNS(self, name, qname): | |
namespace, localname = name | |
lowernamespace = str(namespace or '').lower() | |
if qname and qname.find(':') > 0: | |
givenprefix = qname.split(':')[0] | |
else: | |
givenprefix = '' | |
prefix = self._matchnamespaces.get(lowernamespace, givenprefix) | |
if prefix: | |
localname = prefix + ':' + localname | |
elif namespace and not qname: #Expat | |
for name,value in self.namespacesInUse.items(): | |
if name and value == namespace: | |
localname = name + ':' + localname | |
break | |
localname = str(localname).lower() | |
self.unknown_endtag(localname) | |
def error(self, exc): | |
self.bozo = 1 | |
self.exc = exc | |
def fatalError(self, exc): | |
self.error(exc) | |
raise exc | |
class _BaseHTMLProcessor(sgmllib.SGMLParser): | |
special = re.compile('''[<>'"]''') | |
bare_ampersand = re.compile("&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)") | |
elements_no_end_tag = [ | |
'area', 'base', 'basefont', 'br', 'col', 'command', 'embed', 'frame', | |
'hr', 'img', 'input', 'isindex', 'keygen', 'link', 'meta', 'param', | |
'source', 'track', 'wbr' | |
] | |
def __init__(self, encoding, _type): | |
self.encoding = encoding | |
self._type = _type | |
if _debug: sys.stderr.write('entering BaseHTMLProcessor, encoding=%s\n' % self.encoding) | |
sgmllib.SGMLParser.__init__(self) | |
def reset(self): | |
self.pieces = [] | |
sgmllib.SGMLParser.reset(self) | |
def _shorttag_replace(self, match): | |
tag = match.group(1) | |
if tag in self.elements_no_end_tag: | |
return '<' + tag + ' />' | |
else: | |
return '<' + tag + '></' + tag + '>' | |
def parse_starttag(self,i): | |
j=sgmllib.SGMLParser.parse_starttag(self, i) | |
if self._type == 'application/xhtml+xml': | |
if j>2 and self.rawdata[j-2:j]=='/>': | |
self.unknown_endtag(self.lasttag) | |
return j | |
def feed(self, data): | |
data = re.compile(r'<!((?!DOCTYPE|--|\[))', re.IGNORECASE).sub(r'<!\1', data) | |
#data = re.sub(r'<(\S+?)\s*?/>', self._shorttag_replace, data) # bug [ 1399464 ] Bad regexp for _shorttag_replace | |
data = re.sub(r'<([^<>\s]+?)\s*/>', self._shorttag_replace, data) | |
data = data.replace(''', "'") | |
data = data.replace('"', '"') | |
try: | |
bytes | |
if bytes is str: | |
raise NameError | |
self.encoding = self.encoding + '_INVALID_PYTHON_3' | |
except NameError: | |
if self.encoding and type(data) == type(u''): | |
data = data.encode(self.encoding) | |
sgmllib.SGMLParser.feed(self, data) | |
sgmllib.SGMLParser.close(self) | |
def normalize_attrs(self, attrs): | |
if not attrs: return attrs | |
# utility method to be called by descendants | |
attrs = dict([(k.lower(), v) for k, v in attrs]).items() | |
attrs = [(k, k in ('rel', 'type') and v.lower() or v) for k, v in attrs] | |
attrs.sort() | |
return attrs | |
def unknown_starttag(self, tag, attrs): | |
# called for each start tag | |
# attrs is a list of (attr, value) tuples | |
# e.g. for <pre class='screen'>, tag='pre', attrs=[('class', 'screen')] | |
if _debug: sys.stderr.write('_BaseHTMLProcessor, unknown_starttag, tag=%s\n' % tag) | |
uattrs = [] | |
strattrs='' | |
if attrs: | |
for key, value in attrs: | |
value=value.replace('>','>').replace('<','<').replace('"','"') | |
value = self.bare_ampersand.sub("&", value) | |
# thanks to Kevin Marks for this breathtaking hack to deal with (valid) high-bit attribute values in UTF-8 feeds | |
if type(value) != type(u''): | |
try: | |
value = unicode(value, self.encoding) | |
except: | |
value = unicode(value, 'iso-8859-1') | |
try: | |
# Currently, in Python 3 the key is already a str, and cannot be decoded again | |
uattrs.append((unicode(key, self.encoding), value)) | |
except TypeError: | |
uattrs.append((key, value)) | |
strattrs = u''.join([u' %s="%s"' % (key, value) for key, value in uattrs]) | |
if self.encoding: | |
try: | |
strattrs=strattrs.encode(self.encoding) | |
except: | |
pass | |
if tag in self.elements_no_end_tag: | |
self.pieces.append('<%(tag)s%(strattrs)s />' % locals()) | |
else: | |
self.pieces.append('<%(tag)s%(strattrs)s>' % locals()) | |
def unknown_endtag(self, tag): | |
# called for each end tag, e.g. for </pre>, tag will be 'pre' | |
# Reconstruct the original end tag. | |
if tag not in self.elements_no_end_tag: | |
self.pieces.append("</%(tag)s>" % locals()) | |
def handle_charref(self, ref): | |
# called for each character reference, e.g. for ' ', ref will be '160' | |
# Reconstruct the original character reference. | |
if ref.startswith('x'): | |
value = unichr(int(ref[1:],16)) | |
else: | |
value = unichr(int(ref)) | |
if value in _cp1252.keys(): | |
self.pieces.append('&#%s;' % hex(ord(_cp1252[value]))[1:]) | |
else: | |
self.pieces.append('&#%(ref)s;' % locals()) | |
def handle_entityref(self, ref): | |
# called for each entity reference, e.g. for '©', ref will be 'copy' | |
# Reconstruct the original entity reference. | |
if name2codepoint.has_key(ref): | |
self.pieces.append('&%(ref)s;' % locals()) | |
else: | |
self.pieces.append('&%(ref)s' % locals()) | |
def handle_data(self, text): | |
# called for each block of plain text, i.e. outside of any tag and | |
# not containing any character or entity references | |
# Store the original text verbatim. | |
if _debug: sys.stderr.write('_BaseHTMLProcessor, handle_data, text=%s\n' % text) | |
self.pieces.append(text) | |
def handle_comment(self, text): | |
# called for each HTML comment, e.g. <!-- insert Javascript code here --> | |
# Reconstruct the original comment. | |
self.pieces.append('<!--%(text)s-->' % locals()) | |
def handle_pi(self, text): | |
# called for each processing instruction, e.g. <?instruction> | |
# Reconstruct original processing instruction. | |
self.pieces.append('<?%(text)s>' % locals()) | |
def handle_decl(self, text): | |
# called for the DOCTYPE, if present, e.g. | |
# <!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN" | |
# "http://www.w3.org/TR/html4/loose.dtd"> | |
# Reconstruct original DOCTYPE | |
self.pieces.append('<!%(text)s>' % locals()) | |
_new_declname_match = re.compile(r'[a-zA-Z][-_.a-zA-Z0-9:]*\s*').match | |
def _scan_name(self, i, declstartpos): | |
rawdata = self.rawdata | |
n = len(rawdata) | |
if i == n: | |
return None, -1 | |
m = self._new_declname_match(rawdata, i) | |
if m: | |
s = m.group() | |
name = s.strip() | |
if (i + len(s)) == n: | |
return None, -1 # end of buffer | |
return name.lower(), m.end() | |
else: | |
self.handle_data(rawdata) | |
# self.updatepos(declstartpos, i) | |
return None, -1 | |
def convert_charref(self, name): | |
return '&#%s;' % name | |
def convert_entityref(self, name): | |
return '&%s;' % name | |
def output(self): | |
'''Return processed HTML as a single string''' | |
return ''.join([str(p) for p in self.pieces]) | |
def parse_declaration(self, i): | |
try: | |
return sgmllib.SGMLParser.parse_declaration(self, i) | |
except sgmllib.SGMLParseError: | |
# escape the doctype declaration and continue parsing | |
self.handle_data('<') | |
return i+1 | |
class _LooseFeedParser(_FeedParserMixin, _BaseHTMLProcessor): | |
def __init__(self, baseuri, baselang, encoding, entities): | |
sgmllib.SGMLParser.__init__(self) | |
_FeedParserMixin.__init__(self, baseuri, baselang, encoding) | |
_BaseHTMLProcessor.__init__(self, encoding, 'application/xhtml+xml') | |
self.entities=entities | |
def decodeEntities(self, element, data): | |
data = data.replace('<', '<') | |
data = data.replace('<', '<') | |
data = data.replace('<', '<') | |
data = data.replace('>', '>') | |
data = data.replace('>', '>') | |
data = data.replace('>', '>') | |
data = data.replace('&', '&') | |
data = data.replace('&', '&') | |
data = data.replace('"', '"') | |
data = data.replace('"', '"') | |
data = data.replace(''', ''') | |
data = data.replace(''', ''') | |
if self.contentparams.has_key('type') and not self.contentparams.get('type', 'xml').endswith('xml'): | |
data = data.replace('<', '<') | |
data = data.replace('>', '>') | |
data = data.replace('&', '&') | |
data = data.replace('"', '"') | |
data = data.replace(''', "'") | |
return data | |
def strattrs(self, attrs): | |
return ''.join([' %s="%s"' % (n,v.replace('"','"')) for n,v in attrs]) | |
class _MicroformatsParser: | |
STRING = 1 | |
DATE = 2 | |
URI = 3 | |
NODE = 4 | |
EMAIL = 5 | |
known_xfn_relationships = ['contact', 'acquaintance', 'friend', 'met', 'co-worker', 'coworker', 'colleague', 'co-resident', 'coresident', 'neighbor', 'child', 'parent', 'sibling', 'brother', 'sister', 'spouse', 'wife', 'husband', 'kin', 'relative', 'muse', 'crush', 'date', 'sweetheart', 'me'] | |
known_binary_extensions = ['zip','rar','exe','gz','tar','tgz','tbz2','bz2','z','7z','dmg','img','sit','sitx','hqx','deb','rpm','bz2','jar','rar','iso','bin','msi','mp2','mp3','ogg','ogm','mp4','m4v','m4a','avi','wma','wmv'] | |
def __init__(self, data, baseuri, encoding): | |
self.document = BeautifulSoup.BeautifulSoup(data) | |
self.baseuri = baseuri | |
self.encoding = encoding | |
if type(data) == type(u''): | |
data = data.encode(encoding) | |
self.tags = [] | |
self.enclosures = [] | |
self.xfn = [] | |
self.vcard = None | |
def vcardEscape(self, s): | |
if type(s) in (type(''), type(u'')): | |
s = s.replace(',', '\\,').replace(';', '\\;').replace('\n', '\\n') | |
return s | |
def vcardFold(self, s): | |
s = re.sub(';+$', '', s) | |
sFolded = '' | |
iMax = 75 | |
sPrefix = '' | |
while len(s) > iMax: | |
sFolded += sPrefix + s[:iMax] + '\n' | |
s = s[iMax:] | |
sPrefix = ' ' | |
iMax = 74 | |
sFolded += sPrefix + s | |
return sFolded | |
def normalize(self, s): | |
return re.sub(r'\s+', ' ', s).strip() | |
def unique(self, aList): | |
results = [] | |
for element in aList: | |
if element not in results: | |
results.append(element) | |
return results | |
def toISO8601(self, dt): | |
return time.strftime('%Y-%m-%dT%H:%M:%SZ', dt) | |
def getPropertyValue(self, elmRoot, sProperty, iPropertyType=4, bAllowMultiple=0, bAutoEscape=0): | |
all = lambda x: 1 | |
sProperty = sProperty.lower() | |
bFound = 0 | |
bNormalize = 1 | |
propertyMatch = {'class': re.compile(r'\b%s\b' % sProperty)} | |
if bAllowMultiple and (iPropertyType != self.NODE): | |
snapResults = [] | |
containers = elmRoot(['ul', 'ol'], propertyMatch) | |
for container in containers: | |
snapResults.extend(container('li')) | |
bFound = (len(snapResults) != 0) | |
if not bFound: | |
snapResults = elmRoot(all, propertyMatch) | |
bFound = (len(snapResults) != 0) | |
if (not bFound) and (sProperty == 'value'): | |
snapResults = elmRoot('pre') | |
bFound = (len(snapResults) != 0) | |
bNormalize = not bFound | |
if not bFound: | |
snapResults = [elmRoot] | |
bFound = (len(snapResults) != 0) | |
arFilter = [] | |
if sProperty == 'vcard': | |
snapFilter = elmRoot(all, propertyMatch) | |
for node in snapFilter: | |
if node.findParent(all, propertyMatch): | |
arFilter.append(node) | |
arResults = [] | |
for node in snapResults: | |
if node not in arFilter: | |
arResults.append(node) | |
bFound = (len(arResults) != 0) | |
if not bFound: | |
if bAllowMultiple: return [] | |
elif iPropertyType == self.STRING: return '' | |
elif iPropertyType == self.DATE: return None | |
elif iPropertyType == self.URI: return '' | |
elif iPropertyType == self.NODE: return None | |
else: return None | |
arValues = [] | |
for elmResult in arResults: | |
sValue = None | |
if iPropertyType == self.NODE: | |
if bAllowMultiple: | |
arValues.append(elmResult) | |
continue | |
else: | |
return elmResult | |
sNodeName = elmResult.name.lower() | |
if (iPropertyType == self.EMAIL) and (sNodeName == 'a'): | |
sValue = (elmResult.get('href') or '').split('mailto:').pop().split('?')[0] | |
if sValue: | |
sValue = bNormalize and self.normalize(sValue) or sValue.strip() | |
if (not sValue) and (sNodeName == 'abbr'): | |
sValue = elmResult.get('title') | |
if sValue: | |
sValue = bNormalize and self.normalize(sValue) or sValue.strip() | |
if (not sValue) and (iPropertyType == self.URI): | |
if sNodeName == 'a': sValue = elmResult.get('href') | |
elif sNodeName == 'img': sValue = elmResult.get('src') | |
elif sNodeName == 'object': sValue = elmResult.get('data') | |
if sValue: | |
sValue = bNormalize and self.normalize(sValue) or sValue.strip() | |
if (not sValue) and (sNodeName == 'img'): | |
sValue = elmResult.get('alt') | |
if sValue: | |
sValue = bNormalize and self.normalize(sValue) or sValue.strip() | |
if not sValue: | |
sValue = elmResult.renderContents() | |
sValue = re.sub(r'<\S[^>]*>', '', sValue) | |
sValue = sValue.replace('\r\n', '\n') | |
sValue = sValue.replace('\r', '\n') | |
if sValue: | |
sValue = bNormalize and self.normalize(sValue) or sValue.strip() | |
if not sValue: continue | |
if iPropertyType == self.DATE: | |
sValue = _parse_date_iso8601(sValue) | |
if bAllowMultiple: | |
arValues.append(bAutoEscape and self.vcardEscape(sValue) or sValue) | |
else: | |
return bAutoEscape and self.vcardEscape(sValue) or sValue | |
return arValues | |
def findVCards(self, elmRoot, bAgentParsing=0): | |
sVCards = '' | |
if not bAgentParsing: | |
arCards = self.getPropertyValue(elmRoot, 'vcard', bAllowMultiple=1) | |
else: | |
arCards = [elmRoot] | |
for elmCard in arCards: | |
arLines = [] | |
def processSingleString(sProperty): | |
sValue = self.getPropertyValue(elmCard, sProperty, self.STRING, bAutoEscape=1).decode(self.encoding) | |
if sValue: | |
arLines.append(self.vcardFold(sProperty.upper() + ':' + sValue)) | |
return sValue or u'' | |
def processSingleURI(sProperty): | |
sValue = self.getPropertyValue(elmCard, sProperty, self.URI) | |
if sValue: | |
sContentType = '' | |
sEncoding = '' | |
sValueKey = '' | |
if sValue.startswith('data:'): | |
sEncoding = ';ENCODING=b' | |
sContentType = sValue.split(';')[0].split('/').pop() | |
sValue = sValue.split(',', 1).pop() | |
else: | |
elmValue = self.getPropertyValue(elmCard, sProperty) | |
if elmValue: | |
if sProperty != 'url': | |
sValueKey = ';VALUE=uri' | |
sContentType = elmValue.get('type', '').strip().split('/').pop().strip() | |
sContentType = sContentType.upper() | |
if sContentType == 'OCTET-STREAM': | |
sContentType = '' | |
if sContentType: | |
sContentType = ';TYPE=' + sContentType.upper() | |
arLines.append(self.vcardFold(sProperty.upper() + sEncoding + sContentType + sValueKey + ':' + sValue)) | |
def processTypeValue(sProperty, arDefaultType, arForceType=None): | |
arResults = self.getPropertyValue(elmCard, sProperty, bAllowMultiple=1) | |
for elmResult in arResults: | |
arType = self.getPropertyValue(elmResult, 'type', self.STRING, 1, 1) | |
if arForceType: | |
arType = self.unique(arForceType + arType) | |
if not arType: | |
arType = arDefaultType | |
sValue = self.getPropertyValue(elmResult, 'value', self.EMAIL, 0) | |
if sValue: | |
arLines.append(self.vcardFold(sProperty.upper() + ';TYPE=' + ','.join(arType) + ':' + sValue)) | |
# AGENT | |
# must do this before all other properties because it is destructive | |
# (removes nested class="vcard" nodes so they don't interfere with | |
# this vcard's other properties) | |
arAgent = self.getPropertyValue(elmCard, 'agent', bAllowMultiple=1) | |
for elmAgent in arAgent: | |
if re.compile(r'\bvcard\b').search(elmAgent.get('class')): | |
sAgentValue = self.findVCards(elmAgent, 1) + '\n' | |
sAgentValue = sAgentValue.replace('\n', '\\n') | |
sAgentValue = sAgentValue.replace(';', '\\;') | |
if sAgentValue: | |
arLines.append(self.vcardFold('AGENT:' + sAgentValue)) | |
# Completely remove the agent element from the parse tree | |
elmAgent.extract() | |
else: | |
sAgentValue = self.getPropertyValue(elmAgent, 'value', self.URI, bAutoEscape=1); | |
if sAgentValue: | |
arLines.append(self.vcardFold('AGENT;VALUE=uri:' + sAgentValue)) | |
# FN (full name) | |
sFN = processSingleString('fn') | |
# N (name) | |
elmName = self.getPropertyValue(elmCard, 'n') | |
if elmName: | |
sFamilyName = self.getPropertyValue(elmName, 'family-name', self.STRING, bAutoEscape=1) | |
sGivenName = self.getPropertyValue(elmName, 'given-name', self.STRING, bAutoEscape=1) | |
arAdditionalNames = self.getPropertyValue(elmName, 'additional-name', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'additional-names', self.STRING, 1, 1) | |
arHonorificPrefixes = self.getPropertyValue(elmName, 'honorific-prefix', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'honorific-prefixes', self.STRING, 1, 1) | |
arHonorificSuffixes = self.getPropertyValue(elmName, 'honorific-suffix', self.STRING, 1, 1) + self.getPropertyValue(elmName, 'honorific-suffixes', self.STRING, 1, 1) | |
arLines.append(self.vcardFold('N:' + sFamilyName + ';' + | |
sGivenName + ';' + | |
','.join(arAdditionalNames) + ';' + | |
','.join(arHonorificPrefixes) + ';' + | |
','.join(arHonorificSuffixes))) | |
elif sFN: | |
# implied "N" optimization | |
# http://microformats.org/wiki/hcard#Implied_.22N.22_Optimization | |
arNames = self.normalize(sFN).split() | |
if len(arNames) == 2: | |
bFamilyNameFirst = (arNames[0].endswith(',') or | |
len(arNames[1]) == 1 or | |
((len(arNames[1]) == 2) and (arNames[1].endswith('.')))) | |
if bFamilyNameFirst: | |
arLines.append(self.vcardFold('N:' + arNames[0] + ';' + arNames[1])) | |
else: | |
arLines.append(self.vcardFold('N:' + arNames[1] + ';' + arNames[0])) | |
# SORT-STRING | |
sSortString = self.getPropertyValue(elmCard, 'sort-string', self.STRING, bAutoEscape=1) | |
if sSortString: | |
arLines.append(self.vcardFold('SORT-STRING:' + sSortString)) | |
# NICKNAME | |
arNickname = self.getPropertyValue(elmCard, 'nickname', self.STRING, 1, 1) | |
if arNickname: | |
arLines.append(self.vcardFold('NICKNAME:' + ','.join(arNickname))) | |
# PHOTO | |
processSingleURI('photo') | |
# BDAY | |
dtBday = self.getPropertyValue(elmCard, 'bday', self.DATE) | |
if dtBday: | |
arLines.append(self.vcardFold('BDAY:' + self.toISO8601(dtBday))) | |
# ADR (address) | |
arAdr = self.getPropertyValue(elmCard, 'adr', bAllowMultiple=1) | |
for elmAdr in arAdr: | |
arType = self.getPropertyValue(elmAdr, 'type', self.STRING, 1, 1) | |
if not arType: | |
arType = ['intl','postal','parcel','work'] # default adr types, see RFC 2426 section 3.2.1 | |
sPostOfficeBox = self.getPropertyValue(elmAdr, 'post-office-box', self.STRING, 0, 1) | |
sExtendedAddress = self.getPropertyValue(elmAdr, 'extended-address', self.STRING, 0, 1) | |
sStreetAddress = self.getPropertyValue(elmAdr, 'street-address', self.STRING, 0, 1) | |
sLocality = self.getPropertyValue(elmAdr, 'locality', self.STRING, 0, 1) | |
sRegion = self.getPropertyValue(elmAdr, 'region', self.STRING, 0, 1) | |
sPostalCode = self.getPropertyValue(elmAdr, 'postal-code', self.STRING, 0, 1) | |
sCountryName = self.getPropertyValue(elmAdr, 'country-name', self.STRING, 0, 1) | |
arLines.append(self.vcardFold('ADR;TYPE=' + ','.join(arType) + ':' + | |
sPostOfficeBox + ';' + | |
sExtendedAddress + ';' + | |
sStreetAddress + ';' + | |
sLocality + ';' + | |
sRegion + ';' + | |
sPostalCode + ';' + | |
sCountryName)) | |
# LABEL | |
processTypeValue('label', ['intl','postal','parcel','work']) | |
# TEL (phone number) | |
processTypeValue('tel', ['voice']) | |
processTypeValue('email', ['internet'], ['internet']) | |
# MAILER | |
processSingleString('mailer') | |
# TZ (timezone) | |
processSingleString('tz') | |
# GEO (geographical information) | |
elmGeo = self.getPropertyValue(elmCard, 'geo') | |
if elmGeo: | |
sLatitude = self.getPropertyValue(elmGeo, 'latitude', self.STRING, 0, 1) | |
sLongitude = self.getPropertyValue(elmGeo, 'longitude', self.STRING, 0, 1) | |
arLines.append(self.vcardFold('GEO:' + sLatitude + ';' + sLongitude)) | |
# TITLE | |
processSingleString('title') | |
# ROLE | |
processSingleString('role') | |
# LOGO | |
processSingleURI('logo') | |
# ORG (organization) | |
elmOrg = self.getPropertyValue(elmCard, 'org') | |
if elmOrg: | |
sOrganizationName = self.getPropertyValue(elmOrg, 'organization-name', self.STRING, 0, 1) | |
if not sOrganizationName: | |
# implied "organization-name" optimization | |
# http://microformats.org/wiki/hcard#Implied_.22organization-name.22_Optimization | |
sOrganizationName = self.getPropertyValue(elmCard, 'org', self.STRING, 0, 1) | |
if sOrganizationName: | |
arLines.append(self.vcardFold('ORG:' + sOrganizationName)) | |
else: | |
arOrganizationUnit = self.getPropertyValue(elmOrg, 'organization-unit', self.STRING, 1, 1) | |
arLines.append(self.vcardFold('ORG:' + sOrganizationName + ';' + ';'.join(arOrganizationUnit))) | |
# CATEGORY | |
arCategory = self.getPropertyValue(elmCard, 'category', self.STRING, 1, 1) + self.getPropertyValue(elmCard, 'categories', self.STRING, 1, 1) | |
if arCategory: | |
arLines.append(self.vcardFold('CATEGORIES:' + ','.join(arCategory))) | |
# NOTE | |
processSingleString('note') | |
# REV | |
processSingleString('rev') | |
# SOUND | |
processSingleURI('sound') | |
# UID | |
processSingleString('uid') | |
# URL | |
processSingleURI('url') | |
# CLASS | |
processSingleString('class') | |
# KEY | |
processSingleURI('key') | |
if arLines: | |
arLines = [u'BEGIN:vCard',u'VERSION:3.0'] + arLines + [u'END:vCard'] | |
sVCards += u'\n'.join(arLines) + u'\n' | |
return sVCards.strip() | |
def isProbablyDownloadable(self, elm): | |
attrsD = elm.attrMap | |
if not attrsD.has_key('href'): return 0 | |
linktype = attrsD.get('type', '').strip() | |
if linktype.startswith('audio/') or \ | |
linktype.startswith('video/') or \ | |
(linktype.startswith('application/') and not linktype.endswith('xml')): | |
return 1 | |
path = urlparse.urlparse(attrsD['href'])[2] | |
if path.find('.') == -1: return 0 | |
fileext = path.split('.').pop().lower() | |
return fileext in self.known_binary_extensions | |
def findTags(self): | |
all = lambda x: 1 | |
for elm in self.document(all, {'rel': re.compile(r'\btag\b')}): | |
href = elm.get('href') | |
if not href: continue | |
urlscheme, domain, path, params, query, fragment = \ | |
urlparse.urlparse(_urljoin(self.baseuri, href)) | |
segments = path.split('/') | |
tag = segments.pop() | |
if not tag: | |
tag = segments.pop() | |
tagscheme = urlparse.urlunparse((urlscheme, domain, '/'.join(segments), '', '', '')) | |
if not tagscheme.endswith('/'): | |
tagscheme += '/' | |
self.tags.append(FeedParserDict({"term": tag, "scheme": tagscheme, "label": elm.string or ''})) | |
def findEnclosures(self): | |
all = lambda x: 1 | |
enclosure_match = re.compile(r'\benclosure\b') | |
for elm in self.document(all, {'href': re.compile(r'.+')}): | |
if not enclosure_match.search(elm.get('rel', '')) and not self.isProbablyDownloadable(elm): continue | |
if elm.attrMap not in self.enclosures: | |
self.enclosures.append(elm.attrMap) | |
if elm.string and not elm.get('title'): | |
self.enclosures[-1]['title'] = elm.string | |
def findXFN(self): | |
all = lambda x: 1 | |
for elm in self.document(all, {'rel': re.compile('.+'), 'href': re.compile('.+')}): | |
rels = elm.get('rel', '').split() | |
xfn_rels = [] | |
for rel in rels: | |
if rel in self.known_xfn_relationships: | |
xfn_rels.append(rel) | |
if xfn_rels: | |
self.xfn.append({"relationships": xfn_rels, "href": elm.get('href', ''), "name": elm.string}) | |
def _parseMicroformats(htmlSource, baseURI, encoding): | |
if not BeautifulSoup: return | |
if _debug: sys.stderr.write('entering _parseMicroformats\n') | |
try: | |
p = _MicroformatsParser(htmlSource, baseURI, encoding) | |
except UnicodeEncodeError: | |
# sgmllib throws this exception when performing lookups of tags | |
# with non-ASCII characters in them. | |
return | |
p.vcard = p.findVCards(p.document) | |
p.findTags() | |
p.findEnclosures() | |
p.findXFN() | |
return {"tags": p.tags, "enclosures": p.enclosures, "xfn": p.xfn, "vcard": p.vcard} | |
class _RelativeURIResolver(_BaseHTMLProcessor): | |
relative_uris = [('a', 'href'), | |
('applet', 'codebase'), | |
('area', 'href'), | |
('blockquote', 'cite'), | |
('body', 'background'), | |
('del', 'cite'), | |
('form', 'action'), | |
('frame', 'longdesc'), | |
('frame', 'src'), | |
('iframe', 'longdesc'), | |
('iframe', 'src'), | |
('head', 'profile'), | |
('img', 'longdesc'), | |
('img', 'src'), | |
('img', 'usemap'), | |
('input', 'src'), | |
('input', 'usemap'), | |
('ins', 'cite'), | |
('link', 'href'), | |
('object', 'classid'), | |
('object', 'codebase'), | |
('object', 'data'), | |
('object', 'usemap'), | |
('q', 'cite'), | |
('script', 'src')] | |
def __init__(self, baseuri, encoding, _type): | |
_BaseHTMLProcessor.__init__(self, encoding, _type) | |
self.baseuri = baseuri | |
def resolveURI(self, uri): | |
return _makeSafeAbsoluteURI(_urljoin(self.baseuri, uri.strip())) | |
def unknown_starttag(self, tag, attrs): | |
if _debug: | |
sys.stderr.write('tag: [%s] with attributes: [%s]\n' % (tag, str(attrs))) | |
attrs = self.normalize_attrs(attrs) | |
attrs = [(key, ((tag, key) in self.relative_uris) and self.resolveURI(value) or value) for key, value in attrs] | |
_BaseHTMLProcessor.unknown_starttag(self, tag, attrs) | |
def _resolveRelativeURIs(htmlSource, baseURI, encoding, _type): | |
if _debug: | |
sys.stderr.write('entering _resolveRelativeURIs\n') | |
p = _RelativeURIResolver(baseURI, encoding, _type) | |
p.feed(htmlSource) | |
return p.output() | |
def _makeSafeAbsoluteURI(base, rel=None): | |
# bail if ACCEPTABLE_URI_SCHEMES is empty | |
if not ACCEPTABLE_URI_SCHEMES: | |
return _urljoin(base, rel or u'') | |
if not base: | |
return rel or u'' | |
if not rel: | |
scheme = urlparse.urlparse(base)[0] | |
if not scheme or scheme in ACCEPTABLE_URI_SCHEMES: | |
return base | |
return u'' | |
uri = _urljoin(base, rel) | |
if uri.strip().split(':', 1)[0] not in ACCEPTABLE_URI_SCHEMES: | |
return u'' | |
return uri | |
class _HTMLSanitizer(_BaseHTMLProcessor): | |
acceptable_elements = ['a', 'abbr', 'acronym', 'address', 'area', | |
'article', 'aside', 'audio', 'b', 'big', 'blockquote', 'br', 'button', | |
'canvas', 'caption', 'center', 'cite', 'code', 'col', 'colgroup', | |
'command', 'datagrid', 'datalist', 'dd', 'del', 'details', 'dfn', | |
'dialog', 'dir', 'div', 'dl', 'dt', 'em', 'event-source', 'fieldset', | |
'figcaption', 'figure', 'footer', 'font', 'form', 'header', 'h1', | |
'h2', 'h3', 'h4', 'h5', 'h6', 'hr', 'i', 'img', 'input', 'ins', | |
'keygen', 'kbd', 'label', 'legend', 'li', 'm', 'map', 'menu', 'meter', | |
'multicol', 'nav', 'nextid', 'ol', 'output', 'optgroup', 'option', | |
'p', 'pre', 'progress', 'q', 's', 'samp', 'section', 'select', | |
'small', 'sound', 'source', 'spacer', 'span', 'strike', 'strong', | |
'sub', 'sup', 'table', 'tbody', 'td', 'textarea', 'time', 'tfoot', | |
'th', 'thead', 'tr', 'tt', 'u', 'ul', 'var', 'video', 'noscript'] | |
acceptable_attributes = ['abbr', 'accept', 'accept-charset', 'accesskey', | |
'action', 'align', 'alt', 'autocomplete', 'autofocus', 'axis', | |
'background', 'balance', 'bgcolor', 'bgproperties', 'border', | |
'bordercolor', 'bordercolordark', 'bordercolorlight', 'bottompadding', | |
'cellpadding', 'cellspacing', 'ch', 'challenge', 'char', 'charoff', | |
'choff', 'charset', 'checked', 'cite', 'class', 'clear', 'color', 'cols', | |
'colspan', 'compact', 'contenteditable', 'controls', 'coords', 'data', | |
'datafld', 'datapagesize', 'datasrc', 'datetime', 'default', 'delay', | |
'dir', 'disabled', 'draggable', 'dynsrc', 'enctype', 'end', 'face', 'for', | |
'form', 'frame', 'galleryimg', 'gutter', 'headers', 'height', 'hidefocus', | |
'hidden', 'high', 'href', 'hreflang', 'hspace', 'icon', 'id', 'inputmode', | |
'ismap', 'keytype', 'label', 'leftspacing', 'lang', 'list', 'longdesc', | |
'loop', 'loopcount', 'loopend', 'loopstart', 'low', 'lowsrc', 'max', | |
'maxlength', 'media', 'method', 'min', 'multiple', 'name', 'nohref', | |
'noshade', 'nowrap', 'open', 'optimum', 'pattern', 'ping', 'point-size', | |
'prompt', 'pqg', 'radiogroup', 'readonly', 'rel', 'repeat-max', | |
'repeat-min', 'replace', 'required', 'rev', 'rightspacing', 'rows', | |
'rowspan', 'rules', 'scope', 'selected', 'shape', 'size', 'span', 'src', | |
'start', 'step', 'summary', 'suppress', 'tabindex', 'target', 'template', | |
'title', 'toppadding', 'type', 'unselectable', 'usemap', 'urn', 'valign', | |
'value', 'variable', 'volume', 'vspace', 'vrml', 'width', 'wrap', | |
'xml:lang'] | |
unacceptable_elements_with_end_tag = ['script', 'applet', 'style'] | |
acceptable_css_properties = ['azimuth', 'background-color', | |
'border-bottom-color', 'border-collapse', 'border-color', | |
'border-left-color', 'border-right-color', 'border-top-color', 'clear', | |
'color', 'cursor', 'direction', 'display', 'elevation', 'float', 'font', | |
'font-family', 'font-size', 'font-style', 'font-variant', 'font-weight', | |
'height', 'letter-spacing', 'line-height', 'overflow', 'pause', | |
'pause-after', 'pause-before', 'pitch', 'pitch-range', 'richness', | |
'speak', 'speak-header', 'speak-numeral', 'speak-punctuation', | |
'speech-rate', 'stress', 'text-align', 'text-decoration', 'text-indent', | |
'unicode-bidi', 'vertical-align', 'voice-family', 'volume', | |
'white-space', 'width'] | |
# survey of common keywords found in feeds | |
acceptable_css_keywords = ['auto', 'aqua', 'black', 'block', 'blue', | |
'bold', 'both', 'bottom', 'brown', 'center', 'collapse', 'dashed', | |
'dotted', 'fuchsia', 'gray', 'green', '!important', 'italic', 'left', | |
'lime', 'maroon', 'medium', 'none', 'navy', 'normal', 'nowrap', 'olive', | |
'pointer', 'purple', 'red', 'right', 'solid', 'silver', 'teal', 'top', | |
'transparent', 'underline', 'white', 'yellow'] | |
valid_css_values = re.compile('^(#[0-9a-f]+|rgb\(\d+%?,\d*%?,?\d*%?\)?|' + | |
'\d{0,2}\.?\d{0,2}(cm|em|ex|in|mm|pc|pt|px|%|,|\))?)$') | |
mathml_elements = ['annotation', 'annotation-xml', 'maction', 'math', | |
'merror', 'mfenced', 'mfrac', 'mi', 'mmultiscripts', 'mn', 'mo', 'mover', 'mpadded', | |
'mphantom', 'mprescripts', 'mroot', 'mrow', 'mspace', 'msqrt', 'mstyle', | |
'msub', 'msubsup', 'msup', 'mtable', 'mtd', 'mtext', 'mtr', 'munder', | |
'munderover', 'none', 'semantics'] | |
mathml_attributes = ['actiontype', 'align', 'columnalign', 'columnalign', | |
'columnalign', 'close', 'columnlines', 'columnspacing', 'columnspan', 'depth', | |
'display', 'displaystyle', 'encoding', 'equalcolumns', 'equalrows', | |
'fence', 'fontstyle', 'fontweight', 'frame', 'height', 'linethickness', | |
'lspace', 'mathbackground', 'mathcolor', 'mathvariant', 'mathvariant', | |
'maxsize', 'minsize', 'open', 'other', 'rowalign', 'rowalign', 'rowalign', | |
'rowlines', 'rowspacing', 'rowspan', 'rspace', 'scriptlevel', 'selection', | |
'separator', 'separators', 'stretchy', 'width', 'width', 'xlink:href', | |
'xlink:show', 'xlink:type', 'xmlns', 'xmlns:xlink'] | |
# svgtiny - foreignObject + linearGradient + radialGradient + stop | |
svg_elements = ['a', 'animate', 'animateColor', 'animateMotion', | |
'animateTransform', 'circle', 'defs', 'desc', 'ellipse', 'foreignObject', | |
'font-face', 'font-face-name', 'font-face-src', 'g', 'glyph', 'hkern', | |
'linearGradient', 'line', 'marker', 'metadata', 'missing-glyph', 'mpath', | |
'path', 'polygon', 'polyline', 'radialGradient', 'rect', 'set', 'stop', | |
'svg', 'switch', 'text', 'title', 'tspan', 'use'] | |
# svgtiny + class + opacity + offset + xmlns + xmlns:xlink | |
svg_attributes = ['accent-height', 'accumulate', 'additive', 'alphabetic', | |
'arabic-form', 'ascent', 'attributeName', 'attributeType', | |
'baseProfile', 'bbox', 'begin', 'by', 'calcMode', 'cap-height', | |
'class', 'color', 'color-rendering', 'content', 'cx', 'cy', 'd', 'dx', | |
'dy', 'descent', 'display', 'dur', 'end', 'fill', 'fill-opacity', | |
'fill-rule', 'font-family', 'font-size', 'font-stretch', 'font-style', | |
'font-variant', 'font-weight', 'from', 'fx', 'fy', 'g1', 'g2', | |
'glyph-name', 'gradientUnits', 'hanging', 'height', 'horiz-adv-x', | |
'horiz-origin-x', 'id', 'ideographic', 'k', 'keyPoints', 'keySplines', | |
'keyTimes', 'lang', 'mathematical', 'marker-end', 'marker-mid', | |
'marker-start', 'markerHeight', 'markerUnits', 'markerWidth', 'max', | |
'min', 'name', 'offset', 'opacity', 'orient', 'origin', | |
'overline-position', 'overline-thickness', 'panose-1', 'path', | |
'pathLength', 'points', 'preserveAspectRatio', 'r', 'refX', 'refY', | |
'repeatCount', 'repeatDur', 'requiredExtensions', 'requiredFeatures', | |
'restart', 'rotate', 'rx', 'ry', 'slope', 'stemh', 'stemv', | |
'stop-color', 'stop-opacity', 'strikethrough-position', | |
'strikethrough-thickness', 'stroke', 'stroke-dasharray', | |
'stroke-dashoffset', 'stroke-linecap', 'stroke-linejoin', | |
'stroke-miterlimit', 'stroke-opacity', 'stroke-width', 'systemLanguage', | |
'target', 'text-anchor', 'to', 'transform', 'type', 'u1', 'u2', | |
'underline-position', 'underline-thickness', 'unicode', 'unicode-range', | |
'units-per-em', 'values', 'version', 'viewBox', 'visibility', 'width', | |
'widths', 'x', 'x-height', 'x1', 'x2', 'xlink:actuate', 'xlink:arcrole', | |
'xlink:href', 'xlink:role', 'xlink:show', 'xlink:title', 'xlink:type', | |
'xml:base', 'xml:lang', 'xml:space', 'xmlns', 'xmlns:xlink', 'y', 'y1', | |
'y2', 'zoomAndPan'] | |
svg_attr_map = None | |
svg_elem_map = None | |
acceptable_svg_properties = [ 'fill', 'fill-opacity', 'fill-rule', | |
'stroke', 'stroke-width', 'stroke-linecap', 'stroke-linejoin', | |
'stroke-opacity'] | |
def reset(self): | |
_BaseHTMLProcessor.reset(self) | |
self.unacceptablestack = 0 | |
self.mathmlOK = 0 | |
self.svgOK = 0 | |
def unknown_starttag(self, tag, attrs): | |
acceptable_attributes = self.acceptable_attributes | |
keymap = {} | |
if not tag in self.acceptable_elements or self.svgOK: | |
if tag in self.unacceptable_elements_with_end_tag: | |
self.unacceptablestack += 1 | |
# add implicit namespaces to html5 inline svg/mathml | |
if self._type.endswith('html'): | |
if not dict(attrs).get('xmlns'): | |
if tag=='svg': | |
attrs.append( ('xmlns','http://www.w3.org/2000/svg') ) | |
if tag=='math': | |
attrs.append( ('xmlns','http://www.w3.org/1998/Math/MathML') ) | |
# not otherwise acceptable, perhaps it is MathML or SVG? | |
if tag=='math' and ('xmlns','http://www.w3.org/1998/Math/MathML') in attrs: | |
self.mathmlOK += 1 | |
if tag=='svg' and ('xmlns','http://www.w3.org/2000/svg') in attrs: | |
self.svgOK += 1 | |
# chose acceptable attributes based on tag class, else bail | |
if self.mathmlOK and tag in self.mathml_elements: | |
acceptable_attributes = self.mathml_attributes | |
elif self.svgOK and tag in self.svg_elements: | |
# for most vocabularies, lowercasing is a good idea. Many | |
# svg elements, however, are camel case | |
if not self.svg_attr_map: | |
lower=[attr.lower() for attr in self.svg_attributes] | |
mix=[a for a in self.svg_attributes if a not in lower] | |
self.svg_attributes = lower | |
self.svg_attr_map = dict([(a.lower(),a) for a in mix]) | |
lower=[attr.lower() for attr in self.svg_elements] | |
mix=[a for a in self.svg_elements if a not in lower] | |
self.svg_elements = lower | |
self.svg_elem_map = dict([(a.lower(),a) for a in mix]) | |
acceptable_attributes = self.svg_attributes | |
tag = self.svg_elem_map.get(tag,tag) | |
keymap = self.svg_attr_map | |
elif not tag in self.acceptable_elements: | |
return | |
# declare xlink namespace, if needed | |
if self.mathmlOK or self.svgOK: | |
if filter(lambda (n,v): n.startswith('xlink:'),attrs): | |
if not ('xmlns:xlink','http://www.w3.org/1999/xlink') in attrs: | |
attrs.append(('xmlns:xlink','http://www.w3.org/1999/xlink')) | |
clean_attrs = [] | |
for key, value in self.normalize_attrs(attrs): | |
if key in acceptable_attributes: | |
key=keymap.get(key,key) | |
# make sure the uri uses an acceptable uri scheme | |
if key == u'href': | |
value = _makeSafeAbsoluteURI(value) | |
clean_attrs.append((key,value)) | |
elif key=='style': | |
clean_value = self.sanitize_style(value) | |
if clean_value: clean_attrs.append((key,clean_value)) | |
_BaseHTMLProcessor.unknown_starttag(self, tag, clean_attrs) | |
def unknown_endtag(self, tag): | |
if not tag in self.acceptable_elements: | |
if tag in self.unacceptable_elements_with_end_tag: | |
self.unacceptablestack -= 1 | |
if self.mathmlOK and tag in self.mathml_elements: | |
if tag == 'math' and self.mathmlOK: self.mathmlOK -= 1 | |
elif self.svgOK and tag in self.svg_elements: | |
tag = self.svg_elem_map.get(tag,tag) | |
if tag == 'svg' and self.svgOK: self.svgOK -= 1 | |
else: | |
return | |
_BaseHTMLProcessor.unknown_endtag(self, tag) | |
def handle_pi(self, text): | |
pass | |
def handle_decl(self, text): | |
pass | |
def handle_data(self, text): | |
if not self.unacceptablestack: | |
_BaseHTMLProcessor.handle_data(self, text) | |
def sanitize_style(self, style): | |
# disallow urls | |
style=re.compile('url\s*\(\s*[^\s)]+?\s*\)\s*').sub(' ',style) | |
# gauntlet | |
if not re.match("""^([:,;#%.\sa-zA-Z0-9!]|\w-\w|'[\s\w]+'|"[\s\w]+"|\([\d,\s]+\))*$""", style): return '' | |
# This replaced a regexp that used re.match and was prone to pathological back-tracking. | |
if re.sub("\s*[-\w]+\s*:\s*[^:;]*;?", '', style).strip(): return '' | |
clean = [] | |
for prop,value in re.findall("([-\w]+)\s*:\s*([^:;]*)",style): | |
if not value: continue | |
if prop.lower() in self.acceptable_css_properties: | |
clean.append(prop + ': ' + value + ';') | |
elif prop.split('-')[0].lower() in ['background','border','margin','padding']: | |
for keyword in value.split(): | |
if not keyword in self.acceptable_css_keywords and \ | |
not self.valid_css_values.match(keyword): | |
break | |
else: | |
clean.append(prop + ': ' + value + ';') | |
elif self.svgOK and prop.lower() in self.acceptable_svg_properties: | |
clean.append(prop + ': ' + value + ';') | |
return ' '.join(clean) | |
def parse_comment(self, i, report=1): | |
ret = _BaseHTMLProcessor.parse_comment(self, i, report) | |
if ret >= 0: | |
return ret | |
# if ret == -1, this may be a malicious attempt to circumvent | |
# sanitization, or a page-destroying unclosed comment | |
match = re.compile(r'--[^>]*>').search(self.rawdata, i+4) | |
if match: | |
return match.end() | |
# unclosed comment; deliberately fail to handle_data() | |
return len(self.rawdata) | |
def _sanitizeHTML(htmlSource, encoding, _type): | |
p = _HTMLSanitizer(encoding, _type) | |
htmlSource = htmlSource.replace('<![CDATA[', '<![CDATA[') | |
p.feed(htmlSource) | |
data = p.output() | |
if TIDY_MARKUP: | |
# loop through list of preferred Tidy interfaces looking for one that's installed, | |
# then set up a common _tidy function to wrap the interface-specific API. | |
_tidy = None | |
for tidy_interface in PREFERRED_TIDY_INTERFACES: | |
try: | |
if tidy_interface == "uTidy": | |
from tidy import parseString as _utidy | |
def _tidy(data, **kwargs): | |
return str(_utidy(data, **kwargs)) | |
break | |
elif tidy_interface == "mxTidy": | |
from mx.Tidy import Tidy as _mxtidy | |
def _tidy(data, **kwargs): | |
nerrors, nwarnings, data, errordata = _mxtidy.tidy(data, **kwargs) | |
return data | |
break | |
except: | |
pass | |
if _tidy: | |
utf8 = type(data) == type(u'') | |
if utf8: | |
data = data.encode('utf-8') | |
data = _tidy(data, output_xhtml=1, numeric_entities=1, wrap=0, char_encoding="utf8") | |
if utf8: | |
data = unicode(data, 'utf-8') | |
if data.count('<body'): | |
data = data.split('<body', 1)[1] | |
if data.count('>'): | |
data = data.split('>', 1)[1] | |
if data.count('</body'): | |
data = data.split('</body', 1)[0] | |
data = data.strip().replace('\r\n', '\n') | |
return data | |
class _FeedURLHandler(urllib2.HTTPDigestAuthHandler, urllib2.HTTPRedirectHandler, urllib2.HTTPDefaultErrorHandler): | |
def http_error_default(self, req, fp, code, msg, headers): | |
if ((code / 100) == 3) and (code != 304): | |
return self.http_error_302(req, fp, code, msg, headers) | |
infourl = urllib.addinfourl(fp, headers, req.get_full_url()) | |
infourl.status = code | |
return infourl | |
def http_error_302(self, req, fp, code, msg, headers): | |
if headers.dict.has_key('location'): | |
infourl = urllib2.HTTPRedirectHandler.http_error_302(self, req, fp, code, msg, headers) | |
else: | |
infourl = urllib.addinfourl(fp, headers, req.get_full_url()) | |
if not hasattr(infourl, 'status'): | |
infourl.status = code | |
return infourl | |
def http_error_301(self, req, fp, code, msg, headers): | |
if headers.dict.has_key('location'): | |
infourl = urllib2.HTTPRedirectHandler.http_error_301(self, req, fp, code, msg, headers) | |
else: | |
infourl = urllib.addinfourl(fp, headers, req.get_full_url()) | |
if not hasattr(infourl, 'status'): | |
infourl.status = code | |
return infourl | |
http_error_300 = http_error_302 | |
http_error_303 = http_error_302 | |
http_error_307 = http_error_302 | |
def http_error_401(self, req, fp, code, msg, headers): | |
# Check if | |
# - server requires digest auth, AND | |
# - we tried (unsuccessfully) with basic auth, AND | |
# - we're using Python 2.3.3 or later (digest auth is irreparably broken in earlier versions) | |
# If all conditions hold, parse authentication information | |
# out of the Authorization header we sent the first time | |
# (for the username and password) and the WWW-Authenticate | |
# header the server sent back (for the realm) and retry | |
# the request with the appropriate digest auth headers instead. | |
# This evil genius hack has been brought to you by Aaron Swartz. | |
host = urlparse.urlparse(req.get_full_url())[1] | |
try: | |
assert sys.version.split()[0] >= '2.3.3' | |
assert base64 != None | |
user, passw = _base64decode(req.headers['Authorization'].split(' ')[1]).split(':') | |
realm = re.findall('realm="([^"]*)"', headers['WWW-Authenticate'])[0] | |
self.add_password(realm, host, user, passw) | |
retry = self.http_error_auth_reqed('www-authenticate', host, req, headers) | |
self.reset_retry_count() | |
return retry | |
except: | |
return self.http_error_default(req, fp, code, msg, headers) | |
def _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, handlers, request_headers): | |
"""URL, filename, or string --> stream | |
This function lets you define parsers that take any input source | |
(URL, pathname to local or network file, or actual data as a string) | |
and deal with it in a uniform manner. Returned object is guaranteed | |
to have all the basic stdio read methods (read, readline, readlines). | |
Just .close() the object when you're done with it. | |
If the etag argument is supplied, it will be used as the value of an | |
If-None-Match request header. | |
If the modified argument is supplied, it can be a tuple of 9 integers | |
(as returned by gmtime() in the standard Python time module) or a date | |
string in any format supported by feedparser. Regardless, it MUST | |
be in GMT (Greenwich Mean Time). It will be reformatted into an | |
RFC 1123-compliant date and used as the value of an If-Modified-Since | |
request header. | |
If the agent argument is supplied, it will be used as the value of a | |
User-Agent request header. | |
If the referrer argument is supplied, it will be used as the value of a | |
Referer[sic] request header. | |
If handlers is supplied, it is a list of handlers used to build a | |
urllib2 opener. | |
if request_headers is supplied it is a dictionary of HTTP request headers | |
that will override the values generated by FeedParser. | |
""" | |
if hasattr(url_file_stream_or_string, 'read'): | |
return url_file_stream_or_string | |
if url_file_stream_or_string == '-': | |
return sys.stdin | |
if urlparse.urlparse(url_file_stream_or_string)[0] in ('http', 'https', 'ftp', 'file', 'feed'): | |
# Deal with the feed URI scheme | |
if url_file_stream_or_string.startswith('feed:http'): | |
url_file_stream_or_string = url_file_stream_or_string[5:] | |
elif url_file_stream_or_string.startswith('feed:'): | |
url_file_stream_or_string = 'http:' + url_file_stream_or_string[5:] | |
if not agent: | |
agent = USER_AGENT | |
# test for inline user:password for basic auth | |
auth = None | |
if base64: | |
urltype, rest = urllib.splittype(url_file_stream_or_string) | |
realhost, rest = urllib.splithost(rest) | |
if realhost: | |
user_passwd, realhost = urllib.splituser(realhost) | |
if user_passwd: | |
url_file_stream_or_string = '%s://%s%s' % (urltype, realhost, rest) | |
auth = base64.standard_b64encode(user_passwd).strip() | |
# iri support | |
try: | |
if isinstance(url_file_stream_or_string,unicode): | |
url_file_stream_or_string = url_file_stream_or_string.encode('idna').decode('utf-8') | |
else: | |
url_file_stream_or_string = url_file_stream_or_string.decode('utf-8').encode('idna').decode('utf-8') | |
except: | |
pass | |
# try to open with urllib2 (to use optional headers) | |
request = _build_urllib2_request(url_file_stream_or_string, agent, etag, modified, referrer, auth, request_headers) | |
opener = apply(urllib2.build_opener, tuple(handlers + [_FeedURLHandler()])) | |
opener.addheaders = [] # RMK - must clear so we only send our custom User-Agent | |
try: | |
return opener.open(request) | |
finally: | |
opener.close() # JohnD | |
# try to open with native open function (if url_file_stream_or_string is a filename) | |
try: | |
return open(url_file_stream_or_string, 'rb') | |
except: | |
pass | |
# treat url_file_stream_or_string as string | |
return _StringIO(str(url_file_stream_or_string)) | |
def _build_urllib2_request(url, agent, etag, modified, referrer, auth, request_headers): | |
request = urllib2.Request(url) | |
request.add_header('User-Agent', agent) | |
if etag: | |
request.add_header('If-None-Match', etag) | |
if type(modified) == type(''): | |
modified = _parse_date(modified) | |
elif isinstance(modified, datetime.datetime): | |
modified = modified.utctimetuple() | |
if modified: | |
# format into an RFC 1123-compliant timestamp. We can't use | |
# time.strftime() since the %a and %b directives can be affected | |
# by the current locale, but RFC 2616 states that dates must be | |
# in English. | |
short_weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] | |
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] | |
request.add_header('If-Modified-Since', '%s, %02d %s %04d %02d:%02d:%02d GMT' % (short_weekdays[modified[6]], modified[2], months[modified[1] - 1], modified[0], modified[3], modified[4], modified[5])) | |
if referrer: | |
request.add_header('Referer', referrer) | |
if gzip and zlib: | |
request.add_header('Accept-encoding', 'gzip, deflate') | |
elif gzip: | |
request.add_header('Accept-encoding', 'gzip') | |
elif zlib: | |
request.add_header('Accept-encoding', 'deflate') | |
else: | |
request.add_header('Accept-encoding', '') | |
if auth: | |
request.add_header('Authorization', 'Basic %s' % auth) | |
if ACCEPT_HEADER: | |
request.add_header('Accept', ACCEPT_HEADER) | |
# use this for whatever -- cookies, special headers, etc | |
# [('Cookie','Something'),('x-special-header','Another Value')] | |
for header_name, header_value in request_headers.items(): | |
request.add_header(header_name, header_value) | |
request.add_header('A-IM', 'feed') # RFC 3229 support | |
return request | |
_date_handlers = [] | |
def registerDateHandler(func): | |
'''Register a date handler function (takes string, returns 9-tuple date in GMT)''' | |
_date_handlers.insert(0, func) | |
# ISO-8601 date parsing routines written by Fazal Majid. | |
# The ISO 8601 standard is very convoluted and irregular - a full ISO 8601 | |
# parser is beyond the scope of feedparser and would be a worthwhile addition | |
# to the Python library. | |
# A single regular expression cannot parse ISO 8601 date formats into groups | |
# as the standard is highly irregular (for instance is 030104 2003-01-04 or | |
# 0301-04-01), so we use templates instead. | |
# Please note the order in templates is significant because we need a | |
# greedy match. | |
_iso8601_tmpl = ['YYYY-?MM-?DD', 'YYYY-0MM?-?DD', 'YYYY-MM', 'YYYY-?OOO', | |
'YY-?MM-?DD', 'YY-?OOO', 'YYYY', | |
'-YY-?MM', '-OOO', '-YY', | |
'--MM-?DD', '--MM', | |
'---DD', | |
'CC', ''] | |
_iso8601_re = [ | |
tmpl.replace( | |
'YYYY', r'(?P<year>\d{4})').replace( | |
'YY', r'(?P<year>\d\d)').replace( | |
'MM', r'(?P<month>[01]\d)').replace( | |
'DD', r'(?P<day>[0123]\d)').replace( | |
'OOO', r'(?P<ordinal>[0123]\d\d)').replace( | |
'CC', r'(?P<century>\d\d$)') | |
+ r'(T?(?P<hour>\d{2}):(?P<minute>\d{2})' | |
+ r'(:(?P<second>\d{2}))?' | |
+ r'(\.(?P<fracsecond>\d+))?' | |
+ r'(?P<tz>[+-](?P<tzhour>\d{2})(:(?P<tzmin>\d{2}))?|Z)?)?' | |
for tmpl in _iso8601_tmpl] | |
try: | |
del tmpl | |
except NameError: | |
pass | |
_iso8601_matches = [re.compile(regex).match for regex in _iso8601_re] | |
try: | |
del regex | |
except NameError: | |
pass | |
def _parse_date_iso8601(dateString): | |
'''Parse a variety of ISO-8601-compatible formats like 20040105''' | |
m = None | |
for _iso8601_match in _iso8601_matches: | |
m = _iso8601_match(dateString) | |
if m: break | |
if not m: return | |
if m.span() == (0, 0): return | |
params = m.groupdict() | |
ordinal = params.get('ordinal', 0) | |
if ordinal: | |
ordinal = int(ordinal) | |
else: | |
ordinal = 0 | |
year = params.get('year', '--') | |
if not year or year == '--': | |
year = time.gmtime()[0] | |
elif len(year) == 2: | |
# ISO 8601 assumes current century, i.e. 93 -> 2093, NOT 1993 | |
year = 100 * int(time.gmtime()[0] / 100) + int(year) | |
else: | |
year = int(year) | |
month = params.get('month', '-') | |
if not month or month == '-': | |
# ordinals are NOT normalized by mktime, we simulate them | |
# by setting month=1, day=ordinal | |
if ordinal: | |
month = 1 | |
else: | |
month = time.gmtime()[1] | |
month = int(month) | |
day = params.get('day', 0) | |
if not day: | |
# see above | |
if ordinal: | |
day = ordinal | |
elif params.get('century', 0) or \ | |
params.get('year', 0) or params.get('month', 0): | |
day = 1 | |
else: | |
day = time.gmtime()[2] | |
else: | |
day = int(day) | |
# special case of the century - is the first year of the 21st century | |
# 2000 or 2001 ? The debate goes on... | |
if 'century' in params.keys(): | |
year = (int(params['century']) - 1) * 100 + 1 | |
# in ISO 8601 most fields are optional | |
for field in ['hour', 'minute', 'second', 'tzhour', 'tzmin']: | |
if not params.get(field, None): | |
params[field] = 0 | |
hour = int(params.get('hour', 0)) | |
minute = int(params.get('minute', 0)) | |
second = int(float(params.get('second', 0))) | |
# weekday is normalized by mktime(), we can ignore it | |
weekday = 0 | |
daylight_savings_flag = -1 | |
tm = [year, month, day, hour, minute, second, weekday, | |
ordinal, daylight_savings_flag] | |
# ISO 8601 time zone adjustments | |
tz = params.get('tz') | |
if tz and tz != 'Z': | |
if tz[0] == '-': | |
tm[3] += int(params.get('tzhour', 0)) | |
tm[4] += int(params.get('tzmin', 0)) | |
elif tz[0] == '+': | |
tm[3] -= int(params.get('tzhour', 0)) | |
tm[4] -= int(params.get('tzmin', 0)) | |
else: | |
return None | |
# Python's time.mktime() is a wrapper around the ANSI C mktime(3c) | |
# which is guaranteed to normalize d/m/y/h/m/s. | |
# Many implementations have bugs, but we'll pretend they don't. | |
return time.localtime(time.mktime(tuple(tm))) | |
registerDateHandler(_parse_date_iso8601) | |
# 8-bit date handling routines written by ytrewq1. | |
_korean_year = u'\ub144' # b3e2 in euc-kr | |
_korean_month = u'\uc6d4' # bff9 in euc-kr | |
_korean_day = u'\uc77c' # c0cf in euc-kr | |
_korean_am = u'\uc624\uc804' # bfc0 c0fc in euc-kr | |
_korean_pm = u'\uc624\ud6c4' # bfc0 c8c4 in euc-kr | |
_korean_onblog_date_re = \ | |
re.compile('(\d{4})%s\s+(\d{2})%s\s+(\d{2})%s\s+(\d{2}):(\d{2}):(\d{2})' % \ | |
(_korean_year, _korean_month, _korean_day)) | |
_korean_nate_date_re = \ | |
re.compile(u'(\d{4})-(\d{2})-(\d{2})\s+(%s|%s)\s+(\d{,2}):(\d{,2}):(\d{,2})' % \ | |
(_korean_am, _korean_pm)) | |
def _parse_date_onblog(dateString): | |
'''Parse a string according to the OnBlog 8-bit date format''' | |
m = _korean_onblog_date_re.match(dateString) | |
if not m: return | |
w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ | |
{'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ | |
'hour': m.group(4), 'minute': m.group(5), 'second': m.group(6),\ | |
'zonediff': '+09:00'} | |
if _debug: sys.stderr.write('OnBlog date parsed as: %s\n' % w3dtfdate) | |
return _parse_date_w3dtf(w3dtfdate) | |
registerDateHandler(_parse_date_onblog) | |
def _parse_date_nate(dateString): | |
'''Parse a string according to the Nate 8-bit date format''' | |
m = _korean_nate_date_re.match(dateString) | |
if not m: return | |
hour = int(m.group(5)) | |
ampm = m.group(4) | |
if (ampm == _korean_pm): | |
hour += 12 | |
hour = str(hour) | |
if len(hour) == 1: | |
hour = '0' + hour | |
w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ | |
{'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ | |
'hour': hour, 'minute': m.group(6), 'second': m.group(7),\ | |
'zonediff': '+09:00'} | |
if _debug: sys.stderr.write('Nate date parsed as: %s\n' % w3dtfdate) | |
return _parse_date_w3dtf(w3dtfdate) | |
registerDateHandler(_parse_date_nate) | |
_mssql_date_re = \ | |
re.compile('(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2}):(\d{2})(\.\d+)?') | |
def _parse_date_mssql(dateString): | |
'''Parse a string according to the MS SQL date format''' | |
m = _mssql_date_re.match(dateString) | |
if not m: return | |
w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s:%(second)s%(zonediff)s' % \ | |
{'year': m.group(1), 'month': m.group(2), 'day': m.group(3),\ | |
'hour': m.group(4), 'minute': m.group(5), 'second': m.group(6),\ | |
'zonediff': '+09:00'} | |
if _debug: sys.stderr.write('MS SQL date parsed as: %s\n' % w3dtfdate) | |
return _parse_date_w3dtf(w3dtfdate) | |
registerDateHandler(_parse_date_mssql) | |
# Unicode strings for Greek date strings | |
_greek_months = \ | |
{ \ | |
u'\u0399\u03b1\u03bd': u'Jan', # c9e1ed in iso-8859-7 | |
u'\u03a6\u03b5\u03b2': u'Feb', # d6e5e2 in iso-8859-7 | |
u'\u039c\u03ac\u03ce': u'Mar', # ccdcfe in iso-8859-7 | |
u'\u039c\u03b1\u03ce': u'Mar', # cce1fe in iso-8859-7 | |
u'\u0391\u03c0\u03c1': u'Apr', # c1f0f1 in iso-8859-7 | |
u'\u039c\u03ac\u03b9': u'May', # ccdce9 in iso-8859-7 | |
u'\u039c\u03b1\u03ca': u'May', # cce1fa in iso-8859-7 | |
u'\u039c\u03b1\u03b9': u'May', # cce1e9 in iso-8859-7 | |
u'\u0399\u03bf\u03cd\u03bd': u'Jun', # c9effded in iso-8859-7 | |
u'\u0399\u03bf\u03bd': u'Jun', # c9efed in iso-8859-7 | |
u'\u0399\u03bf\u03cd\u03bb': u'Jul', # c9effdeb in iso-8859-7 | |
u'\u0399\u03bf\u03bb': u'Jul', # c9f9eb in iso-8859-7 | |
u'\u0391\u03cd\u03b3': u'Aug', # c1fde3 in iso-8859-7 | |
u'\u0391\u03c5\u03b3': u'Aug', # c1f5e3 in iso-8859-7 | |
u'\u03a3\u03b5\u03c0': u'Sep', # d3e5f0 in iso-8859-7 | |
u'\u039f\u03ba\u03c4': u'Oct', # cfeaf4 in iso-8859-7 | |
u'\u039d\u03bf\u03ad': u'Nov', # cdefdd in iso-8859-7 | |
u'\u039d\u03bf\u03b5': u'Nov', # cdefe5 in iso-8859-7 | |
u'\u0394\u03b5\u03ba': u'Dec', # c4e5ea in iso-8859-7 | |
} | |
_greek_wdays = \ | |
{ \ | |
u'\u039a\u03c5\u03c1': u'Sun', # caf5f1 in iso-8859-7 | |
u'\u0394\u03b5\u03c5': u'Mon', # c4e5f5 in iso-8859-7 | |
u'\u03a4\u03c1\u03b9': u'Tue', # d4f1e9 in iso-8859-7 | |
u'\u03a4\u03b5\u03c4': u'Wed', # d4e5f4 in iso-8859-7 | |
u'\u03a0\u03b5\u03bc': u'Thu', # d0e5ec in iso-8859-7 | |
u'\u03a0\u03b1\u03c1': u'Fri', # d0e1f1 in iso-8859-7 | |
u'\u03a3\u03b1\u03b2': u'Sat', # d3e1e2 in iso-8859-7 | |
} | |
_greek_date_format_re = \ | |
re.compile(u'([^,]+),\s+(\d{2})\s+([^\s]+)\s+(\d{4})\s+(\d{2}):(\d{2}):(\d{2})\s+([^\s]+)') | |
def _parse_date_greek(dateString): | |
'''Parse a string according to a Greek 8-bit date format.''' | |
m = _greek_date_format_re.match(dateString) | |
if not m: return | |
try: | |
wday = _greek_wdays[m.group(1)] | |
month = _greek_months[m.group(3)] | |
except: | |
return | |
rfc822date = '%(wday)s, %(day)s %(month)s %(year)s %(hour)s:%(minute)s:%(second)s %(zonediff)s' % \ | |
{'wday': wday, 'day': m.group(2), 'month': month, 'year': m.group(4),\ | |
'hour': m.group(5), 'minute': m.group(6), 'second': m.group(7),\ | |
'zonediff': m.group(8)} | |
if _debug: sys.stderr.write('Greek date parsed as: %s\n' % rfc822date) | |
return _parse_date_rfc822(rfc822date) | |
registerDateHandler(_parse_date_greek) | |
# Unicode strings for Hungarian date strings | |
_hungarian_months = \ | |
{ \ | |
u'janu\u00e1r': u'01', # e1 in iso-8859-2 | |
u'febru\u00e1ri': u'02', # e1 in iso-8859-2 | |
u'm\u00e1rcius': u'03', # e1 in iso-8859-2 | |
u'\u00e1prilis': u'04', # e1 in iso-8859-2 | |
u'm\u00e1ujus': u'05', # e1 in iso-8859-2 | |
u'j\u00fanius': u'06', # fa in iso-8859-2 | |
u'j\u00falius': u'07', # fa in iso-8859-2 | |
u'augusztus': u'08', | |
u'szeptember': u'09', | |
u'okt\u00f3ber': u'10', # f3 in iso-8859-2 | |
u'november': u'11', | |
u'december': u'12', | |
} | |
_hungarian_date_format_re = \ | |
re.compile(u'(\d{4})-([^-]+)-(\d{,2})T(\d{,2}):(\d{2})((\+|-)(\d{,2}:\d{2}))') | |
def _parse_date_hungarian(dateString): | |
'''Parse a string according to a Hungarian 8-bit date format.''' | |
m = _hungarian_date_format_re.match(dateString) | |
if not m: return | |
try: | |
month = _hungarian_months[m.group(2)] | |
day = m.group(3) | |
if len(day) == 1: | |
day = '0' + day | |
hour = m.group(4) | |
if len(hour) == 1: | |
hour = '0' + hour | |
except: | |
return | |
w3dtfdate = '%(year)s-%(month)s-%(day)sT%(hour)s:%(minute)s%(zonediff)s' % \ | |
{'year': m.group(1), 'month': month, 'day': day,\ | |
'hour': hour, 'minute': m.group(5),\ | |
'zonediff': m.group(6)} | |
if _debug: sys.stderr.write('Hungarian date parsed as: %s\n' % w3dtfdate) | |
return _parse_date_w3dtf(w3dtfdate) | |
registerDateHandler(_parse_date_hungarian) | |
# W3DTF-style date parsing adapted from PyXML xml.utils.iso8601, written by | |
# Drake and licensed under the Python license. Removed all range checking | |
# for month, day, hour, minute, and second, since mktime will normalize | |
# these later | |
def _parse_date_w3dtf(dateString): | |
def __extract_date(m): | |
year = int(m.group('year')) | |
if year < 100: | |
year = 100 * int(time.gmtime()[0] / 100) + int(year) | |
if year < 1000: | |
return 0, 0, 0 | |
julian = m.group('julian') | |
if julian: | |
julian = int(julian) | |
month = julian / 30 + 1 | |
day = julian % 30 + 1 | |
jday = None | |
while jday != julian: | |
t = time.mktime((year, month, day, 0, 0, 0, 0, 0, 0)) | |
jday = time.gmtime(t)[-2] | |
diff = abs(jday - julian) | |
if jday > julian: | |
if diff < day: | |
day = day - diff | |
else: | |
month = month - 1 | |
day = 31 | |
elif jday < julian: | |
if day + diff < 28: | |
day = day + diff | |
else: | |
month = month + 1 | |
return year, month, day | |
month = m.group('month') | |
day = 1 | |
if month is None: | |
month = 1 | |
else: | |
month = int(month) | |
day = m.group('day') | |
if day: | |
day = int(day) | |
else: | |
day = 1 | |
return year, month, day | |
def __extract_time(m): | |
if not m: | |
return 0, 0, 0 | |
hours = m.group('hours') | |
if not hours: | |
return 0, 0, 0 | |
hours = int(hours) | |
minutes = int(m.group('minutes')) | |
seconds = m.group('seconds') | |
if seconds: | |
seconds = int(seconds) | |
else: | |
seconds = 0 | |
return hours, minutes, seconds | |
def __extract_tzd(m): | |
'''Return the Time Zone Designator as an offset in seconds from UTC.''' | |
if not m: | |
return 0 | |
tzd = m.group('tzd') | |
if not tzd: | |
return 0 | |
if tzd == 'Z': | |
return 0 | |
hours = int(m.group('tzdhours')) | |
minutes = m.group('tzdminutes') | |
if minutes: | |
minutes = int(minutes) | |
else: | |
minutes = 0 | |
offset = (hours*60 + minutes) * 60 | |
if tzd[0] == '+': | |
return -offset | |
return offset | |
__date_re = ('(?P<year>\d\d\d\d)' | |
'(?:(?P<dsep>-|)' | |
'(?:(?P<month>\d\d)(?:(?P=dsep)(?P<day>\d\d))?' | |
'|(?P<julian>\d\d\d)))?') | |
__tzd_re = '(?P<tzd>[-+](?P<tzdhours>\d\d)(?::?(?P<tzdminutes>\d\d))|Z)' | |
__tzd_rx = re.compile(__tzd_re) | |
__time_re = ('(?P<hours>\d\d)(?P<tsep>:|)(?P<minutes>\d\d)' | |
'(?:(?P=tsep)(?P<seconds>\d\d)(?:[.,]\d+)?)?' | |
+ __tzd_re) | |
__datetime_re = '%s(?:T%s)?' % (__date_re, __time_re) | |
__datetime_rx = re.compile(__datetime_re) | |
m = __datetime_rx.match(dateString) | |
if (m is None) or (m.group() != dateString): return | |
gmt = __extract_date(m) + __extract_time(m) + (0, 0, 0) | |
if gmt[0] == 0: return | |
return time.gmtime(time.mktime(gmt) + __extract_tzd(m) - time.timezone) | |
registerDateHandler(_parse_date_w3dtf) | |
def _parse_date_rfc822(dateString): | |
'''Parse an RFC822, RFC1123, RFC2822, or asctime-style date''' | |
data = dateString.split() | |
if data[0][-1] in (',', '.') or data[0].lower() in rfc822._daynames: | |
del data[0] | |
if len(data) == 4: | |
s = data[3] | |
i = s.find('+') | |
if i > 0: | |
data[3:] = [s[:i], s[i+1:]] | |
else: | |
data.append('') | |
dateString = " ".join(data) | |
# Account for the Etc/GMT timezone by stripping 'Etc/' | |
elif len(data) == 5 and data[4].lower().startswith('etc/'): | |
data[4] = data[4][4:] | |
dateString = " ".join(data) | |
if len(data) < 5: | |
dateString += ' 00:00:00 GMT' | |
tm = rfc822.parsedate_tz(dateString) | |
if tm: | |
return time.gmtime(rfc822.mktime_tz(tm)) | |
# rfc822.py defines several time zones, but we define some extra ones. | |
# 'ET' is equivalent to 'EST', etc. | |
_additional_timezones = {'AT': -400, 'ET': -500, 'CT': -600, 'MT': -700, 'PT': -800} | |
rfc822._timezones.update(_additional_timezones) | |
registerDateHandler(_parse_date_rfc822) | |
def _parse_date_perforce(aDateString): | |
"""parse a date in yyyy/mm/dd hh:mm:ss TTT format""" | |
# Fri, 2006/09/15 08:19:53 EDT | |
_my_date_pattern = re.compile( \ | |
r'(\w{,3}), (\d{,4})/(\d{,2})/(\d{2}) (\d{,2}):(\d{2}):(\d{2}) (\w{,3})') | |
dow, year, month, day, hour, minute, second, tz = \ | |
_my_date_pattern.search(aDateString).groups() | |
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] | |
dateString = "%s, %s %s %s %s:%s:%s %s" % (dow, day, months[int(month) - 1], year, hour, minute, second, tz) | |
tm = rfc822.parsedate_tz(dateString) | |
if tm: | |
return time.gmtime(rfc822.mktime_tz(tm)) | |
registerDateHandler(_parse_date_perforce) | |
def _parse_date(dateString): | |
'''Parses a variety of date formats into a 9-tuple in GMT''' | |
for handler in _date_handlers: | |
try: | |
date9tuple = handler(dateString) | |
if not date9tuple: continue | |
if len(date9tuple) != 9: | |
if _debug: sys.stderr.write('date handler function must return 9-tuple\n') | |
raise ValueError | |
map(int, date9tuple) | |
return date9tuple | |
except Exception, e: | |
if _debug: sys.stderr.write('%s raised %s\n' % (handler.__name__, repr(e))) | |
pass | |
return None | |
def _getCharacterEncoding(http_headers, xml_data): | |
'''Get the character encoding of the XML document | |
http_headers is a dictionary | |
xml_data is a raw string (not Unicode) | |
This is so much trickier than it sounds, it's not even funny. | |
According to RFC 3023 ('XML Media Types'), if the HTTP Content-Type | |
is application/xml, application/*+xml, | |
application/xml-external-parsed-entity, or application/xml-dtd, | |
the encoding given in the charset parameter of the HTTP Content-Type | |
takes precedence over the encoding given in the XML prefix within the | |
document, and defaults to 'utf-8' if neither are specified. But, if | |
the HTTP Content-Type is text/xml, text/*+xml, or | |
text/xml-external-parsed-entity, the encoding given in the XML prefix | |
within the document is ALWAYS IGNORED and only the encoding given in | |
the charset parameter of the HTTP Content-Type header should be | |
respected, and it defaults to 'us-ascii' if not specified. | |
Furthermore, discussion on the atom-syntax mailing list with the | |
author of RFC 3023 leads me to the conclusion that any document | |
served with a Content-Type of text/* and no charset parameter | |
must be treated as us-ascii. (We now do this.) And also that it | |
must always be flagged as non-well-formed. (We now do this too.) | |
If Content-Type is unspecified (input was local file or non-HTTP source) | |
or unrecognized (server just got it totally wrong), then go by the | |
encoding given in the XML prefix of the document and default to | |
'iso-8859-1' as per the HTTP specification (RFC 2616). | |
Then, assuming we didn't find a character encoding in the HTTP headers | |
(and the HTTP Content-type allowed us to look in the body), we need | |
to sniff the first few bytes of the XML data and try to determine | |
whether the encoding is ASCII-compatible. Section F of the XML | |
specification shows the way here: | |
http://www.w3.org/TR/REC-xml/#sec-guessing-no-ext-info | |
If the sniffed encoding is not ASCII-compatible, we need to make it | |
ASCII compatible so that we can sniff further into the XML declaration | |
to find the encoding attribute, which will tell us the true encoding. | |
Of course, none of this guarantees that we will be able to parse the | |
feed in the declared character encoding (assuming it was declared | |
correctly, which many are not). CJKCodecs and iconv_codec help a lot; | |
you should definitely install them if you can. | |
http://cjkpython.i18n.org/ | |
''' | |
def _parseHTTPContentType(content_type): | |
'''takes HTTP Content-Type header and returns (content type, charset) | |
If no charset is specified, returns (content type, '') | |
If no content type is specified, returns ('', '') | |
Both return parameters are guaranteed to be lowercase strings | |
''' | |
content_type = content_type or '' | |
content_type, params = cgi.parse_header(content_type) | |
return content_type, params.get('charset', '').replace("'", '') | |
sniffed_xml_encoding = '' | |
xml_encoding = '' | |
true_encoding = '' | |
http_content_type, http_encoding = _parseHTTPContentType(http_headers.get('content-type', http_headers.get('Content-type'))) | |
# Must sniff for non-ASCII-compatible character encodings before | |
# searching for XML declaration. This heuristic is defined in | |
# section F of the XML specification: | |
# http://www.w3.org/TR/REC-xml/#sec-guessing-no-ext-info | |
try: | |
if xml_data[:4] == _l2bytes([0x4c, 0x6f, 0xa7, 0x94]): | |
# EBCDIC | |
xml_data = _ebcdic_to_ascii(xml_data) | |
elif xml_data[:4] == _l2bytes([0x00, 0x3c, 0x00, 0x3f]): | |
# UTF-16BE | |
sniffed_xml_encoding = 'utf-16be' | |
xml_data = unicode(xml_data, 'utf-16be').encode('utf-8') | |
elif (len(xml_data) >= 4) and (xml_data[:2] == _l2bytes([0xfe, 0xff])) and (xml_data[2:4] != _l2bytes([0x00, 0x00])): | |
# UTF-16BE with BOM | |
sniffed_xml_encoding = 'utf-16be' | |
xml_data = unicode(xml_data[2:], 'utf-16be').encode('utf-8') | |
elif xml_data[:4] == _l2bytes([0x3c, 0x00, 0x3f, 0x00]): | |
# UTF-16LE | |
sniffed_xml_encoding = 'utf-16le' | |
xml_data = unicode(xml_data, 'utf-16le').encode('utf-8') | |
elif (len(xml_data) >= 4) and (xml_data[:2] == _l2bytes([0xff, 0xfe])) and (xml_data[2:4] != _l2bytes([0x00, 0x00])): | |
# UTF-16LE with BOM | |
sniffed_xml_encoding = 'utf-16le' | |
xml_data = unicode(xml_data[2:], 'utf-16le').encode('utf-8') | |
elif xml_data[:4] == _l2bytes([0x00, 0x00, 0x00, 0x3c]): | |
# UTF-32BE | |
sniffed_xml_encoding = 'utf-32be' | |
xml_data = unicode(xml_data, 'utf-32be').encode('utf-8') | |
elif xml_data[:4] == _l2bytes([0x3c, 0x00, 0x00, 0x00]): | |
# UTF-32LE | |
sniffed_xml_encoding = 'utf-32le' | |
xml_data = unicode(xml_data, 'utf-32le').encode('utf-8') | |
elif xml_data[:4] == _l2bytes([0x00, 0x00, 0xfe, 0xff]): | |
# UTF-32BE with BOM | |
sniffed_xml_encoding = 'utf-32be' | |
xml_data = unicode(xml_data[4:], 'utf-32be').encode('utf-8') | |
elif xml_data[:4] == _l2bytes([0xff, 0xfe, 0x00, 0x00]): | |
# UTF-32LE with BOM | |
sniffed_xml_encoding = 'utf-32le' | |
xml_data = unicode(xml_data[4:], 'utf-32le').encode('utf-8') | |
elif xml_data[:3] == _l2bytes([0xef, 0xbb, 0xbf]): | |
# UTF-8 with BOM | |
sniffed_xml_encoding = 'utf-8' | |
xml_data = unicode(xml_data[3:], 'utf-8').encode('utf-8') | |
else: | |
# ASCII-compatible | |
pass | |
xml_encoding_match = re.compile(_s2bytes('^<\?.*encoding=[\'"](.*?)[\'"].*\?>')).match(xml_data) | |
except: | |
xml_encoding_match = None | |
if xml_encoding_match: | |
xml_encoding = xml_encoding_match.groups()[0].decode('utf-8').lower() | |
if sniffed_xml_encoding and (xml_encoding in ('iso-10646-ucs-2', 'ucs-2', 'csunicode', 'iso-10646-ucs-4', 'ucs-4', 'csucs4', 'utf-16', 'utf-32', 'utf_16', 'utf_32', 'utf16', 'u16')): | |
xml_encoding = sniffed_xml_encoding | |
acceptable_content_type = 0 | |
application_content_types = ('application/xml', 'application/xml-dtd', 'application/xml-external-parsed-entity') | |
text_content_types = ('text/xml', 'text/xml-external-parsed-entity') | |
if (http_content_type in application_content_types) or \ | |
(http_content_type.startswith('application/') and http_content_type.endswith('+xml')): | |
acceptable_content_type = 1 | |
true_encoding = http_encoding or xml_encoding or 'utf-8' | |
elif (http_content_type in text_content_types) or \ | |
(http_content_type.startswith('text/')) and http_content_type.endswith('+xml'): | |
acceptable_content_type = 1 | |
true_encoding = http_encoding or 'us-ascii' | |
elif http_content_type.startswith('text/'): | |
true_encoding = http_encoding or 'us-ascii' | |
elif http_headers and (not (http_headers.has_key('content-type') or http_headers.has_key('Content-type'))): | |
true_encoding = xml_encoding or 'iso-8859-1' | |
else: | |
true_encoding = xml_encoding or 'utf-8' | |
# some feeds claim to be gb2312 but are actually gb18030. | |
# apparently MSIE and Firefox both do the following switch: | |
if true_encoding.lower() == 'gb2312': | |
true_encoding = 'gb18030' | |
return true_encoding, http_encoding, xml_encoding, sniffed_xml_encoding, acceptable_content_type | |
def _toUTF8(data, encoding): | |
'''Changes an XML data stream on the fly to specify a new encoding | |
data is a raw sequence of bytes (not Unicode) that is presumed to be in %encoding already | |
encoding is a string recognized by encodings.aliases | |
''' | |
if _debug: sys.stderr.write('entering _toUTF8, trying encoding %s\n' % encoding) | |
# strip Byte Order Mark (if present) | |
if (len(data) >= 4) and (data[:2] == _l2bytes([0xfe, 0xff])) and (data[2:4] != _l2bytes([0x00, 0x00])): | |
if _debug: | |
sys.stderr.write('stripping BOM\n') | |
if encoding != 'utf-16be': | |
sys.stderr.write('trying utf-16be instead\n') | |
encoding = 'utf-16be' | |
data = data[2:] | |
elif (len(data) >= 4) and (data[:2] == _l2bytes([0xff, 0xfe])) and (data[2:4] != _l2bytes([0x00, 0x00])): | |
if _debug: | |
sys.stderr.write('stripping BOM\n') | |
if encoding != 'utf-16le': | |
sys.stderr.write('trying utf-16le instead\n') | |
encoding = 'utf-16le' | |
data = data[2:] | |
elif data[:3] == _l2bytes([0xef, 0xbb, 0xbf]): | |
if _debug: | |
sys.stderr.write('stripping BOM\n') | |
if encoding != 'utf-8': | |
sys.stderr.write('trying utf-8 instead\n') | |
encoding = 'utf-8' | |
data = data[3:] | |
elif data[:4] == _l2bytes([0x00, 0x00, 0xfe, 0xff]): | |
if _debug: | |
sys.stderr.write('stripping BOM\n') | |
if encoding != 'utf-32be': | |
sys.stderr.write('trying utf-32be instead\n') | |
encoding = 'utf-32be' | |
data = data[4:] | |
elif data[:4] == _l2bytes([0xff, 0xfe, 0x00, 0x00]): | |
if _debug: | |
sys.stderr.write('stripping BOM\n') | |
if encoding != 'utf-32le': | |
sys.stderr.write('trying utf-32le instead\n') | |
encoding = 'utf-32le' | |
data = data[4:] | |
newdata = unicode(data, encoding) | |
if _debug: sys.stderr.write('successfully converted %s data to unicode\n' % encoding) | |
declmatch = re.compile('^<\?xml[^>]*?>') | |
newdecl = '''<?xml version='1.0' encoding='utf-8'?>''' | |
if declmatch.search(newdata): | |
newdata = declmatch.sub(newdecl, newdata) | |
else: | |
newdata = newdecl + u'\n' + newdata | |
return newdata.encode('utf-8') | |
def _stripDoctype(data): | |
'''Strips DOCTYPE from XML document, returns (rss_version, stripped_data) | |
rss_version may be 'rss091n' or None | |
stripped_data is the same XML document, minus the DOCTYPE | |
''' | |
start = re.search(_s2bytes('<\w'), data) | |
start = start and start.start() or -1 | |
head,data = data[:start+1], data[start+1:] | |
entity_pattern = re.compile(_s2bytes(r'^\s*<!ENTITY([^>]*?)>'), re.MULTILINE) | |
entity_results=entity_pattern.findall(head) | |
head = entity_pattern.sub(_s2bytes(''), head) | |
doctype_pattern = re.compile(_s2bytes(r'^\s*<!DOCTYPE([^>]*?)>'), re.MULTILINE) | |
doctype_results = doctype_pattern.findall(head) | |
doctype = doctype_results and doctype_results[0] or _s2bytes('') | |
if doctype.lower().count(_s2bytes('netscape')): | |
version = 'rss091n' | |
else: | |
version = None | |
# only allow in 'safe' inline entity definitions | |
replacement=_s2bytes('') | |
if len(doctype_results)==1 and entity_results: | |
safe_pattern=re.compile(_s2bytes('\s+(\w+)\s+"(&#\w+;|[^&"]*)"')) | |
safe_entities=filter(lambda e: safe_pattern.match(e),entity_results) | |
if safe_entities: | |
replacement=_s2bytes('<!DOCTYPE feed [\n <!ENTITY') + _s2bytes('>\n <!ENTITY ').join(safe_entities) + _s2bytes('>\n]>') | |
data = doctype_pattern.sub(replacement, head) + data | |
return version, data, dict(replacement and [(k.decode('utf-8'), v.decode('utf-8')) for k, v in safe_pattern.findall(replacement)]) | |
def parse(url_file_stream_or_string, etag=None, modified=None, agent=None, referrer=None, handlers=[], request_headers={}, response_headers={}): | |
'''Parse a feed from a URL, file, stream, or string. | |
request_headers, if given, is a dict from http header name to value to add | |
to the request; this overrides internally generated values. | |
''' | |
result = FeedParserDict() | |
result['feed'] = FeedParserDict() | |
result['entries'] = [] | |
if _XML_AVAILABLE: | |
result['bozo'] = 0 | |
if not isinstance(handlers, list): | |
handlers = [handlers] | |
try: | |
f = _open_resource(url_file_stream_or_string, etag, modified, agent, referrer, handlers, request_headers) | |
data = f.read() | |
except Exception, e: | |
result['bozo'] = 1 | |
result['bozo_exception'] = e | |
data = None | |
f = None | |
if hasattr(f, 'headers'): | |
result['headers'] = dict(f.headers) | |
# overwrite existing headers using response_headers | |
if 'headers' in result: | |
result['headers'].update(response_headers) | |
elif response_headers: | |
result['headers'] = copy.deepcopy(response_headers) | |
# if feed is gzip-compressed, decompress it | |
if f and data and 'headers' in result: | |
if gzip and result['headers'].get('content-encoding') == 'gzip': | |
try: | |
data = gzip.GzipFile(fileobj=_StringIO(data)).read() | |
except Exception, e: | |
# Some feeds claim to be gzipped but they're not, so | |
# we get garbage. Ideally, we should re-request the | |
# feed without the 'Accept-encoding: gzip' header, | |
# but we don't. | |
result['bozo'] = 1 | |
result['bozo_exception'] = e | |
data = '' | |
elif zlib and result['headers'].get('content-encoding') == 'deflate': | |
try: | |
data = zlib.decompress(data, -zlib.MAX_WBITS) | |
except Exception, e: | |
result['bozo'] = 1 | |
result['bozo_exception'] = e | |
data = '' | |
# save HTTP headers | |
if 'headers' in result: | |
if 'etag' in result['headers'] or 'ETag' in result['headers']: | |
etag = result['headers'].get('etag', result['headers'].get('ETag')) | |
if etag: | |
result['etag'] = etag | |
if 'last-modified' in result['headers'] or 'Last-Modified' in result['headers']: | |
modified = result['headers'].get('last-modified', result['headers'].get('Last-Modified')) | |
if modified: | |
result['modified'] = _parse_date(modified) | |
if hasattr(f, 'url'): | |
result['href'] = f.url | |
result['status'] = 200 | |
if hasattr(f, 'status'): | |
result['status'] = f.status | |
if hasattr(f, 'close'): | |
f.close() | |
# there are four encodings to keep track of: | |
# - http_encoding is the encoding declared in the Content-Type HTTP header | |
# - xml_encoding is the encoding declared in the <?xml declaration | |
# - sniffed_encoding is the encoding sniffed from the first 4 bytes of the XML data | |
# - result['encoding'] is the actual encoding, as per RFC 3023 and a variety of other conflicting specifications | |
http_headers = result.get('headers', {}) | |
result['encoding'], http_encoding, xml_encoding, sniffed_xml_encoding, acceptable_content_type = \ | |
_getCharacterEncoding(http_headers, data) | |
if http_headers and (not acceptable_content_type): | |
if http_headers.has_key('content-type') or http_headers.has_key('Content-type'): | |
bozo_message = '%s is not an XML media type' % http_headers.get('content-type', http_headers.get('Content-type')) | |
else: | |
bozo_message = 'no Content-type specified' | |
result['bozo'] = 1 | |
result['bozo_exception'] = NonXMLContentType(bozo_message) | |
if data is not None: | |
result['version'], data, entities = _stripDoctype(data) | |
# ensure that baseuri is an absolute uri using an acceptable URI scheme | |
contentloc = http_headers.get('content-location', http_headers.get('Content-Location', '')) | |
href = result.get('href', '') | |
baseuri = _makeSafeAbsoluteURI(href, contentloc) or _makeSafeAbsoluteURI(contentloc) or href | |
baselang = http_headers.get('content-language', http_headers.get('Content-Language', None)) | |
# if server sent 304, we're done | |
if result.get('status', 0) == 304: | |
result['version'] = '' | |
result['debug_message'] = 'The feed has not changed since you last checked, ' + \ | |
'so the server sent no data. This is a feature, not a bug!' | |
return result | |
# if there was a problem downloading, we're done | |
if data is None: | |
return result | |
# determine character encoding | |
use_strict_parser = 0 | |
known_encoding = 0 | |
tried_encodings = [] | |
# try: HTTP encoding, declared XML encoding, encoding sniffed from BOM | |
for proposed_encoding in (result['encoding'], xml_encoding, sniffed_xml_encoding): | |
if not proposed_encoding: continue | |
if proposed_encoding in tried_encodings: continue | |
tried_encodings.append(proposed_encoding) | |
try: | |
data = _toUTF8(data, proposed_encoding) | |
known_encoding = use_strict_parser = 1 | |
break | |
except: | |
pass | |
# if no luck and we have auto-detection library, try that | |
if (not known_encoding) and chardet: | |
try: | |
proposed_encoding = chardet.detect(data)['encoding'] | |
if proposed_encoding and (proposed_encoding not in tried_encodings): | |
tried_encodings.append(proposed_encoding) | |
data = _toUTF8(data, proposed_encoding) | |
known_encoding = use_strict_parser = 1 | |
except: | |
pass | |
# if still no luck and we haven't tried utf-8 yet, try that | |
if (not known_encoding) and ('utf-8' not in tried_encodings): | |
try: | |
proposed_encoding = 'utf-8' | |
tried_encodings.append(proposed_encoding) | |
data = _toUTF8(data, proposed_encoding) | |
known_encoding = use_strict_parser = 1 | |
except: | |
pass | |
# if still no luck and we haven't tried windows-1252 yet, try that | |
if (not known_encoding) and ('windows-1252' not in tried_encodings): | |
try: | |
proposed_encoding = 'windows-1252' | |
tried_encodings.append(proposed_encoding) | |
data = _toUTF8(data, proposed_encoding) | |
known_encoding = use_strict_parser = 1 | |
except: | |
pass | |
# if still no luck and we haven't tried iso-8859-2 yet, try that. | |
if (not known_encoding) and ('iso-8859-2' not in tried_encodings): | |
try: | |
proposed_encoding = 'iso-8859-2' | |
tried_encodings.append(proposed_encoding) | |
data = _toUTF8(data, proposed_encoding) | |
known_encoding = use_strict_parser = 1 | |
except: | |
pass | |
# if still no luck, give up | |
if not known_encoding: | |
result['bozo'] = 1 | |
result['bozo_exception'] = CharacterEncodingUnknown( \ | |
'document encoding unknown, I tried ' + \ | |
'%s, %s, utf-8, windows-1252, and iso-8859-2 but nothing worked' % \ | |
(result['encoding'], xml_encoding)) | |
result['encoding'] = '' | |
elif proposed_encoding != result['encoding']: | |
result['bozo'] = 1 | |
result['bozo_exception'] = CharacterEncodingOverride( \ | |
'document declared as %s, but parsed as %s' % \ | |
(result['encoding'], proposed_encoding)) | |
result['encoding'] = proposed_encoding | |
if not _XML_AVAILABLE: | |
use_strict_parser = 0 | |
if use_strict_parser: | |
# initialize the SAX parser | |
feedparser = _StrictFeedParser(baseuri, baselang, 'utf-8') | |
saxparser = xml.sax.make_parser(PREFERRED_XML_PARSERS) | |
saxparser.setFeature(xml.sax.handler.feature_namespaces, 1) | |
saxparser.setContentHandler(feedparser) | |
saxparser.setErrorHandler(feedparser) | |
source = xml.sax.xmlreader.InputSource() | |
source.setByteStream(_StringIO(data)) | |
if hasattr(saxparser, '_ns_stack'): | |
# work around bug in built-in SAX parser (doesn't recognize xml: namespace) | |
# PyXML doesn't have this problem, and it doesn't have _ns_stack either | |
saxparser._ns_stack.append({'http://www.w3.org/XML/1998/namespace':'xml'}) | |
try: | |
saxparser.parse(source) | |
except Exception, e: | |
if _debug: | |
import traceback | |
traceback.print_stack() | |
traceback.print_exc() | |
sys.stderr.write('xml parsing failed\n') | |
result['bozo'] = 1 | |
result['bozo_exception'] = feedparser.exc or e | |
use_strict_parser = 0 | |
if not use_strict_parser: | |
feedparser = _LooseFeedParser(baseuri, baselang, 'utf-8', entities) | |
feedparser.feed(data.decode('utf-8', 'replace')) | |
result['feed'] = feedparser.feeddata | |
result['entries'] = feedparser.entries | |
result['version'] = result['version'] or feedparser.version | |
result['namespaces'] = feedparser.namespacesInUse | |
return result | |
class Serializer: | |
def __init__(self, results): | |
self.results = results | |
class TextSerializer(Serializer): | |
def write(self, stream=sys.stdout): | |
self._writer(stream, self.results, '') | |
def _writer(self, stream, node, prefix): | |
if not node: return | |
if hasattr(node, 'keys'): | |
keys = node.keys() | |
keys.sort() | |
for k in keys: | |
if k in ('description', 'link'): continue | |
if node.has_key(k + '_detail'): continue | |
if node.has_key(k + '_parsed'): continue | |
self._writer(stream, node[k], prefix + k + '.') | |
elif type(node) == types.ListType: | |
index = 0 | |
for n in node: | |
self._writer(stream, n, prefix[:-1] + '[' + str(index) + '].') | |
index += 1 | |
else: | |
try: | |
s = str(node).encode('utf-8') | |
s = s.replace('\\', '\\\\') | |
s = s.replace('\r', '') | |
s = s.replace('\n', r'\n') | |
stream.write(prefix[:-1]) | |
stream.write('=') | |
stream.write(s) | |
stream.write('\n') | |
except: | |
pass | |
class PprintSerializer(Serializer): | |
def write(self, stream=sys.stdout): | |
if self.results.has_key('href'): | |
stream.write(self.results['href'] + '\n\n') | |
from pprint import pprint | |
pprint(self.results, stream) | |
stream.write('\n') | |
if __name__ == '__main__': | |
try: | |
from optparse import OptionParser | |
except: | |
OptionParser = None | |
if OptionParser: | |
optionParser = OptionParser(version=__version__, usage="%prog [options] url_or_filename_or_-") | |
optionParser.set_defaults(format="pprint") | |
optionParser.add_option("-A", "--user-agent", dest="agent", metavar="AGENT", help="User-Agent for HTTP URLs") | |
optionParser.add_option("-e", "--referer", "--referrer", dest="referrer", metavar="URL", help="Referrer for HTTP URLs") | |
optionParser.add_option("-t", "--etag", dest="etag", metavar="TAG", help="ETag/If-None-Match for HTTP URLs") | |
optionParser.add_option("-m", "--last-modified", dest="modified", metavar="DATE", help="Last-modified/If-Modified-Since for HTTP URLs (any supported date format)") | |
optionParser.add_option("-f", "--format", dest="format", metavar="FORMAT", help="output results in FORMAT (text, pprint)") | |
optionParser.add_option("-v", "--verbose", action="store_true", dest="verbose", default=False, help="write debugging information to stderr") | |
(options, urls) = optionParser.parse_args() | |
if options.verbose: | |
_debug = 1 | |
if not urls: | |
optionParser.print_help() | |
sys.exit(0) | |
else: | |
if not sys.argv[1:]: | |
print __doc__ | |
sys.exit(0) | |
class _Options: | |
etag = modified = agent = referrer = None | |
format = 'pprint' | |
options = _Options() | |
urls = sys.argv[1:] | |
zopeCompatibilityHack() | |
serializer = globals().get(options.format.capitalize() + 'Serializer', Serializer) | |
for url in urls: | |
results = parse(url, etag=options.etag, modified=options.modified, agent=options.agent, referrer=options.referrer) | |
serializer(results).write(sys.stdout) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
HTML forms | |
(part of web.py) | |
""" | |
import copy, re | |
import webapi as web | |
import utils, net | |
def attrget(obj, attr, value=None): | |
if hasattr(obj, 'has_key') and obj.has_key(attr): return obj[attr] | |
if hasattr(obj, attr): return getattr(obj, attr) | |
return value | |
class Form(object): | |
r""" | |
HTML form. | |
>>> f = Form(Textbox("x")) | |
>>> f.render() | |
'<table>\n <tr><th><label for="x">x</label></th><td><input type="text" id="x" name="x"/></td></tr>\n</table>' | |
""" | |
def __init__(self, *inputs, **kw): | |
self.inputs = inputs | |
self.valid = True | |
self.note = None | |
self.validators = kw.pop('validators', []) | |
def __call__(self, x=None): | |
o = copy.deepcopy(self) | |
if x: o.validates(x) | |
return o | |
def render(self): | |
out = '' | |
out += self.rendernote(self.note) | |
out += '<table>\n' | |
for i in self.inputs: | |
html = utils.safeunicode(i.pre) + i.render() + self.rendernote(i.note) + utils.safeunicode(i.post) | |
if i.is_hidden(): | |
out += ' <tr style="display: none;"><th></th><td>%s</td></tr>\n' % (html) | |
else: | |
out += ' <tr><th><label for="%s">%s</label></th><td>%s</td></tr>\n' % (i.id, net.websafe(i.description), html) | |
out += "</table>" | |
return out | |
def render_css(self): | |
out = [] | |
out.append(self.rendernote(self.note)) | |
for i in self.inputs: | |
if not i.is_hidden(): | |
out.append('<label for="%s">%s</label>' % (i.id, net.websafe(i.description))) | |
out.append(i.pre) | |
out.append(i.render()) | |
out.append(self.rendernote(i.note)) | |
out.append(i.post) | |
out.append('\n') | |
return ''.join(out) | |
def rendernote(self, note): | |
if note: return '<strong class="wrong">%s</strong>' % net.websafe(note) | |
else: return "" | |
def validates(self, source=None, _validate=True, **kw): | |
source = source or kw or web.input() | |
out = True | |
for i in self.inputs: | |
v = attrget(source, i.name) | |
if _validate: | |
out = i.validate(v) and out | |
else: | |
i.set_value(v) | |
if _validate: | |
out = out and self._validate(source) | |
self.valid = out | |
return out | |
def _validate(self, value): | |
self.value = value | |
for v in self.validators: | |
if not v.valid(value): | |
self.note = v.msg | |
return False | |
return True | |
def fill(self, source=None, **kw): | |
return self.validates(source, _validate=False, **kw) | |
def __getitem__(self, i): | |
for x in self.inputs: | |
if x.name == i: return x | |
raise KeyError, i | |
def __getattr__(self, name): | |
# don't interfere with deepcopy | |
inputs = self.__dict__.get('inputs') or [] | |
for x in inputs: | |
if x.name == name: return x | |
raise AttributeError, name | |
def get(self, i, default=None): | |
try: | |
return self[i] | |
except KeyError: | |
return default | |
def _get_d(self): #@@ should really be form.attr, no? | |
return utils.storage([(i.name, i.get_value()) for i in self.inputs]) | |
d = property(_get_d) | |
class Input(object): | |
def __init__(self, name, *validators, **attrs): | |
self.name = name | |
self.validators = validators | |
self.attrs = attrs = AttributeList(attrs) | |
self.description = attrs.pop('description', name) | |
self.value = attrs.pop('value', None) | |
self.pre = attrs.pop('pre', "") | |
self.post = attrs.pop('post', "") | |
self.note = None | |
self.id = attrs.setdefault('id', self.get_default_id()) | |
if 'class_' in attrs: | |
attrs['class'] = attrs['class_'] | |
del attrs['class_'] | |
def is_hidden(self): | |
return False | |
def get_type(self): | |
raise NotImplementedError | |
def get_default_id(self): | |
return self.name | |
def validate(self, value): | |
self.set_value(value) | |
for v in self.validators: | |
if not v.valid(value): | |
self.note = v.msg | |
return False | |
return True | |
def set_value(self, value): | |
self.value = value | |
def get_value(self): | |
return self.value | |
def render(self): | |
attrs = self.attrs.copy() | |
attrs['type'] = self.get_type() | |
if self.value is not None: | |
attrs['value'] = self.value | |
attrs['name'] = self.name | |
return '<input %s/>' % attrs | |
def rendernote(self, note): | |
if note: return '<strong class="wrong">%s</strong>' % net.websafe(note) | |
else: return "" | |
def addatts(self): | |
# add leading space for backward-compatibility | |
return " " + str(self.attrs) | |
class AttributeList(dict): | |
"""List of atributes of input. | |
>>> a = AttributeList(type='text', name='x', value=20) | |
>>> a | |
<attrs: 'type="text" name="x" value="20"'> | |
""" | |
def copy(self): | |
return AttributeList(self) | |
def __str__(self): | |
return " ".join(['%s="%s"' % (k, net.websafe(v)) for k, v in self.items()]) | |
def __repr__(self): | |
return '<attrs: %s>' % repr(str(self)) | |
class Textbox(Input): | |
"""Textbox input. | |
>>> Textbox(name='foo', value='bar').render() | |
'<input type="text" id="foo" value="bar" name="foo"/>' | |
>>> Textbox(name='foo', value=0).render() | |
'<input type="text" id="foo" value="0" name="foo"/>' | |
""" | |
def get_type(self): | |
return 'text' | |
class Password(Input): | |
"""Password input. | |
>>> Password(name='password', value='secret').render() | |
'<input type="password" id="password" value="secret" name="password"/>' | |
""" | |
def get_type(self): | |
return 'password' | |
class Textarea(Input): | |
"""Textarea input. | |
>>> Textarea(name='foo', value='bar').render() | |
'<textarea id="foo" name="foo">bar</textarea>' | |
""" | |
def render(self): | |
attrs = self.attrs.copy() | |
attrs['name'] = self.name | |
value = net.websafe(self.value or '') | |
return '<textarea %s>%s</textarea>' % (attrs, value) | |
class Dropdown(Input): | |
r"""Dropdown/select input. | |
>>> Dropdown(name='foo', args=['a', 'b', 'c'], value='b').render() | |
'<select id="foo" name="foo">\n <option value="a">a</option>\n <option selected="selected" value="b">b</option>\n <option value="c">c</option>\n</select>\n' | |
>>> Dropdown(name='foo', args=[('a', 'aa'), ('b', 'bb'), ('c', 'cc')], value='b').render() | |
'<select id="foo" name="foo">\n <option value="a">aa</option>\n <option selected="selected" value="b">bb</option>\n <option value="c">cc</option>\n</select>\n' | |
""" | |
def __init__(self, name, args, *validators, **attrs): | |
self.args = args | |
super(Dropdown, self).__init__(name, *validators, **attrs) | |
def render(self): | |
attrs = self.attrs.copy() | |
attrs['name'] = self.name | |
x = '<select %s>\n' % attrs | |
for arg in self.args: | |
if isinstance(arg, (tuple, list)): | |
value, desc= arg | |
else: | |
value, desc = arg, arg | |
if self.value == value or (isinstance(self.value, list) and value in self.value): | |
select_p = ' selected="selected"' | |
else: select_p = '' | |
x += ' <option%s value="%s">%s</option>\n' % (select_p, net.websafe(value), net.websafe(desc)) | |
x += '</select>\n' | |
return x | |
class Radio(Input): | |
def __init__(self, name, args, *validators, **attrs): | |
self.args = args | |
super(Radio, self).__init__(name, *validators, **attrs) | |
def render(self): | |
x = '<span>' | |
for arg in self.args: | |
if isinstance(arg, (tuple, list)): | |
value, desc= arg | |
else: | |
value, desc = arg, arg | |
attrs = self.attrs.copy() | |
attrs['name'] = self.name | |
attrs['type'] = 'radio' | |
attrs['value'] = value | |
if self.value == value: | |
attrs['checked'] = 'checked' | |
x += '<input %s/> %s' % (attrs, net.websafe(desc)) | |
x += '</span>' | |
return x | |
class Checkbox(Input): | |
"""Checkbox input. | |
>>> Checkbox('foo', value='bar', checked=True).render() | |
'<input checked="checked" type="checkbox" id="foo_bar" value="bar" name="foo"/>' | |
>>> Checkbox('foo', value='bar').render() | |
'<input type="checkbox" id="foo_bar" value="bar" name="foo"/>' | |
>>> c = Checkbox('foo', value='bar') | |
>>> c.validate('on') | |
True | |
>>> c.render() | |
'<input checked="checked" type="checkbox" id="foo_bar" value="bar" name="foo"/>' | |
""" | |
def __init__(self, name, *validators, **attrs): | |
self.checked = attrs.pop('checked', False) | |
Input.__init__(self, name, *validators, **attrs) | |
def get_default_id(self): | |
value = utils.safestr(self.value or "") | |
return self.name + '_' + value.replace(' ', '_') | |
def render(self): | |
attrs = self.attrs.copy() | |
attrs['type'] = 'checkbox' | |
attrs['name'] = self.name | |
attrs['value'] = self.value | |
if self.checked: | |
attrs['checked'] = 'checked' | |
return '<input %s/>' % attrs | |
def set_value(self, value): | |
self.checked = bool(value) | |
def get_value(self): | |
return self.checked | |
class Button(Input): | |
"""HTML Button. | |
>>> Button("save").render() | |
'<button id="save" name="save">save</button>' | |
>>> Button("action", value="save", html="<b>Save Changes</b>").render() | |
'<button id="action" value="save" name="action"><b>Save Changes</b></button>' | |
""" | |
def __init__(self, name, *validators, **attrs): | |
super(Button, self).__init__(name, *validators, **attrs) | |
self.description = "" | |
def render(self): | |
attrs = self.attrs.copy() | |
attrs['name'] = self.name | |
if self.value is not None: | |
attrs['value'] = self.value | |
html = attrs.pop('html', None) or net.websafe(self.name) | |
return '<button %s>%s</button>' % (attrs, html) | |
class Hidden(Input): | |
"""Hidden Input. | |
>>> Hidden(name='foo', value='bar').render() | |
'<input type="hidden" id="foo" value="bar" name="foo"/>' | |
""" | |
def is_hidden(self): | |
return True | |
def get_type(self): | |
return 'hidden' | |
class File(Input): | |
"""File input. | |
>>> File(name='f').render() | |
'<input type="file" id="f" name="f"/>' | |
""" | |
def get_type(self): | |
return 'file' | |
class Validator: | |
def __deepcopy__(self, memo): return copy.copy(self) | |
def __init__(self, msg, test, jstest=None): utils.autoassign(self, locals()) | |
def valid(self, value): | |
try: return self.test(value) | |
except: return False | |
notnull = Validator("Required", bool) | |
class regexp(Validator): | |
def __init__(self, rexp, msg): | |
self.rexp = re.compile(rexp) | |
self.msg = msg | |
def valid(self, value): | |
return bool(self.rexp.match(value)) | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
HTTP Utilities | |
(from web.py) | |
""" | |
__all__ = [ | |
"expires", "lastmodified", | |
"prefixurl", "modified", | |
"changequery", "url", | |
"profiler", | |
] | |
import sys, os, threading, urllib, urlparse | |
try: import datetime | |
except ImportError: pass | |
import net, utils, webapi as web | |
def prefixurl(base=''): | |
""" | |
Sorry, this function is really difficult to explain. | |
Maybe some other time. | |
""" | |
url = web.ctx.path.lstrip('/') | |
for i in xrange(url.count('/')): | |
base += '../' | |
if not base: | |
base = './' | |
return base | |
def expires(delta): | |
""" | |
Outputs an `Expires` header for `delta` from now. | |
`delta` is a `timedelta` object or a number of seconds. | |
""" | |
if isinstance(delta, (int, long)): | |
delta = datetime.timedelta(seconds=delta) | |
date_obj = datetime.datetime.utcnow() + delta | |
web.header('Expires', net.httpdate(date_obj)) | |
def lastmodified(date_obj): | |
"""Outputs a `Last-Modified` header for `datetime`.""" | |
web.header('Last-Modified', net.httpdate(date_obj)) | |
def modified(date=None, etag=None): | |
""" | |
Checks to see if the page has been modified since the version in the | |
requester's cache. | |
When you publish pages, you can include `Last-Modified` and `ETag` | |
with the date the page was last modified and an opaque token for | |
the particular version, respectively. When readers reload the page, | |
the browser sends along the modification date and etag value for | |
the version it has in its cache. If the page hasn't changed, | |
the server can just return `304 Not Modified` and not have to | |
send the whole page again. | |
This function takes the last-modified date `date` and the ETag `etag` | |
and checks the headers to see if they match. If they do, it returns | |
`True`, or otherwise it raises NotModified error. It also sets | |
`Last-Modified` and `ETag` output headers. | |
""" | |
try: | |
from __builtin__ import set | |
except ImportError: | |
# for python 2.3 | |
from sets import Set as set | |
n = set([x.strip('" ') for x in web.ctx.env.get('HTTP_IF_NONE_MATCH', '').split(',')]) | |
m = net.parsehttpdate(web.ctx.env.get('HTTP_IF_MODIFIED_SINCE', '').split(';')[0]) | |
validate = False | |
if etag: | |
if '*' in n or etag in n: | |
validate = True | |
if date and m: | |
# we subtract a second because | |
# HTTP dates don't have sub-second precision | |
if date-datetime.timedelta(seconds=1) <= m: | |
validate = True | |
if date: lastmodified(date) | |
if etag: web.header('ETag', '"' + etag + '"') | |
if validate: | |
raise web.notmodified() | |
else: | |
return True | |
def urlencode(query, doseq=0): | |
""" | |
Same as urllib.urlencode, but supports unicode strings. | |
>>> urlencode({'text':'foo bar'}) | |
'text=foo+bar' | |
>>> urlencode({'x': [1, 2]}, doseq=True) | |
'x=1&x=2' | |
""" | |
def convert(value, doseq=False): | |
if doseq and isinstance(value, list): | |
return [convert(v) for v in value] | |
else: | |
return utils.safestr(value) | |
query = dict([(k, convert(v, doseq)) for k, v in query.items()]) | |
return urllib.urlencode(query, doseq=doseq) | |
def changequery(query=None, **kw): | |
""" | |
Imagine you're at `/foo?a=1&b=2`. Then `changequery(a=3)` will return | |
`/foo?a=3&b=2` -- the same URL but with the arguments you requested | |
changed. | |
""" | |
if query is None: | |
query = web.rawinput(method='get') | |
for k, v in kw.iteritems(): | |
if v is None: | |
query.pop(k, None) | |
else: | |
query[k] = v | |
out = web.ctx.path | |
if query: | |
out += '?' + urlencode(query, doseq=True) | |
return out | |
def url(path=None, doseq=False, **kw): | |
""" | |
Makes url by concatinating web.ctx.homepath and path and the | |
query string created using the arguments. | |
""" | |
if path is None: | |
path = web.ctx.path | |
if path.startswith("/"): | |
out = web.ctx.homepath + path | |
else: | |
out = path | |
if kw: | |
out += '?' + urlencode(kw, doseq=doseq) | |
return out | |
def profiler(app): | |
"""Outputs basic profiling information at the bottom of each response.""" | |
from utils import profile | |
def profile_internal(e, o): | |
out, result = profile(app)(e, o) | |
return list(out) + ['<pre>' + net.websafe(result) + '</pre>'] | |
return profile_internal | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__all__ = ["runsimple"] | |
import sys, os | |
from SimpleHTTPServer import SimpleHTTPRequestHandler | |
import urllib | |
import posixpath | |
import webapi as web | |
import net | |
import utils | |
def runbasic(func, server_address=("0.0.0.0", 8080)): | |
""" | |
Runs a simple HTTP server hosting WSGI app `func`. The directory `static/` | |
is hosted statically. | |
Based on [WsgiServer][ws] from [Colin Stewart][cs]. | |
[ws]: http://www.owlfish.com/software/wsgiutils/documentation/wsgi-server-api.html | |
[cs]: http://www.owlfish.com/ | |
""" | |
# Copyright (c) 2004 Colin Stewart (http://www.owlfish.com/) | |
# Modified somewhat for simplicity | |
# Used under the modified BSD license: | |
# http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5 | |
import SimpleHTTPServer, SocketServer, BaseHTTPServer, urlparse | |
import socket, errno | |
import traceback | |
class WSGIHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): | |
def run_wsgi_app(self): | |
protocol, host, path, parameters, query, fragment = \ | |
urlparse.urlparse('http://dummyhost%s' % self.path) | |
# we only use path, query | |
env = {'wsgi.version': (1, 0) | |
,'wsgi.url_scheme': 'http' | |
,'wsgi.input': self.rfile | |
,'wsgi.errors': sys.stderr | |
,'wsgi.multithread': 1 | |
,'wsgi.multiprocess': 0 | |
,'wsgi.run_once': 0 | |
,'REQUEST_METHOD': self.command | |
,'REQUEST_URI': self.path | |
,'PATH_INFO': path | |
,'QUERY_STRING': query | |
,'CONTENT_TYPE': self.headers.get('Content-Type', '') | |
,'CONTENT_LENGTH': self.headers.get('Content-Length', '') | |
,'REMOTE_ADDR': self.client_address[0] | |
,'SERVER_NAME': self.server.server_address[0] | |
,'SERVER_PORT': str(self.server.server_address[1]) | |
,'SERVER_PROTOCOL': self.request_version | |
} | |
for http_header, http_value in self.headers.items(): | |
env ['HTTP_%s' % http_header.replace('-', '_').upper()] = \ | |
http_value | |
# Setup the state | |
self.wsgi_sent_headers = 0 | |
self.wsgi_headers = [] | |
try: | |
# We have there environment, now invoke the application | |
result = self.server.app(env, self.wsgi_start_response) | |
try: | |
try: | |
for data in result: | |
if data: | |
self.wsgi_write_data(data) | |
finally: | |
if hasattr(result, 'close'): | |
result.close() | |
except socket.error, socket_err: | |
# Catch common network errors and suppress them | |
if (socket_err.args[0] in \ | |
(errno.ECONNABORTED, errno.EPIPE)): | |
return | |
except socket.timeout, socket_timeout: | |
return | |
except: | |
print >> web.debug, traceback.format_exc(), | |
if (not self.wsgi_sent_headers): | |
# We must write out something! | |
self.wsgi_write_data(" ") | |
return | |
do_POST = run_wsgi_app | |
do_PUT = run_wsgi_app | |
do_DELETE = run_wsgi_app | |
def do_GET(self): | |
if self.path.startswith('/static/'): | |
SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self) | |
else: | |
self.run_wsgi_app() | |
def wsgi_start_response(self, response_status, response_headers, | |
exc_info=None): | |
if (self.wsgi_sent_headers): | |
raise Exception \ | |
("Headers already sent and start_response called again!") | |
# Should really take a copy to avoid changes in the application.... | |
self.wsgi_headers = (response_status, response_headers) | |
return self.wsgi_write_data | |
def wsgi_write_data(self, data): | |
if (not self.wsgi_sent_headers): | |
status, headers = self.wsgi_headers | |
# Need to send header prior to data | |
status_code = status[:status.find(' ')] | |
status_msg = status[status.find(' ') + 1:] | |
self.send_response(int(status_code), status_msg) | |
for header, value in headers: | |
self.send_header(header, value) | |
self.end_headers() | |
self.wsgi_sent_headers = 1 | |
# Send the data | |
self.wfile.write(data) | |
class WSGIServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): | |
def __init__(self, func, server_address): | |
BaseHTTPServer.HTTPServer.__init__(self, | |
server_address, | |
WSGIHandler) | |
self.app = func | |
self.serverShuttingDown = 0 | |
print "http://%s:%d/" % server_address | |
WSGIServer(func, server_address).serve_forever() | |
def runsimple(func, server_address=("0.0.0.0", 8080)): | |
""" | |
Runs [CherryPy][cp] WSGI server hosting WSGI app `func`. | |
The directory `static/` is hosted statically. | |
[cp]: http://www.cherrypy.org | |
""" | |
func = StaticMiddleware(func) | |
func = LogMiddleware(func) | |
server = WSGIServer(server_address, func) | |
if server.ssl_adapter: | |
print "https://%s:%d/" % server_address | |
else: | |
print "http://%s:%d/" % server_address | |
try: | |
server.start() | |
except KeyboardInterrupt: | |
server.stop() | |
def WSGIServer(server_address, wsgi_app): | |
"""Creates CherryPy WSGI server listening at `server_address` to serve `wsgi_app`. | |
This function can be overwritten to customize the webserver or use a different webserver. | |
""" | |
import wsgiserver | |
# Default values of wsgiserver.ssl_adapters uses cheerypy.wsgiserver | |
# prefix. Overwriting it make it work with web.wsgiserver. | |
wsgiserver.ssl_adapters = { | |
'builtin': 'web.wsgiserver.ssl_builtin.BuiltinSSLAdapter', | |
'pyopenssl': 'web.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter', | |
} | |
server = wsgiserver.CherryPyWSGIServer(server_address, wsgi_app, server_name="localhost") | |
def create_ssl_adapter(cert, key): | |
# wsgiserver tries to import submodules as cherrypy.wsgiserver.foo. | |
# That doesn't work as not it is web.wsgiserver. | |
# Patching sys.modules temporarily to make it work. | |
import types | |
cherrypy = types.ModuleType('cherrypy') | |
cherrypy.wsgiserver = wsgiserver | |
sys.modules['cherrypy'] = cherrypy | |
sys.modules['cherrypy.wsgiserver'] = wsgiserver | |
from wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter | |
adapter = pyOpenSSLAdapter(cert, key) | |
# We are done with our work. Cleanup the patches. | |
del sys.modules['cherrypy'] | |
del sys.modules['cherrypy.wsgiserver'] | |
return adapter | |
# SSL backward compatibility | |
if (server.ssl_adapter is None and | |
getattr(server, 'ssl_certificate', None) and | |
getattr(server, 'ssl_private_key', None)): | |
server.ssl_adapter = create_ssl_adapter(server.ssl_certificate, server.ssl_private_key) | |
server.nodelay = not sys.platform.startswith('java') # TCP_NODELAY isn't supported on the JVM | |
return server | |
class StaticApp(SimpleHTTPRequestHandler): | |
"""WSGI application for serving static files.""" | |
def __init__(self, environ, start_response): | |
self.headers = [] | |
self.environ = environ | |
self.start_response = start_response | |
def send_response(self, status, msg=""): | |
self.status = str(status) + " " + msg | |
def send_header(self, name, value): | |
self.headers.append((name, value)) | |
def end_headers(self): | |
pass | |
def log_message(*a): pass | |
def __iter__(self): | |
environ = self.environ | |
self.path = environ.get('PATH_INFO', '') | |
self.client_address = environ.get('REMOTE_ADDR','-'), \ | |
environ.get('REMOTE_PORT','-') | |
self.command = environ.get('REQUEST_METHOD', '-') | |
from cStringIO import StringIO | |
self.wfile = StringIO() # for capturing error | |
try: | |
path = self.translate_path(self.path) | |
etag = '"%s"' % os.path.getmtime(path) | |
client_etag = environ.get('HTTP_IF_NONE_MATCH') | |
self.send_header('ETag', etag) | |
if etag == client_etag: | |
self.send_response(304, "Not Modified") | |
self.start_response(self.status, self.headers) | |
raise StopIteration | |
except OSError: | |
pass # Probably a 404 | |
f = self.send_head() | |
self.start_response(self.status, self.headers) | |
if f: | |
block_size = 16 * 1024 | |
while True: | |
buf = f.read(block_size) | |
if not buf: | |
break | |
yield buf | |
f.close() | |
else: | |
value = self.wfile.getvalue() | |
yield value | |
class StaticMiddleware: | |
"""WSGI middleware for serving static files.""" | |
def __init__(self, app, prefix='/static/'): | |
self.app = app | |
self.prefix = prefix | |
def __call__(self, environ, start_response): | |
path = environ.get('PATH_INFO', '') | |
path = self.normpath(path) | |
if path.startswith(self.prefix): | |
return StaticApp(environ, start_response) | |
else: | |
return self.app(environ, start_response) | |
def normpath(self, path): | |
path2 = posixpath.normpath(urllib.unquote(path)) | |
if path.endswith("/"): | |
path2 += "/" | |
return path2 | |
class LogMiddleware: | |
"""WSGI middleware for logging the status.""" | |
def __init__(self, app): | |
self.app = app | |
self.format = '%s - - [%s] "%s %s %s" - %s' | |
from BaseHTTPServer import BaseHTTPRequestHandler | |
import StringIO | |
f = StringIO.StringIO() | |
class FakeSocket: | |
def makefile(self, *a): | |
return f | |
# take log_date_time_string method from BaseHTTPRequestHandler | |
self.log_date_time_string = BaseHTTPRequestHandler(FakeSocket(), None, None).log_date_time_string | |
def __call__(self, environ, start_response): | |
def xstart_response(status, response_headers, *args): | |
out = start_response(status, response_headers, *args) | |
self.log(status, environ) | |
return out | |
return self.app(environ, xstart_response) | |
def log(self, status, environ): | |
outfile = environ.get('wsgi.errors', web.debug) | |
req = environ.get('PATH_INFO', '_') | |
protocol = environ.get('ACTUAL_SERVER_PROTOCOL', '-') | |
method = environ.get('REQUEST_METHOD', '-') | |
host = "%s:%s" % (environ.get('REMOTE_ADDR','-'), | |
environ.get('REMOTE_PORT','-')) | |
time = self.log_date_time_string() | |
msg = self.format % (host, time, protocol, method, req, status) | |
print >> outfile, utils.safestr(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Copyright (c) 2004-2007, CherryPy Team (team@cherrypy.org) | |
All rights reserved. | |
Redistribution and use in source and binary forms, with or without modification, | |
are permitted provided that the following conditions are met: | |
* Redistributions of source code must retain the above copyright notice, | |
this list of conditions and the following disclaimer. | |
* Redistributions in binary form must reproduce the above copyright notice, | |
this list of conditions and the following disclaimer in the documentation | |
and/or other materials provided with the distribution. | |
* Neither the name of the CherryPy Team nor the names of its contributors | |
may be used to endorse or promote products derived from this software | |
without specific prior written permission. | |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND | |
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED | |
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE | |
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Network Utilities | |
(from web.py) | |
""" | |
__all__ = [ | |
"validipaddr", "validipport", "validip", "validaddr", | |
"urlquote", | |
"httpdate", "parsehttpdate", | |
"htmlquote", "htmlunquote", "websafe", | |
] | |
import urllib, time | |
try: import datetime | |
except ImportError: pass | |
def validipaddr(address): | |
""" | |
Returns True if `address` is a valid IPv4 address. | |
>>> validipaddr('192.168.1.1') | |
True | |
>>> validipaddr('192.168.1.800') | |
False | |
>>> validipaddr('192.168.1') | |
False | |
""" | |
try: | |
octets = address.split('.') | |
if len(octets) != 4: | |
return False | |
for x in octets: | |
if not (0 <= int(x) <= 255): | |
return False | |
except ValueError: | |
return False | |
return True | |
def validipport(port): | |
""" | |
Returns True if `port` is a valid IPv4 port. | |
>>> validipport('9000') | |
True | |
>>> validipport('foo') | |
False | |
>>> validipport('1000000') | |
False | |
""" | |
try: | |
if not (0 <= int(port) <= 65535): | |
return False | |
except ValueError: | |
return False | |
return True | |
def validip(ip, defaultaddr="0.0.0.0", defaultport=8080): | |
"""Returns `(ip_address, port)` from string `ip_addr_port`""" | |
addr = defaultaddr | |
port = defaultport | |
ip = ip.split(":", 1) | |
if len(ip) == 1: | |
if not ip[0]: | |
pass | |
elif validipaddr(ip[0]): | |
addr = ip[0] | |
elif validipport(ip[0]): | |
port = int(ip[0]) | |
else: | |
raise ValueError, ':'.join(ip) + ' is not a valid IP address/port' | |
elif len(ip) == 2: | |
addr, port = ip | |
if not validipaddr(addr) and validipport(port): | |
raise ValueError, ':'.join(ip) + ' is not a valid IP address/port' | |
port = int(port) | |
else: | |
raise ValueError, ':'.join(ip) + ' is not a valid IP address/port' | |
return (addr, port) | |
def validaddr(string_): | |
""" | |
Returns either (ip_address, port) or "/path/to/socket" from string_ | |
>>> validaddr('/path/to/socket') | |
'/path/to/socket' | |
>>> validaddr('8000') | |
('0.0.0.0', 8000) | |
>>> validaddr('127.0.0.1') | |
('127.0.0.1', 8080) | |
>>> validaddr('127.0.0.1:8000') | |
('127.0.0.1', 8000) | |
>>> validaddr('fff') | |
Traceback (most recent call last): | |
... | |
ValueError: fff is not a valid IP address/port | |
""" | |
if '/' in string_: | |
return string_ | |
else: | |
return validip(string_) | |
def urlquote(val): | |
""" | |
Quotes a string for use in a URL. | |
>>> urlquote('://?f=1&j=1') | |
'%3A//%3Ff%3D1%26j%3D1' | |
>>> urlquote(None) | |
'' | |
>>> urlquote(u'\u203d') | |
'%E2%80%BD' | |
""" | |
if val is None: return '' | |
if not isinstance(val, unicode): val = str(val) | |
else: val = val.encode('utf-8') | |
return urllib.quote(val) | |
def httpdate(date_obj): | |
""" | |
Formats a datetime object for use in HTTP headers. | |
>>> import datetime | |
>>> httpdate(datetime.datetime(1970, 1, 1, 1, 1, 1)) | |
'Thu, 01 Jan 1970 01:01:01 GMT' | |
""" | |
return date_obj.strftime("%a, %d %b %Y %H:%M:%S GMT") | |
def parsehttpdate(string_): | |
""" | |
Parses an HTTP date into a datetime object. | |
>>> parsehttpdate('Thu, 01 Jan 1970 01:01:01 GMT') | |
datetime.datetime(1970, 1, 1, 1, 1, 1) | |
""" | |
try: | |
t = time.strptime(string_, "%a, %d %b %Y %H:%M:%S %Z") | |
except ValueError: | |
return None | |
return datetime.datetime(*t[:6]) | |
def htmlquote(text): | |
r""" | |
Encodes `text` for raw use in HTML. | |
>>> htmlquote(u"<'&\">") | |
u'<'&">' | |
""" | |
text = text.replace(u"&", u"&") # Must be done first! | |
text = text.replace(u"<", u"<") | |
text = text.replace(u">", u">") | |
text = text.replace(u"'", u"'") | |
text = text.replace(u'"', u""") | |
return text | |
def htmlunquote(text): | |
r""" | |
Decodes `text` that's HTML quoted. | |
>>> htmlunquote(u'<'&">') | |
u'<\'&">' | |
""" | |
text = text.replace(u""", u'"') | |
text = text.replace(u"'", u"'") | |
text = text.replace(u">", u">") | |
text = text.replace(u"<", u"<") | |
text = text.replace(u"&", u"&") # Must be done last! | |
return text | |
def websafe(val): | |
r"""Converts `val` so that it is safe for use in Unicode HTML. | |
>>> websafe("<'&\">") | |
u'<'&">' | |
>>> websafe(None) | |
u'' | |
>>> websafe(u'\u203d') | |
u'\u203d' | |
>>> websafe('\xe2\x80\xbd') | |
u'\u203d' | |
""" | |
if val is None: | |
return u'' | |
elif isinstance(val, str): | |
val = val.decode('utf-8') | |
elif not isinstance(val, unicode): | |
val = unicode(val) | |
return htmlquote(val) | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Give me a feed url, and I'll attempt to parse it for you and give you JSON | |
# Usage: /?url=http://example.com/feed | |
# Dependencies: python 2.6+, feedparser, and web.py | |
# Now runs on GAE, using web.py rather than webapp2 (so it can easily be run elsewhere) | |
from google.appengine.ext.webapp.util import run_wsgi_app | |
import web | |
import json | |
import feedparser | |
import time | |
def to_json(python_object): | |
if isinstance(python_object, time.struct_time): | |
return {'__class__': 'time.asctime', | |
'__value__': time.asctime(python_object)} | |
raise TypeError(repr(python_object) + ' is not JSON serializable') | |
urls = ('/', 'Parser') | |
app = web.application(urls, globals()) | |
class Parser: | |
def GET(self): | |
web.header('Content-Type', 'application/json') | |
try: | |
url = web.input().url | |
return json.dumps(feedparser.parse(url), default=to_json) | |
except Exception as e: | |
web.ctx.status = "400 Bad Request" | |
return json.dumps({"message": "Sorry, i couldn't parse that feed. Usage: /?url=http://example.com/feed"}) | |
if __name__ == "__main__": | |
application = app.wsgifunc() | |
run_wsgi_app(application) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Python 2.3 compatabilty""" | |
import threading | |
class threadlocal(object): | |
"""Implementation of threading.local for python2.3. | |
""" | |
def __getattribute__(self, name): | |
if name == "__dict__": | |
return threadlocal._getd(self) | |
else: | |
try: | |
return object.__getattribute__(self, name) | |
except AttributeError: | |
try: | |
return self.__dict__[name] | |
except KeyError: | |
raise AttributeError, name | |
def __setattr__(self, name, value): | |
self.__dict__[name] = value | |
def __delattr__(self, name): | |
try: | |
del self.__dict__[name] | |
except KeyError: | |
raise AttributeError, name | |
def _getd(self): | |
t = threading.currentThread() | |
if not hasattr(t, '_d'): | |
# using __dict__ of thread as thread local storage | |
t._d = {} | |
_id = id(self) | |
# there could be multiple instances of threadlocal. | |
# use id(self) as key | |
if _id not in t._d: | |
t._d[_id] = {} | |
return t._d[_id] | |
if __name__ == '__main__': | |
d = threadlocal() | |
d.x = 1 | |
print d.__dict__ | |
print d.x | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Session Management | |
(from web.py) | |
""" | |
import os, time, datetime, random, base64 | |
import os.path | |
from copy import deepcopy | |
try: | |
import cPickle as pickle | |
except ImportError: | |
import pickle | |
try: | |
import hashlib | |
sha1 = hashlib.sha1 | |
except ImportError: | |
import sha | |
sha1 = sha.new | |
import utils | |
import webapi as web | |
__all__ = [ | |
'Session', 'SessionExpired', | |
'Store', 'DiskStore', 'DBStore', | |
] | |
web.config.session_parameters = utils.storage({ | |
'cookie_name': 'webpy_session_id', | |
'cookie_domain': None, | |
'timeout': 86400, #24 * 60 * 60, # 24 hours in seconds | |
'ignore_expiry': True, | |
'ignore_change_ip': True, | |
'secret_key': 'fLjUfxqXtfNoIldA0A0J', | |
'expired_message': 'Session expired', | |
'httponly': True, | |
'secure': False | |
}) | |
class SessionExpired(web.HTTPError): | |
def __init__(self, message): | |
web.HTTPError.__init__(self, '200 OK', {}, data=message) | |
class Session(object): | |
"""Session management for web.py | |
""" | |
__slots__ = [ | |
"store", "_initializer", "_last_cleanup_time", "_config", "_data", | |
"__getitem__", "__setitem__", "__delitem__" | |
] | |
def __init__(self, app, store, initializer=None): | |
self.store = store | |
self._initializer = initializer | |
self._last_cleanup_time = 0 | |
self._config = utils.storage(web.config.session_parameters) | |
self._data = utils.threadeddict() | |
self.__getitem__ = self._data.__getitem__ | |
self.__setitem__ = self._data.__setitem__ | |
self.__delitem__ = self._data.__delitem__ | |
if app: | |
app.add_processor(self._processor) | |
def __contains__(self, name): | |
return name in self._data | |
def __getattr__(self, name): | |
return getattr(self._data, name) | |
def __setattr__(self, name, value): | |
if name in self.__slots__: | |
object.__setattr__(self, name, value) | |
else: | |
setattr(self._data, name, value) | |
def __delattr__(self, name): | |
delattr(self._data, name) | |
def _processor(self, handler): | |
"""Application processor to setup session for every request""" | |
self._cleanup() | |
self._load() | |
try: | |
return handler() | |
finally: | |
self._save() | |
def _load(self): | |
"""Load the session from the store, by the id from cookie""" | |
cookie_name = self._config.cookie_name | |
cookie_domain = self._config.cookie_domain | |
httponly = self._config.httponly | |
self.session_id = web.cookies().get(cookie_name) | |
# protection against session_id tampering | |
if self.session_id and not self._valid_session_id(self.session_id): | |
self.session_id = None | |
self._check_expiry() | |
if self.session_id: | |
d = self.store[self.session_id] | |
self.update(d) | |
self._validate_ip() | |
if not self.session_id: | |
self.session_id = self._generate_session_id() | |
if self._initializer: | |
if isinstance(self._initializer, dict): | |
self.update(deepcopy(self._initializer)) | |
elif hasattr(self._initializer, '__call__'): | |
self._initializer() | |
self.ip = web.ctx.ip | |
def _check_expiry(self): | |
# check for expiry | |
if self.session_id and self.session_id not in self.store: | |
if self._config.ignore_expiry: | |
self.session_id = None | |
else: | |
return self.expired() | |
def _validate_ip(self): | |
# check for change of IP | |
if self.session_id and self.get('ip', None) != web.ctx.ip: | |
if not self._config.ignore_change_ip: | |
return self.expired() | |
def _save(self): | |
if not self.get('_killed'): | |
self._setcookie(self.session_id) | |
self.store[self.session_id] = dict(self._data) | |
else: | |
self._setcookie(self.session_id, expires=-1) | |
def _setcookie(self, session_id, expires='', **kw): | |
cookie_name = self._config.cookie_name | |
cookie_domain = self._config.cookie_domain | |
httponly = self._config.httponly | |
secure = self._config.secure | |
web.setcookie(cookie_name, session_id, expires=expires, domain=cookie_domain, httponly=httponly, secure=secure) | |
def _generate_session_id(self): | |
"""Generate a random id for session""" | |
while True: | |
rand = os.urandom(16) | |
now = time.time() | |
secret_key = self._config.secret_key | |
session_id = sha1("%s%s%s%s" %(rand, now, utils.safestr(web.ctx.ip), secret_key)) | |
session_id = session_id.hexdigest() | |
if session_id not in self.store: | |
break | |
return session_id | |
def _valid_session_id(self, session_id): | |
rx = utils.re_compile('^[0-9a-fA-F]+$') | |
return rx.match(session_id) | |
def _cleanup(self): | |
"""Cleanup the stored sessions""" | |
current_time = time.time() | |
timeout = self._config.timeout | |
if current_time - self._last_cleanup_time > timeout: | |
self.store.cleanup(timeout) | |
self._last_cleanup_time = current_time | |
def expired(self): | |
"""Called when an expired session is atime""" | |
self._killed = True | |
self._save() | |
raise SessionExpired(self._config.expired_message) | |
def kill(self): | |
"""Kill the session, make it no longer available""" | |
del self.store[self.session_id] | |
self._killed = True | |
class Store: | |
"""Base class for session stores""" | |
def __contains__(self, key): | |
raise NotImplementedError | |
def __getitem__(self, key): | |
raise NotImplementedError | |
def __setitem__(self, key, value): | |
raise NotImplementedError | |
def cleanup(self, timeout): | |
"""removes all the expired sessions""" | |
raise NotImplementedError | |
def encode(self, session_dict): | |
"""encodes session dict as a string""" | |
pickled = pickle.dumps(session_dict) | |
return base64.encodestring(pickled) | |
def decode(self, session_data): | |
"""decodes the data to get back the session dict """ | |
pickled = base64.decodestring(session_data) | |
return pickle.loads(pickled) | |
class DiskStore(Store): | |
""" | |
Store for saving a session on disk. | |
>>> import tempfile | |
>>> root = tempfile.mkdtemp() | |
>>> s = DiskStore(root) | |
>>> s['a'] = 'foo' | |
>>> s['a'] | |
'foo' | |
>>> time.sleep(0.01) | |
>>> s.cleanup(0.01) | |
>>> s['a'] | |
Traceback (most recent call last): | |
... | |
KeyError: 'a' | |
""" | |
def __init__(self, root): | |
# if the storage root doesn't exists, create it. | |
if not os.path.exists(root): | |
os.makedirs( | |
os.path.abspath(root) | |
) | |
self.root = root | |
def _get_path(self, key): | |
if os.path.sep in key: | |
raise ValueError, "Bad key: %s" % repr(key) | |
return os.path.join(self.root, key) | |
def __contains__(self, key): | |
path = self._get_path(key) | |
return os.path.exists(path) | |
def __getitem__(self, key): | |
path = self._get_path(key) | |
if os.path.exists(path): | |
pickled = open(path).read() | |
return self.decode(pickled) | |
else: | |
raise KeyError, key | |
def __setitem__(self, key, value): | |
path = self._get_path(key) | |
pickled = self.encode(value) | |
try: | |
f = open(path, 'w') | |
try: | |
f.write(pickled) | |
finally: | |
f.close() | |
except IOError: | |
pass | |
def __delitem__(self, key): | |
path = self._get_path(key) | |
if os.path.exists(path): | |
os.remove(path) | |
def cleanup(self, timeout): | |
now = time.time() | |
for f in os.listdir(self.root): | |
path = self._get_path(f) | |
atime = os.stat(path).st_atime | |
if now - atime > timeout : | |
os.remove(path) | |
class DBStore(Store): | |
"""Store for saving a session in database | |
Needs a table with the following columns: | |
session_id CHAR(128) UNIQUE NOT NULL, | |
atime DATETIME NOT NULL default current_timestamp, | |
data TEXT | |
""" | |
def __init__(self, db, table_name): | |
self.db = db | |
self.table = table_name | |
def __contains__(self, key): | |
data = self.db.select(self.table, where="session_id=$key", vars=locals()) | |
return bool(list(data)) | |
def __getitem__(self, key): | |
now = datetime.datetime.now() | |
try: | |
s = self.db.select(self.table, where="session_id=$key", vars=locals())[0] | |
self.db.update(self.table, where="session_id=$key", atime=now, vars=locals()) | |
except IndexError: | |
raise KeyError | |
else: | |
return self.decode(s.data) | |
def __setitem__(self, key, value): | |
pickled = self.encode(value) | |
now = datetime.datetime.now() | |
if key in self: | |
self.db.update(self.table, where="session_id=$key", data=pickled, vars=locals()) | |
else: | |
self.db.insert(self.table, False, session_id=key, data=pickled ) | |
def __delitem__(self, key): | |
self.db.delete(self.table, where="session_id=$key", vars=locals()) | |
def cleanup(self, timeout): | |
timeout = datetime.timedelta(timeout/(24.0*60*60)) #timedelta takes numdays as arg | |
last_allowed_time = datetime.datetime.now() - timeout | |
self.db.delete(self.table, where="$last_allowed_time > atime", vars=locals()) | |
class ShelfStore: | |
"""Store for saving session using `shelve` module. | |
import shelve | |
store = ShelfStore(shelve.open('session.shelf')) | |
XXX: is shelve thread-safe? | |
""" | |
def __init__(self, shelf): | |
self.shelf = shelf | |
def __contains__(self, key): | |
return key in self.shelf | |
def __getitem__(self, key): | |
atime, v = self.shelf[key] | |
self[key] = v # update atime | |
return v | |
def __setitem__(self, key, value): | |
self.shelf[key] = time.time(), value | |
def __delitem__(self, key): | |
try: | |
del self.shelf[key] | |
except KeyError: | |
pass | |
def cleanup(self, timeout): | |
now = time.time() | |
for k in self.shelf.keys(): | |
atime, v = self.shelf[k] | |
if now - atime > timeout : | |
del self[k] | |
if __name__ == '__main__' : | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A parser for SGML, using the derived class as a static DTD.""" | |
# XXX This only supports those SGML features used by HTML. | |
# XXX There should be a way to distinguish between PCDATA (parsed | |
# character data -- the normal case), RCDATA (replaceable character | |
# data -- only char and entity references and end tags are special) | |
# and CDATA (character data -- only end tags are special). RCDATA is | |
# not supported at all. | |
import _markupbase | |
import re | |
__all__ = ["SGMLParser", "SGMLParseError"] | |
# Regular expressions used for parsing | |
interesting = re.compile('[&<]') | |
incomplete = re.compile('&([a-zA-Z][a-zA-Z0-9]*|#[0-9]*)?|' | |
'<([a-zA-Z][^<>]*|' | |
'/([a-zA-Z][^<>]*)?|' | |
'![^<>]*)?') | |
entityref = re.compile('&([a-zA-Z][-.a-zA-Z0-9]*)[^a-zA-Z0-9]') | |
charref = re.compile('&#([0-9]+)[^0-9]') | |
starttagopen = re.compile('<[>a-zA-Z]') | |
shorttagopen = re.compile('<[a-zA-Z][-.a-zA-Z0-9]*/') | |
shorttag = re.compile('<([a-zA-Z][-.a-zA-Z0-9]*)/([^/]*)/') | |
piclose = re.compile('>') | |
endbracket = re.compile('[<>]') | |
tagfind = re.compile('[a-zA-Z][-_.a-zA-Z0-9]*') | |
attrfind = re.compile( | |
r'\s*([a-zA-Z_][-:.a-zA-Z_0-9]*)(\s*=\s*' | |
r'(\'[^\']*\'|"[^"]*"|[][\-a-zA-Z0-9./,:;+*%?!&$\(\)_#=~\'"@]*))?') | |
class SGMLParseError(RuntimeError): | |
"""Exception raised for all parse errors.""" | |
pass | |
# SGML parser base class -- find tags and call handler functions. | |
# Usage: p = SGMLParser(); p.feed(data); ...; p.close(). | |
# The dtd is defined by deriving a class which defines methods | |
# with special names to handle tags: start_foo and end_foo to handle | |
# <foo> and </foo>, respectively, or do_foo to handle <foo> by itself. | |
# (Tags are converted to lower case for this purpose.) The data | |
# between tags is passed to the parser by calling self.handle_data() | |
# with some data as argument (the data may be split up in arbitrary | |
# chunks). Entity references are passed by calling | |
# self.handle_entityref() with the entity reference as argument. | |
class SGMLParser(_markupbase.ParserBase): | |
# Definition of entities -- derived classes may override | |
entity_or_charref = re.compile('&(?:' | |
'([a-zA-Z][-.a-zA-Z0-9]*)|#([0-9]+)' | |
')(;?)') | |
def __init__(self, verbose=0): | |
"""Initialize and reset this instance.""" | |
self.verbose = verbose | |
self.reset() | |
def reset(self): | |
"""Reset this instance. Loses all unprocessed data.""" | |
self.__starttag_text = None | |
self.rawdata = '' | |
self.stack = [] | |
self.lasttag = '???' | |
self.nomoretags = 0 | |
self.literal = 0 | |
_markupbase.ParserBase.reset(self) | |
def setnomoretags(self): | |
"""Enter literal mode (CDATA) till EOF. | |
Intended for derived classes only. | |
""" | |
self.nomoretags = self.literal = 1 | |
def setliteral(self, *args): | |
"""Enter literal mode (CDATA). | |
Intended for derived classes only. | |
""" | |
self.literal = 1 | |
def feed(self, data): | |
"""Feed some data to the parser. | |
Call this as often as you want, with as little or as much text | |
as you want (may include '\n'). (This just saves the text, | |
all the processing is done by goahead().) | |
""" | |
self.rawdata = self.rawdata + data | |
self.goahead(0) | |
def close(self): | |
"""Handle the remaining data.""" | |
self.goahead(1) | |
def error(self, message): | |
raise SGMLParseError(message) | |
# Internal -- handle data as far as reasonable. May leave state | |
# and data to be processed by a subsequent call. If 'end' is | |
# true, force handling all data as if followed by EOF marker. | |
def goahead(self, end): | |
rawdata = self.rawdata | |
i = 0 | |
n = len(rawdata) | |
while i < n: | |
if self.nomoretags: | |
self.handle_data(rawdata[i:n]) | |
i = n | |
break | |
match = interesting.search(rawdata, i) | |
if match: j = match.start() | |
else: j = n | |
if i < j: | |
self.handle_data(rawdata[i:j]) | |
i = j | |
if i == n: break | |
if rawdata[i] == '<': | |
if starttagopen.match(rawdata, i): | |
if self.literal: | |
self.handle_data(rawdata[i]) | |
i = i+1 | |
continue | |
k = self.parse_starttag(i) | |
if k < 0: break | |
i = k | |
continue | |
if rawdata.startswith("</", i): | |
k = self.parse_endtag(i) | |
if k < 0: break | |
i = k | |
self.literal = 0 | |
continue | |
if self.literal: | |
if n > (i + 1): | |
self.handle_data("<") | |
i = i+1 | |
else: | |
# incomplete | |
break | |
continue | |
if rawdata.startswith("<!--", i): | |
# Strictly speaking, a comment is --.*-- | |
# within a declaration tag <!...>. | |
# This should be removed, | |
# and comments handled only in parse_declaration. | |
k = self.parse_comment(i) | |
if k < 0: break | |
i = k | |
continue | |
if rawdata.startswith("<?", i): | |
k = self.parse_pi(i) | |
if k < 0: break | |
i = i+k | |
continue | |
if rawdata.startswith("<!", i): | |
# This is some sort of declaration; in "HTML as | |
# deployed," this should only be the document type | |
# declaration ("<!DOCTYPE html...>"). | |
k = self.parse_declaration(i) | |
if k < 0: break | |
i = k | |
continue | |
elif rawdata[i] == '&': | |
if self.literal: | |
self.handle_data(rawdata[i]) | |
i = i+1 | |
continue | |
match = charref.match(rawdata, i) | |
if match: | |
name = match.group(1) | |
self.handle_charref(name) | |
i = match.end(0) | |
if rawdata[i-1] != ';': i = i-1 | |
continue | |
match = entityref.match(rawdata, i) | |
if match: | |
name = match.group(1) | |
self.handle_entityref(name) | |
i = match.end(0) | |
if rawdata[i-1] != ';': i = i-1 | |
continue | |
else: | |
self.error('neither < nor & ??') | |
# We get here only if incomplete matches but | |
# nothing else | |
match = incomplete.match(rawdata, i) | |
if not match: | |
self.handle_data(rawdata[i]) | |
i = i+1 | |
continue | |
j = match.end(0) | |
if j == n: | |
break # Really incomplete | |
self.handle_data(rawdata[i:j]) | |
i = j | |
# end while | |
if end and i < n: | |
self.handle_data(rawdata[i:n]) | |
i = n | |
self.rawdata = rawdata[i:] | |
# XXX if end: check for empty stack | |
# Extensions for the DOCTYPE scanner: | |
_decl_otherchars = '=' | |
# Internal -- parse processing instr, return length or -1 if not terminated | |
def parse_pi(self, i): | |
rawdata = self.rawdata | |
if rawdata[i:i+2] != '<?': | |
self.error('unexpected call to parse_pi()') | |
match = piclose.search(rawdata, i+2) | |
if not match: | |
return -1 | |
j = match.start(0) | |
self.handle_pi(rawdata[i+2: j]) | |
j = match.end(0) | |
return j-i | |
def get_starttag_text(self): | |
return self.__starttag_text | |
# Internal -- handle starttag, return length or -1 if not terminated | |
def parse_starttag(self, i): | |
self.__starttag_text = None | |
start_pos = i | |
rawdata = self.rawdata | |
if shorttagopen.match(rawdata, i): | |
# SGML shorthand: <tag/data/ == <tag>data</tag> | |
# XXX Can data contain &... (entity or char refs)? | |
# XXX Can data contain < or > (tag characters)? | |
# XXX Can there be whitespace before the first /? | |
match = shorttag.match(rawdata, i) | |
if not match: | |
return -1 | |
tag, data = match.group(1, 2) | |
self.__starttag_text = '<%s/' % tag | |
tag = tag.lower() | |
k = match.end(0) | |
self.finish_shorttag(tag, data) | |
self.__starttag_text = rawdata[start_pos:match.end(1) + 1] | |
return k | |
# XXX The following should skip matching quotes (' or ") | |
# As a shortcut way to exit, this isn't so bad, but shouldn't | |
# be used to locate the actual end of the start tag since the | |
# < or > characters may be embedded in an attribute value. | |
match = endbracket.search(rawdata, i+1) | |
if not match: | |
return -1 | |
j = match.start(0) | |
# Now parse the data between i+1 and j into a tag and attrs | |
attrs = [] | |
if rawdata[i:i+2] == '<>': | |
# SGML shorthand: <> == <last open tag seen> | |
k = j | |
tag = self.lasttag | |
else: | |
match = tagfind.match(rawdata, i+1) | |
if not match: | |
self.error('unexpected call to parse_starttag') | |
k = match.end(0) | |
tag = rawdata[i+1:k].lower() | |
self.lasttag = tag | |
while k < j: | |
match = attrfind.match(rawdata, k) | |
if not match: break | |
attrname, rest, attrvalue = match.group(1, 2, 3) | |
if not rest: | |
attrvalue = attrname | |
else: | |
if (attrvalue[:1] == "'" == attrvalue[-1:] or | |
attrvalue[:1] == '"' == attrvalue[-1:]): | |
# strip quotes | |
attrvalue = attrvalue[1:-1] | |
attrvalue = self.entity_or_charref.sub( | |
self._convert_ref, attrvalue) | |
attrs.append((attrname.lower(), attrvalue)) | |
k = match.end(0) | |
if rawdata[j] == '>': | |
j = j+1 | |
self.__starttag_text = rawdata[start_pos:j] | |
self.finish_starttag(tag, attrs) | |
return j | |
# Internal -- convert entity or character reference | |
def _convert_ref(self, match): | |
if match.group(2): | |
return self.convert_charref(match.group(2)) or \ | |
'&#%s%s' % match.groups()[1:] | |
elif match.group(3): | |
return self.convert_entityref(match.group(1)) or \ | |
'&%s;' % match.group(1) | |
else: | |
return '&%s' % match.group(1) | |
# Internal -- parse endtag | |
def parse_endtag(self, i): | |
rawdata = self.rawdata | |
match = endbracket.search(rawdata, i+1) | |
if not match: | |
return -1 | |
j = match.start(0) | |
tag = rawdata[i+2:j].strip().lower() | |
if rawdata[j] == '>': | |
j = j+1 | |
self.finish_endtag(tag) | |
return j | |
# Internal -- finish parsing of <tag/data/ (same as <tag>data</tag>) | |
def finish_shorttag(self, tag, data): | |
self.finish_starttag(tag, []) | |
self.handle_data(data) | |
self.finish_endtag(tag) | |
# Internal -- finish processing of start tag | |
# Return -1 for unknown tag, 0 for open-only tag, 1 for balanced tag | |
def finish_starttag(self, tag, attrs): | |
try: | |
method = getattr(self, 'start_' + tag) | |
except AttributeError: | |
try: | |
method = getattr(self, 'do_' + tag) | |
except AttributeError: | |
self.unknown_starttag(tag, attrs) | |
return -1 | |
else: | |
self.handle_starttag(tag, method, attrs) | |
return 0 | |
else: | |
self.stack.append(tag) | |
self.handle_starttag(tag, method, attrs) | |
return 1 | |
# Internal -- finish processing of end tag | |
def finish_endtag(self, tag): | |
if not tag: | |
found = len(self.stack) - 1 | |
if found < 0: | |
self.unknown_endtag(tag) | |
return | |
else: | |
if tag not in self.stack: | |
try: | |
method = getattr(self, 'end_' + tag) | |
except AttributeError: | |
self.unknown_endtag(tag) | |
else: | |
self.report_unbalanced(tag) | |
return | |
found = len(self.stack) | |
for i in range(found): | |
if self.stack[i] == tag: found = i | |
while len(self.stack) > found: | |
tag = self.stack[-1] | |
try: | |
method = getattr(self, 'end_' + tag) | |
except AttributeError: | |
method = None | |
if method: | |
self.handle_endtag(tag, method) | |
else: | |
self.unknown_endtag(tag) | |
del self.stack[-1] | |
# Overridable -- handle start tag | |
def handle_starttag(self, tag, method, attrs): | |
method(attrs) | |
# Overridable -- handle end tag | |
def handle_endtag(self, tag, method): | |
method() | |
# Example -- report an unbalanced </...> tag. | |
def report_unbalanced(self, tag): | |
if self.verbose: | |
print('*** Unbalanced </' + tag + '>') | |
print('*** Stack:', self.stack) | |
def convert_charref(self, name): | |
"""Convert character reference, may be overridden.""" | |
try: | |
n = int(name) | |
except ValueError: | |
return | |
if not 0 <= n <= 127: | |
return | |
return self.convert_codepoint(n) | |
def convert_codepoint(self, codepoint): | |
return chr(codepoint) | |
def handle_charref(self, name): | |
"""Handle character reference, no need to override.""" | |
replacement = self.convert_charref(name) | |
if replacement is None: | |
self.unknown_charref(name) | |
else: | |
self.handle_data(replacement) | |
# Definition of entities -- derived classes may override | |
entitydefs = \ | |
{'lt': '<', 'gt': '>', 'amp': '&', 'quot': '"', 'apos': '\''} | |
def convert_entityref(self, name): | |
"""Convert entity references. | |
As an alternative to overriding this method; one can tailor the | |
results by setting up the self.entitydefs mapping appropriately. | |
""" | |
table = self.entitydefs | |
if name in table: | |
return table[name] | |
else: | |
return | |
def handle_entityref(self, name): | |
"""Handle entity references, no need to override.""" | |
replacement = self.convert_entityref(name) | |
if replacement is None: | |
self.unknown_entityref(name) | |
else: | |
self.handle_data(replacement) | |
# Example -- handle data, should be overridden | |
def handle_data(self, data): | |
pass | |
# Example -- handle comment, could be overridden | |
def handle_comment(self, data): | |
pass | |
# Example -- handle declaration, could be overridden | |
def handle_decl(self, decl): | |
pass | |
# Example -- handle processing instruction, could be overridden | |
def handle_pi(self, data): | |
pass | |
# To be overridden -- handlers for unknown objects | |
def unknown_starttag(self, tag, attrs): pass | |
def unknown_endtag(self, tag): pass | |
def unknown_charref(self, ref): pass | |
def unknown_entityref(self, ref): pass | |
class TestSGMLParser(SGMLParser): | |
def __init__(self, verbose=0): | |
self.testdata = "" | |
SGMLParser.__init__(self, verbose) | |
def handle_data(self, data): | |
self.testdata = self.testdata + data | |
if len(repr(self.testdata)) >= 70: | |
self.flush() | |
def flush(self): | |
data = self.testdata | |
if data: | |
self.testdata = "" | |
print('data:', repr(data)) | |
def handle_comment(self, data): | |
self.flush() | |
r = repr(data) | |
if len(r) > 68: | |
r = r[:32] + '...' + r[-32:] | |
print('comment:', r) | |
def unknown_starttag(self, tag, attrs): | |
self.flush() | |
if not attrs: | |
print('start tag: <' + tag + '>') | |
else: | |
print('start tag: <' + tag, end=' ') | |
for name, value in attrs: | |
print(name + '=' + '"' + value + '"', end=' ') | |
print('>') | |
def unknown_endtag(self, tag): | |
self.flush() | |
print('end tag: </' + tag + '>') | |
def unknown_entityref(self, ref): | |
self.flush() | |
print('*** unknown entity ref: &' + ref + ';') | |
def unknown_charref(self, ref): | |
self.flush() | |
print('*** unknown char ref: &#' + ref + ';') | |
def unknown_decl(self, data): | |
self.flush() | |
print('*** unknown decl: [' + data + ']') | |
def close(self): | |
SGMLParser.close(self) | |
self.flush() | |
def test(args = None): | |
import sys | |
if args is None: | |
args = sys.argv[1:] | |
if args and args[0] == '-s': | |
args = args[1:] | |
klass = SGMLParser | |
else: | |
klass = TestSGMLParser | |
if args: | |
file = args[0] | |
else: | |
file = 'test.html' | |
if file == '-': | |
f = sys.stdin | |
else: | |
try: | |
f = open(file, 'r') | |
except IOError as msg: | |
print(file, ":", msg) | |
sys.exit(1) | |
data = f.read() | |
if f is not sys.stdin: | |
f.close() | |
x = klass() | |
for c in data: | |
x.feed(c) | |
x.close() | |
if __name__ == '__main__': | |
test() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A library for integrating Python's builtin ``ssl`` library with CherryPy. | |
The ssl module must be importable for SSL functionality. | |
To use this module, set ``CherryPyWSGIServer.ssl_adapter`` to an instance of | |
``BuiltinSSLAdapter``. | |
""" | |
try: | |
import ssl | |
except ImportError: | |
ssl = None | |
from cherrypy import wsgiserver | |
class BuiltinSSLAdapter(wsgiserver.SSLAdapter): | |
"""A wrapper for integrating Python's builtin ssl module with CherryPy.""" | |
certificate = None | |
"""The filename of the server SSL certificate.""" | |
private_key = None | |
"""The filename of the server's private key file.""" | |
def __init__(self, certificate, private_key, certificate_chain=None): | |
if ssl is None: | |
raise ImportError("You must install the ssl module to use HTTPS.") | |
self.certificate = certificate | |
self.private_key = private_key | |
self.certificate_chain = certificate_chain | |
def bind(self, sock): | |
"""Wrap and return the given socket.""" | |
return sock | |
def wrap(self, sock): | |
"""Wrap and return the given socket, plus WSGI environ entries.""" | |
try: | |
s = ssl.wrap_socket(sock, do_handshake_on_connect=True, | |
server_side=True, certfile=self.certificate, | |
keyfile=self.private_key, ssl_version=ssl.PROTOCOL_SSLv23) | |
except ssl.SSLError, e: | |
if e.errno == ssl.SSL_ERROR_EOF: | |
# This is almost certainly due to the cherrypy engine | |
# 'pinging' the socket to assert it's connectable; | |
# the 'ping' isn't SSL. | |
return None, {} | |
elif e.errno == ssl.SSL_ERROR_SSL: | |
if e.args[1].endswith('http request'): | |
# The client is speaking HTTP to an HTTPS server. | |
raise wsgiserver.NoSSLError | |
raise | |
return s, self.get_environ(s) | |
# TODO: fill this out more with mod ssl env | |
def get_environ(self, sock): | |
"""Create WSGI environ entries to be merged into each request.""" | |
cipher = sock.cipher() | |
ssl_environ = { | |
"wsgi.url_scheme": "https", | |
"HTTPS": "on", | |
'SSL_PROTOCOL': cipher[1], | |
'SSL_CIPHER': cipher[0] | |
## SSL_VERSION_INTERFACE string The mod_ssl program version | |
## SSL_VERSION_LIBRARY string The OpenSSL program version | |
} | |
return ssl_environ | |
def makefile(self, sock, mode='r', bufsize=-1): | |
return wsgiserver.CP_fileobject(sock, mode, bufsize) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A library for integrating pyOpenSSL with CherryPy. | |
The OpenSSL module must be importable for SSL functionality. | |
You can obtain it from http://pyopenssl.sourceforge.net/ | |
To use this module, set CherryPyWSGIServer.ssl_adapter to an instance of | |
SSLAdapter. There are two ways to use SSL: | |
Method One | |
---------- | |
* ``ssl_adapter.context``: an instance of SSL.Context. | |
If this is not None, it is assumed to be an SSL.Context instance, | |
and will be passed to SSL.Connection on bind(). The developer is | |
responsible for forming a valid Context object. This approach is | |
to be preferred for more flexibility, e.g. if the cert and key are | |
streams instead of files, or need decryption, or SSL.SSLv3_METHOD | |
is desired instead of the default SSL.SSLv23_METHOD, etc. Consult | |
the pyOpenSSL documentation for complete options. | |
Method Two (shortcut) | |
--------------------- | |
* ``ssl_adapter.certificate``: the filename of the server SSL certificate. | |
* ``ssl_adapter.private_key``: the filename of the server's private key file. | |
Both are None by default. If ssl_adapter.context is None, but .private_key | |
and .certificate are both given and valid, they will be read, and the | |
context will be automatically created from them. | |
""" | |
import socket | |
import threading | |
import time | |
from cherrypy import wsgiserver | |
try: | |
from OpenSSL import SSL | |
from OpenSSL import crypto | |
except ImportError: | |
SSL = None | |
class SSL_fileobject(wsgiserver.CP_fileobject): | |
"""SSL file object attached to a socket object.""" | |
ssl_timeout = 3 | |
ssl_retry = .01 | |
def _safe_call(self, is_reader, call, *args, **kwargs): | |
"""Wrap the given call with SSL error-trapping. | |
is_reader: if False EOF errors will be raised. If True, EOF errors | |
will return "" (to emulate normal sockets). | |
""" | |
start = time.time() | |
while True: | |
try: | |
return call(*args, **kwargs) | |
except SSL.WantReadError: | |
# Sleep and try again. This is dangerous, because it means | |
# the rest of the stack has no way of differentiating | |
# between a "new handshake" error and "client dropped". | |
# Note this isn't an endless loop: there's a timeout below. | |
time.sleep(self.ssl_retry) | |
except SSL.WantWriteError: | |
time.sleep(self.ssl_retry) | |
except SSL.SysCallError, e: | |
if is_reader and e.args == (-1, 'Unexpected EOF'): | |
return "" | |
errnum = e.args[0] | |
if is_reader and errnum in wsgiserver.socket_errors_to_ignore: | |
return "" | |
raise socket.error(errnum) | |
except SSL.Error, e: | |
if is_reader and e.args == (-1, 'Unexpected EOF'): | |
return "" | |
thirdarg = None | |
try: | |
thirdarg = e.args[0][0][2] | |
except IndexError: | |
pass | |
if thirdarg == 'http request': | |
# The client is talking HTTP to an HTTPS server. | |
raise wsgiserver.NoSSLError() | |
raise wsgiserver.FatalSSLAlert(*e.args) | |
except: | |
raise | |
if time.time() - start > self.ssl_timeout: | |
raise socket.timeout("timed out") | |
def recv(self, *args, **kwargs): | |
buf = [] | |
r = super(SSL_fileobject, self).recv | |
while True: | |
data = self._safe_call(True, r, *args, **kwargs) | |
buf.append(data) | |
p = self._sock.pending() | |
if not p: | |
return "".join(buf) | |
def sendall(self, *args, **kwargs): | |
return self._safe_call(False, super(SSL_fileobject, self).sendall, | |
*args, **kwargs) | |
def send(self, *args, **kwargs): | |
return self._safe_call(False, super(SSL_fileobject, self).send, | |
*args, **kwargs) | |
class SSLConnection: | |
"""A thread-safe wrapper for an SSL.Connection. | |
``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``. | |
""" | |
def __init__(self, *args): | |
self._ssl_conn = SSL.Connection(*args) | |
self._lock = threading.RLock() | |
for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read', | |
'renegotiate', 'bind', 'listen', 'connect', 'accept', | |
'setblocking', 'fileno', 'close', 'get_cipher_list', | |
'getpeername', 'getsockname', 'getsockopt', 'setsockopt', | |
'makefile', 'get_app_data', 'set_app_data', 'state_string', | |
'sock_shutdown', 'get_peer_certificate', 'want_read', | |
'want_write', 'set_connect_state', 'set_accept_state', | |
'connect_ex', 'sendall', 'settimeout', 'gettimeout'): | |
exec("""def %s(self, *args): | |
self._lock.acquire() | |
try: | |
return self._ssl_conn.%s(*args) | |
finally: | |
self._lock.release() | |
""" % (f, f)) | |
def shutdown(self, *args): | |
self._lock.acquire() | |
try: | |
# pyOpenSSL.socket.shutdown takes no args | |
return self._ssl_conn.shutdown() | |
finally: | |
self._lock.release() | |
class pyOpenSSLAdapter(wsgiserver.SSLAdapter): | |
"""A wrapper for integrating pyOpenSSL with CherryPy.""" | |
context = None | |
"""An instance of SSL.Context.""" | |
certificate = None | |
"""The filename of the server SSL certificate.""" | |
private_key = None | |
"""The filename of the server's private key file.""" | |
certificate_chain = None | |
"""Optional. The filename of CA's intermediate certificate bundle. | |
This is needed for cheaper "chained root" SSL certificates, and should be | |
left as None if not required.""" | |
def __init__(self, certificate, private_key, certificate_chain=None): | |
if SSL is None: | |
raise ImportError("You must install pyOpenSSL to use HTTPS.") | |
self.context = None | |
self.certificate = certificate | |
self.private_key = private_key | |
self.certificate_chain = certificate_chain | |
self._environ = None | |
def bind(self, sock): | |
"""Wrap and return the given socket.""" | |
if self.context is None: | |
self.context = self.get_context() | |
conn = SSLConnection(self.context, sock) | |
self._environ = self.get_environ() | |
return conn | |
def wrap(self, sock): | |
"""Wrap and return the given socket, plus WSGI environ entries.""" | |
return sock, self._environ.copy() | |
def get_context(self): | |
"""Return an SSL.Context from self attributes.""" | |
# See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473 | |
c = SSL.Context(SSL.SSLv23_METHOD) | |
c.use_privatekey_file(self.private_key) | |
if self.certificate_chain: | |
c.load_verify_locations(self.certificate_chain) | |
c.use_certificate_file(self.certificate) | |
return c | |
def get_environ(self): | |
"""Return WSGI environ entries to be merged into each request.""" | |
ssl_environ = { | |
"HTTPS": "on", | |
# pyOpenSSL doesn't provide access to any of these AFAICT | |
## 'SSL_PROTOCOL': 'SSLv2', | |
## SSL_CIPHER string The cipher specification name | |
## SSL_VERSION_INTERFACE string The mod_ssl program version | |
## SSL_VERSION_LIBRARY string The OpenSSL program version | |
} | |
if self.certificate: | |
# Server certificate attributes | |
cert = open(self.certificate, 'rb').read() | |
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) | |
ssl_environ.update({ | |
'SSL_SERVER_M_VERSION': cert.get_version(), | |
'SSL_SERVER_M_SERIAL': cert.get_serial_number(), | |
## 'SSL_SERVER_V_START': Validity of server's certificate (start time), | |
## 'SSL_SERVER_V_END': Validity of server's certificate (end time), | |
}) | |
for prefix, dn in [("I", cert.get_issuer()), | |
("S", cert.get_subject())]: | |
# X509Name objects don't seem to have a way to get the | |
# complete DN string. Use str() and slice it instead, | |
# because str(dn) == "<X509Name object '/C=US/ST=...'>" | |
dnstr = str(dn)[18:-2] | |
wsgikey = 'SSL_SERVER_%s_DN' % prefix | |
ssl_environ[wsgikey] = dnstr | |
# The DN should be of the form: /k1=v1/k2=v2, but we must allow | |
# for any value to contain slashes itself (in a URL). | |
while dnstr: | |
pos = dnstr.rfind("=") | |
dnstr, value = dnstr[:pos], dnstr[pos + 1:] | |
pos = dnstr.rfind("/") | |
dnstr, key = dnstr[:pos], dnstr[pos + 1:] | |
if key and value: | |
wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) | |
ssl_environ[wsgikey] = value | |
return ssl_environ | |
def makefile(self, sock, mode='r', bufsize=-1): | |
if SSL and isinstance(sock, SSL.ConnectionType): | |
timeout = sock.gettimeout() | |
f = SSL_fileobject(sock, mode, bufsize) | |
f.ssl_timeout = timeout | |
return f | |
else: | |
return wsgiserver.CP_fileobject(sock, mode, bufsize) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Interface to various templating engines. | |
""" | |
import os.path | |
__all__ = [ | |
"render_cheetah", "render_genshi", "render_mako", | |
"cache", | |
] | |
class render_cheetah: | |
"""Rendering interface to Cheetah Templates. | |
Example: | |
render = render_cheetah('templates') | |
render.hello(name="cheetah") | |
""" | |
def __init__(self, path): | |
# give error if Chetah is not installed | |
from Cheetah.Template import Template | |
self.path = path | |
def __getattr__(self, name): | |
from Cheetah.Template import Template | |
path = os.path.join(self.path, name + ".html") | |
def template(**kw): | |
t = Template(file=path, searchList=[kw]) | |
return t.respond() | |
return template | |
class render_genshi: | |
"""Rendering interface genshi templates. | |
Example: | |
for xml/html templates. | |
render = render_genshi(['templates/']) | |
render.hello(name='genshi') | |
For text templates: | |
render = render_genshi(['templates/'], type='text') | |
render.hello(name='genshi') | |
""" | |
def __init__(self, *a, **kwargs): | |
from genshi.template import TemplateLoader | |
self._type = kwargs.pop('type', None) | |
self._loader = TemplateLoader(*a, **kwargs) | |
def __getattr__(self, name): | |
# Assuming all templates are html | |
path = name + ".html" | |
if self._type == "text": | |
from genshi.template import TextTemplate | |
cls = TextTemplate | |
type = "text" | |
else: | |
cls = None | |
type = None | |
t = self._loader.load(path, cls=cls) | |
def template(**kw): | |
stream = t.generate(**kw) | |
if type: | |
return stream.render(type) | |
else: | |
return stream.render() | |
return template | |
class render_jinja: | |
"""Rendering interface to Jinja2 Templates | |
Example: | |
render= render_jinja('templates') | |
render.hello(name='jinja2') | |
""" | |
def __init__(self, *a, **kwargs): | |
extensions = kwargs.pop('extensions', []) | |
globals = kwargs.pop('globals', {}) | |
from jinja2 import Environment,FileSystemLoader | |
self._lookup = Environment(loader=FileSystemLoader(*a, **kwargs), extensions=extensions) | |
self._lookup.globals.update(globals) | |
def __getattr__(self, name): | |
# Assuming all templates end with .html | |
path = name + '.html' | |
t = self._lookup.get_template(path) | |
return t.render | |
class render_mako: | |
"""Rendering interface to Mako Templates. | |
Example: | |
render = render_mako(directories=['templates']) | |
render.hello(name="mako") | |
""" | |
def __init__(self, *a, **kwargs): | |
from mako.lookup import TemplateLookup | |
self._lookup = TemplateLookup(*a, **kwargs) | |
def __getattr__(self, name): | |
# Assuming all templates are html | |
path = name + ".html" | |
t = self._lookup.get_template(path) | |
return t.render | |
class cache: | |
"""Cache for any rendering interface. | |
Example: | |
render = cache(render_cheetah("templates/")) | |
render.hello(name='cache') | |
""" | |
def __init__(self, render): | |
self._render = render | |
self._cache = {} | |
def __getattr__(self, name): | |
if name not in self._cache: | |
self._cache[name] = getattr(self._render, name) | |
return self._cache[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
simple, elegant templating | |
(part of web.py) | |
Template design: | |
Template string is split into tokens and the tokens are combined into nodes. | |
Parse tree is a nodelist. TextNode and ExpressionNode are simple nodes and | |
for-loop, if-loop etc are block nodes, which contain multiple child nodes. | |
Each node can emit some python string. python string emitted by the | |
root node is validated for safeeval and executed using python in the given environment. | |
Enough care is taken to make sure the generated code and the template has line to line match, | |
so that the error messages can point to exact line number in template. (It doesn't work in some cases still.) | |
Grammar: | |
template -> defwith sections | |
defwith -> '$def with (' arguments ')' | '' | |
sections -> section* | |
section -> block | assignment | line | |
assignment -> '$ ' <assignment expression> | |
line -> (text|expr)* | |
text -> <any characters other than $> | |
expr -> '$' pyexpr | '$(' pyexpr ')' | '${' pyexpr '}' | |
pyexpr -> <python expression> | |
""" | |
__all__ = [ | |
"Template", | |
"Render", "render", "frender", | |
"ParseError", "SecurityError", | |
"test" | |
] | |
import tokenize | |
import os | |
import sys | |
import glob | |
import re | |
from UserDict import DictMixin | |
from utils import storage, safeunicode, safestr, re_compile | |
from webapi import config | |
from net import websafe | |
def splitline(text): | |
r""" | |
Splits the given text at newline. | |
>>> splitline('foo\nbar') | |
('foo\n', 'bar') | |
>>> splitline('foo') | |
('foo', '') | |
>>> splitline('') | |
('', '') | |
""" | |
index = text.find('\n') + 1 | |
if index: | |
return text[:index], text[index:] | |
else: | |
return text, '' | |
class Parser: | |
"""Parser Base. | |
""" | |
def __init__(self): | |
self.statement_nodes = STATEMENT_NODES | |
self.keywords = KEYWORDS | |
def parse(self, text, name="<template>"): | |
self.text = text | |
self.name = name | |
defwith, text = self.read_defwith(text) | |
suite = self.read_suite(text) | |
return DefwithNode(defwith, suite) | |
def read_defwith(self, text): | |
if text.startswith('$def with'): | |
defwith, text = splitline(text) | |
defwith = defwith[1:].strip() # strip $ and spaces | |
return defwith, text | |
else: | |
return '', text | |
def read_section(self, text): | |
r"""Reads one section from the given text. | |
section -> block | assignment | line | |
>>> read_section = Parser().read_section | |
>>> read_section('foo\nbar\n') | |
(<line: [t'foo\n']>, 'bar\n') | |
>>> read_section('$ a = b + 1\nfoo\n') | |
(<assignment: 'a = b + 1'>, 'foo\n') | |
read_section('$for in range(10):\n hello $i\nfoo) | |
""" | |
if text.lstrip(' ').startswith('$'): | |
index = text.index('$') | |
begin_indent, text2 = text[:index], text[index+1:] | |
ahead = self.python_lookahead(text2) | |
if ahead == 'var': | |
return self.read_var(text2) | |
elif ahead in self.statement_nodes: | |
return self.read_block_section(text2, begin_indent) | |
elif ahead in self.keywords: | |
return self.read_keyword(text2) | |
elif ahead.strip() == '': | |
# assignments starts with a space after $ | |
# ex: $ a = b + 2 | |
return self.read_assignment(text2) | |
return self.readline(text) | |
def read_var(self, text): | |
r"""Reads a var statement. | |
>>> read_var = Parser().read_var | |
>>> read_var('var x=10\nfoo') | |
(<var: x = 10>, 'foo') | |
>>> read_var('var x: hello $name\nfoo') | |
(<var: x = join_(u'hello ', escape_(name, True))>, 'foo') | |
""" | |
line, text = splitline(text) | |
tokens = self.python_tokens(line) | |
if len(tokens) < 4: | |
raise SyntaxError('Invalid var statement') | |
name = tokens[1] | |
sep = tokens[2] | |
value = line.split(sep, 1)[1].strip() | |
if sep == '=': | |
pass # no need to process value | |
elif sep == ':': | |
#@@ Hack for backward-compatability | |
if tokens[3] == '\n': # multi-line var statement | |
block, text = self.read_indented_block(text, ' ') | |
lines = [self.readline(x)[0] for x in block.splitlines()] | |
nodes = [] | |
for x in lines: | |
nodes.extend(x.nodes) | |
nodes.append(TextNode('\n')) | |
else: # single-line var statement | |
linenode, _ = self.readline(value) | |
nodes = linenode.nodes | |
parts = [node.emit('') for node in nodes] | |
value = "join_(%s)" % ", ".join(parts) | |
else: | |
raise SyntaxError('Invalid var statement') | |
return VarNode(name, value), text | |
def read_suite(self, text): | |
r"""Reads section by section till end of text. | |
>>> read_suite = Parser().read_suite | |
>>> read_suite('hello $name\nfoo\n') | |
[<line: [t'hello ', $name, t'\n']>, <line: [t'foo\n']>] | |
""" | |
sections = [] | |
while text: | |
section, text = self.read_section(text) | |
sections.append(section) | |
return SuiteNode(sections) | |
def readline(self, text): | |
r"""Reads one line from the text. Newline is supressed if the line ends with \. | |
>>> readline = Parser().readline | |
>>> readline('hello $name!\nbye!') | |
(<line: [t'hello ', $name, t'!\n']>, 'bye!') | |
>>> readline('hello $name!\\\nbye!') | |
(<line: [t'hello ', $name, t'!']>, 'bye!') | |
>>> readline('$f()\n\n') | |
(<line: [$f(), t'\n']>, '\n') | |
""" | |
line, text = splitline(text) | |
# supress new line if line ends with \ | |
if line.endswith('\\\n'): | |
line = line[:-2] | |
nodes = [] | |
while line: | |
node, line = self.read_node(line) | |
nodes.append(node) | |
return LineNode(nodes), text | |
def read_node(self, text): | |
r"""Reads a node from the given text and returns the node and remaining text. | |
>>> read_node = Parser().read_node | |
>>> read_node('hello $name') | |
(t'hello ', '$name') | |
>>> read_node('$name') | |
($name, '') | |
""" | |
if text.startswith('$$'): | |
return TextNode('$'), text[2:] | |
elif text.startswith('$#'): # comment | |
line, text = splitline(text) | |
return TextNode('\n'), text | |
elif text.startswith('$'): | |
text = text[1:] # strip $ | |
if text.startswith(':'): | |
escape = False | |
text = text[1:] # strip : | |
else: | |
escape = True | |
return self.read_expr(text, escape=escape) | |
else: | |
return self.read_text(text) | |
def read_text(self, text): | |
r"""Reads a text node from the given text. | |
>>> read_text = Parser().read_text | |
>>> read_text('hello $name') | |
(t'hello ', '$name') | |
""" | |
index = text.find('$') | |
if index < 0: | |
return TextNode(text), '' | |
else: | |
return TextNode(text[:index]), text[index:] | |
def read_keyword(self, text): | |
line, text = splitline(text) | |
return StatementNode(line.strip() + "\n"), text | |
def read_expr(self, text, escape=True): | |
"""Reads a python expression from the text and returns the expression and remaining text. | |
expr -> simple_expr | paren_expr | |
simple_expr -> id extended_expr | |
extended_expr -> attr_access | paren_expr extended_expr | '' | |
attr_access -> dot id extended_expr | |
paren_expr -> [ tokens ] | ( tokens ) | { tokens } | |
>>> read_expr = Parser().read_expr | |
>>> read_expr("name") | |
($name, '') | |
>>> read_expr("a.b and c") | |
($a.b, ' and c') | |
>>> read_expr("a. b") | |
($a, '. b') | |
>>> read_expr("name</h1>") | |
($name, '</h1>') | |
>>> read_expr("(limit)ing") | |
($(limit), 'ing') | |
>>> read_expr('a[1, 2][:3].f(1+2, "weird string[).", 3 + 4) done.') | |
($a[1, 2][:3].f(1+2, "weird string[).", 3 + 4), ' done.') | |
""" | |
def simple_expr(): | |
identifier() | |
extended_expr() | |
def identifier(): | |
tokens.next() | |
def extended_expr(): | |
lookahead = tokens.lookahead() | |
if lookahead is None: | |
return | |
elif lookahead.value == '.': | |
attr_access() | |
elif lookahead.value in parens: | |
paren_expr() | |
extended_expr() | |
else: | |
return | |
def attr_access(): | |
from token import NAME # python token constants | |
dot = tokens.lookahead() | |
if tokens.lookahead2().type == NAME: | |
tokens.next() # consume dot | |
identifier() | |
extended_expr() | |
def paren_expr(): | |
begin = tokens.next().value | |
end = parens[begin] | |
while True: | |
if tokens.lookahead().value in parens: | |
paren_expr() | |
else: | |
t = tokens.next() | |
if t.value == end: | |
break | |
return | |
parens = { | |
"(": ")", | |
"[": "]", | |
"{": "}" | |
} | |
def get_tokens(text): | |
"""tokenize text using python tokenizer. | |
Python tokenizer ignores spaces, but they might be important in some cases. | |
This function introduces dummy space tokens when it identifies any ignored space. | |
Each token is a storage object containing type, value, begin and end. | |
""" | |
readline = iter([text]).next | |
end = None | |
for t in tokenize.generate_tokens(readline): | |
t = storage(type=t[0], value=t[1], begin=t[2], end=t[3]) | |
if end is not None and end != t.begin: | |
_, x1 = end | |
_, x2 = t.begin | |
yield storage(type=-1, value=text[x1:x2], begin=end, end=t.begin) | |
end = t.end | |
yield t | |
class BetterIter: | |
"""Iterator like object with 2 support for 2 look aheads.""" | |
def __init__(self, items): | |
self.iteritems = iter(items) | |
self.items = [] | |
self.position = 0 | |
self.current_item = None | |
def lookahead(self): | |
if len(self.items) <= self.position: | |
self.items.append(self._next()) | |
return self.items[self.position] | |
def _next(self): | |
try: | |
return self.iteritems.next() | |
except StopIteration: | |
return None | |
def lookahead2(self): | |
if len(self.items) <= self.position+1: | |
self.items.append(self._next()) | |
return self.items[self.position+1] | |
def next(self): | |
self.current_item = self.lookahead() | |
self.position += 1 | |
return self.current_item | |
tokens = BetterIter(get_tokens(text)) | |
if tokens.lookahead().value in parens: | |
paren_expr() | |
else: | |
simple_expr() | |
row, col = tokens.current_item.end | |
return ExpressionNode(text[:col], escape=escape), text[col:] | |
def read_assignment(self, text): | |
r"""Reads assignment statement from text. | |
>>> read_assignment = Parser().read_assignment | |
>>> read_assignment('a = b + 1\nfoo') | |
(<assignment: 'a = b + 1'>, 'foo') | |
""" | |
line, text = splitline(text) | |
return AssignmentNode(line.strip()), text | |
def python_lookahead(self, text): | |
"""Returns the first python token from the given text. | |
>>> python_lookahead = Parser().python_lookahead | |
>>> python_lookahead('for i in range(10):') | |
'for' | |
>>> python_lookahead('else:') | |
'else' | |
>>> python_lookahead(' x = 1') | |
' ' | |
""" | |
readline = iter([text]).next | |
tokens = tokenize.generate_tokens(readline) | |
return tokens.next()[1] | |
def python_tokens(self, text): | |
readline = iter([text]).next | |
tokens = tokenize.generate_tokens(readline) | |
return [t[1] for t in tokens] | |
def read_indented_block(self, text, indent): | |
r"""Read a block of text. A block is what typically follows a for or it statement. | |
It can be in the same line as that of the statement or an indented block. | |
>>> read_indented_block = Parser().read_indented_block | |
>>> read_indented_block(' a\n b\nc', ' ') | |
('a\nb\n', 'c') | |
>>> read_indented_block(' a\n b\n c\nd', ' ') | |
('a\n b\nc\n', 'd') | |
>>> read_indented_block(' a\n\n b\nc', ' ') | |
('a\n\n b\n', 'c') | |
""" | |
if indent == '': | |
return '', text | |
block = "" | |
while text: | |
line, text2 = splitline(text) | |
if line.strip() == "": | |
block += '\n' | |
elif line.startswith(indent): | |
block += line[len(indent):] | |
else: | |
break | |
text = text2 | |
return block, text | |
def read_statement(self, text): | |
r"""Reads a python statement. | |
>>> read_statement = Parser().read_statement | |
>>> read_statement('for i in range(10): hello $name') | |
('for i in range(10):', ' hello $name') | |
""" | |
tok = PythonTokenizer(text) | |
tok.consume_till(':') | |
return text[:tok.index], text[tok.index:] | |
def read_block_section(self, text, begin_indent=''): | |
r""" | |
>>> read_block_section = Parser().read_block_section | |
>>> read_block_section('for i in range(10): hello $i\nfoo') | |
(<block: 'for i in range(10):', [<line: [t'hello ', $i, t'\n']>]>, 'foo') | |
>>> read_block_section('for i in range(10):\n hello $i\n foo', begin_indent=' ') | |
(<block: 'for i in range(10):', [<line: [t'hello ', $i, t'\n']>]>, ' foo') | |
>>> read_block_section('for i in range(10):\n hello $i\nfoo') | |
(<block: 'for i in range(10):', [<line: [t'hello ', $i, t'\n']>]>, 'foo') | |
""" | |
line, text = splitline(text) | |
stmt, line = self.read_statement(line) | |
keyword = self.python_lookahead(stmt) | |
# if there is some thing left in the line | |
if line.strip(): | |
block = line.lstrip() | |
else: | |
def find_indent(text): | |
rx = re_compile(' +') | |
match = rx.match(text) | |
first_indent = match and match.group(0) | |
return first_indent or "" | |
# find the indentation of the block by looking at the first line | |
first_indent = find_indent(text)[len(begin_indent):] | |
#TODO: fix this special case | |
if keyword == "code": | |
indent = begin_indent + first_indent | |
else: | |
indent = begin_indent + min(first_indent, INDENT) | |
block, text = self.read_indented_block(text, indent) | |
return self.create_block_node(keyword, stmt, block, begin_indent), text | |
def create_block_node(self, keyword, stmt, block, begin_indent): | |
if keyword in self.statement_nodes: | |
return self.statement_nodes[keyword](stmt, block, begin_indent) | |
else: | |
raise ParseError, 'Unknown statement: %s' % repr(keyword) | |
class PythonTokenizer: | |
"""Utility wrapper over python tokenizer.""" | |
def __init__(self, text): | |
self.text = text | |
readline = iter([text]).next | |
self.tokens = tokenize.generate_tokens(readline) | |
self.index = 0 | |
def consume_till(self, delim): | |
"""Consumes tokens till colon. | |
>>> tok = PythonTokenizer('for i in range(10): hello $i') | |
>>> tok.consume_till(':') | |
>>> tok.text[:tok.index] | |
'for i in range(10):' | |
>>> tok.text[tok.index:] | |
' hello $i' | |
""" | |
try: | |
while True: | |
t = self.next() | |
if t.value == delim: | |
break | |
elif t.value == '(': | |
self.consume_till(')') | |
elif t.value == '[': | |
self.consume_till(']') | |
elif t.value == '{': | |
self.consume_till('}') | |
# if end of line is found, it is an exception. | |
# Since there is no easy way to report the line number, | |
# leave the error reporting to the python parser later | |
#@@ This should be fixed. | |
if t.value == '\n': | |
break | |
except: | |
#raise ParseError, "Expected %s, found end of line." % repr(delim) | |
# raising ParseError doesn't show the line number. | |
# if this error is ignored, then it will be caught when compiling the python code. | |
return | |
def next(self): | |
type, t, begin, end, line = self.tokens.next() | |
row, col = end | |
self.index = col | |
return storage(type=type, value=t, begin=begin, end=end) | |
class DefwithNode: | |
def __init__(self, defwith, suite): | |
if defwith: | |
self.defwith = defwith.replace('with', '__template__') + ':' | |
# offset 4 lines. for encoding, __lineoffset__, loop and self. | |
self.defwith += "\n __lineoffset__ = -4" | |
else: | |
self.defwith = 'def __template__():' | |
# offset 4 lines for encoding, __template__, __lineoffset__, loop and self. | |
self.defwith += "\n __lineoffset__ = -5" | |
self.defwith += "\n loop = ForLoop()" | |
self.defwith += "\n self = TemplateResult(); extend_ = self.extend" | |
self.suite = suite | |
self.end = "\n return self" | |
def emit(self, indent): | |
encoding = "# coding: utf-8\n" | |
return encoding + self.defwith + self.suite.emit(indent + INDENT) + self.end | |
def __repr__(self): | |
return "<defwith: %s, %s>" % (self.defwith, self.suite) | |
class TextNode: | |
def __init__(self, value): | |
self.value = value | |
def emit(self, indent, begin_indent=''): | |
return repr(safeunicode(self.value)) | |
def __repr__(self): | |
return 't' + repr(self.value) | |
class ExpressionNode: | |
def __init__(self, value, escape=True): | |
self.value = value.strip() | |
# convert ${...} to $(...) | |
if value.startswith('{') and value.endswith('}'): | |
self.value = '(' + self.value[1:-1] + ')' | |
self.escape = escape | |
def emit(self, indent, begin_indent=''): | |
return 'escape_(%s, %s)' % (self.value, bool(self.escape)) | |
def __repr__(self): | |
if self.escape: | |
escape = '' | |
else: | |
escape = ':' | |
return "$%s%s" % (escape, self.value) | |
class AssignmentNode: | |
def __init__(self, code): | |
self.code = code | |
def emit(self, indent, begin_indent=''): | |
return indent + self.code + "\n" | |
def __repr__(self): | |
return "<assignment: %s>" % repr(self.code) | |
class LineNode: | |
def __init__(self, nodes): | |
self.nodes = nodes | |
def emit(self, indent, text_indent='', name=''): | |
text = [node.emit('') for node in self.nodes] | |
if text_indent: | |
text = [repr(text_indent)] + text | |
return indent + "extend_([%s])\n" % ", ".join(text) | |
def __repr__(self): | |
return "<line: %s>" % repr(self.nodes) | |
INDENT = ' ' # 4 spaces | |
class BlockNode: | |
def __init__(self, stmt, block, begin_indent=''): | |
self.stmt = stmt | |
self.suite = Parser().read_suite(block) | |
self.begin_indent = begin_indent | |
def emit(self, indent, text_indent=''): | |
text_indent = self.begin_indent + text_indent | |
out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent) | |
return out | |
def __repr__(self): | |
return "<block: %s, %s>" % (repr(self.stmt), repr(self.suite)) | |
class ForNode(BlockNode): | |
def __init__(self, stmt, block, begin_indent=''): | |
self.original_stmt = stmt | |
tok = PythonTokenizer(stmt) | |
tok.consume_till('in') | |
a = stmt[:tok.index] # for i in | |
b = stmt[tok.index:-1] # rest of for stmt excluding : | |
stmt = a + ' loop.setup(' + b.strip() + '):' | |
BlockNode.__init__(self, stmt, block, begin_indent) | |
def __repr__(self): | |
return "<block: %s, %s>" % (repr(self.original_stmt), repr(self.suite)) | |
class CodeNode: | |
def __init__(self, stmt, block, begin_indent=''): | |
# compensate one line for $code: | |
self.code = "\n" + block | |
def emit(self, indent, text_indent=''): | |
import re | |
rx = re.compile('^', re.M) | |
return rx.sub(indent, self.code).rstrip(' ') | |
def __repr__(self): | |
return "<code: %s>" % repr(self.code) | |
class StatementNode: | |
def __init__(self, stmt): | |
self.stmt = stmt | |
def emit(self, indent, begin_indent=''): | |
return indent + self.stmt | |
def __repr__(self): | |
return "<stmt: %s>" % repr(self.stmt) | |
class IfNode(BlockNode): | |
pass | |
class ElseNode(BlockNode): | |
pass | |
class ElifNode(BlockNode): | |
pass | |
class DefNode(BlockNode): | |
def __init__(self, *a, **kw): | |
BlockNode.__init__(self, *a, **kw) | |
code = CodeNode("", "") | |
code.code = "self = TemplateResult(); extend_ = self.extend\n" | |
self.suite.sections.insert(0, code) | |
code = CodeNode("", "") | |
code.code = "return self\n" | |
self.suite.sections.append(code) | |
def emit(self, indent, text_indent=''): | |
text_indent = self.begin_indent + text_indent | |
out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent) | |
return indent + "__lineoffset__ -= 3\n" + out | |
class VarNode: | |
def __init__(self, name, value): | |
self.name = name | |
self.value = value | |
def emit(self, indent, text_indent): | |
return indent + "self[%s] = %s\n" % (repr(self.name), self.value) | |
def __repr__(self): | |
return "<var: %s = %s>" % (self.name, self.value) | |
class SuiteNode: | |
"""Suite is a list of sections.""" | |
def __init__(self, sections): | |
self.sections = sections | |
def emit(self, indent, text_indent=''): | |
return "\n" + "".join([s.emit(indent, text_indent) for s in self.sections]) | |
def __repr__(self): | |
return repr(self.sections) | |
STATEMENT_NODES = { | |
'for': ForNode, | |
'while': BlockNode, | |
'if': IfNode, | |
'elif': ElifNode, | |
'else': ElseNode, | |
'def': DefNode, | |
'code': CodeNode | |
} | |
KEYWORDS = [ | |
"pass", | |
"break", | |
"continue", | |
"return" | |
] | |
TEMPLATE_BUILTIN_NAMES = [ | |
"dict", "enumerate", "float", "int", "bool", "list", "long", "reversed", | |
"set", "slice", "tuple", "xrange", | |
"abs", "all", "any", "callable", "chr", "cmp", "divmod", "filter", "hex", | |
"id", "isinstance", "iter", "len", "max", "min", "oct", "ord", "pow", "range", | |
"True", "False", | |
"None", | |
"__import__", # some c-libraries like datetime requires __import__ to present in the namespace | |
] | |
import __builtin__ | |
TEMPLATE_BUILTINS = dict([(name, getattr(__builtin__, name)) for name in TEMPLATE_BUILTIN_NAMES if name in __builtin__.__dict__]) | |
class ForLoop: | |
""" | |
Wrapper for expression in for stament to support loop.xxx helpers. | |
>>> loop = ForLoop() | |
>>> for x in loop.setup(['a', 'b', 'c']): | |
... print loop.index, loop.revindex, loop.parity, x | |
... | |
1 3 odd a | |
2 2 even b | |
3 1 odd c | |
>>> loop.index | |
Traceback (most recent call last): | |
... | |
AttributeError: index | |
""" | |
def __init__(self): | |
self._ctx = None | |
def __getattr__(self, name): | |
if self._ctx is None: | |
raise AttributeError, name | |
else: | |
return getattr(self._ctx, name) | |
def setup(self, seq): | |
self._push() | |
return self._ctx.setup(seq) | |
def _push(self): | |
self._ctx = ForLoopContext(self, self._ctx) | |
def _pop(self): | |
self._ctx = self._ctx.parent | |
class ForLoopContext: | |
"""Stackable context for ForLoop to support nested for loops. | |
""" | |
def __init__(self, forloop, parent): | |
self._forloop = forloop | |
self.parent = parent | |
def setup(self, seq): | |
try: | |
self.length = len(seq) | |
except: | |
self.length = 0 | |
self.index = 0 | |
for a in seq: | |
self.index += 1 | |
yield a | |
self._forloop._pop() | |
index0 = property(lambda self: self.index-1) | |
first = property(lambda self: self.index == 1) | |
last = property(lambda self: self.index == self.length) | |
odd = property(lambda self: self.index % 2 == 1) | |
even = property(lambda self: self.index % 2 == 0) | |
parity = property(lambda self: ['odd', 'even'][self.even]) | |
revindex0 = property(lambda self: self.length - self.index) | |
revindex = property(lambda self: self.length - self.index + 1) | |
class BaseTemplate: | |
def __init__(self, code, filename, filter, globals, builtins): | |
self.filename = filename | |
self.filter = filter | |
self._globals = globals | |
self._builtins = builtins | |
if code: | |
self.t = self._compile(code) | |
else: | |
self.t = lambda: '' | |
def _compile(self, code): | |
env = self.make_env(self._globals or {}, self._builtins) | |
exec(code, env) | |
return env['__template__'] | |
def __call__(self, *a, **kw): | |
__hidetraceback__ = True | |
return self.t(*a, **kw) | |
def make_env(self, globals, builtins): | |
return dict(globals, | |
__builtins__=builtins, | |
ForLoop=ForLoop, | |
TemplateResult=TemplateResult, | |
escape_=self._escape, | |
join_=self._join | |
) | |
def _join(self, *items): | |
return u"".join(items) | |
def _escape(self, value, escape=False): | |
if value is None: | |
value = '' | |
value = safeunicode(value) | |
if escape and self.filter: | |
value = self.filter(value) | |
return value | |
class Template(BaseTemplate): | |
CONTENT_TYPES = { | |
'.html' : 'text/html; charset=utf-8', | |
'.xhtml' : 'application/xhtml+xml; charset=utf-8', | |
'.txt' : 'text/plain', | |
} | |
FILTERS = { | |
'.html': websafe, | |
'.xhtml': websafe, | |
'.xml': websafe | |
} | |
globals = {} | |
def __init__(self, text, filename='<template>', filter=None, globals=None, builtins=None, extensions=None): | |
self.extensions = extensions or [] | |
text = Template.normalize_text(text) | |
code = self.compile_template(text, filename) | |
_, ext = os.path.splitext(filename) | |
filter = filter or self.FILTERS.get(ext, None) | |
self.content_type = self.CONTENT_TYPES.get(ext, None) | |
if globals is None: | |
globals = self.globals | |
if builtins is None: | |
builtins = TEMPLATE_BUILTINS | |
BaseTemplate.__init__(self, code=code, filename=filename, filter=filter, globals=globals, builtins=builtins) | |
def normalize_text(text): | |
"""Normalizes template text by correcting \r\n, tabs and BOM chars.""" | |
text = text.replace('\r\n', '\n').replace('\r', '\n').expandtabs() | |
if not text.endswith('\n'): | |
text += '\n' | |
# ignore BOM chars at the begining of template | |
BOM = '\xef\xbb\xbf' | |
if isinstance(text, str) and text.startswith(BOM): | |
text = text[len(BOM):] | |
# support fort \$ for backward-compatibility | |
text = text.replace(r'\$', '$$') | |
return text | |
normalize_text = staticmethod(normalize_text) | |
def __call__(self, *a, **kw): | |
__hidetraceback__ = True | |
import webapi as web | |
if 'headers' in web.ctx and self.content_type: | |
web.header('Content-Type', self.content_type, unique=True) | |
return BaseTemplate.__call__(self, *a, **kw) | |
def generate_code(text, filename, parser=None): | |
# parse the text | |
parser = parser or Parser() | |
rootnode = parser.parse(text, filename) | |
# generate python code from the parse tree | |
code = rootnode.emit(indent="").strip() | |
return safestr(code) | |
generate_code = staticmethod(generate_code) | |
def create_parser(self): | |
p = Parser() | |
for ext in self.extensions: | |
p = ext(p) | |
return p | |
def compile_template(self, template_string, filename): | |
code = Template.generate_code(template_string, filename, parser=self.create_parser()) | |
def get_source_line(filename, lineno): | |
try: | |
lines = open(filename).read().splitlines() | |
return lines[lineno] | |
except: | |
return None | |
try: | |
# compile the code first to report the errors, if any, with the filename | |
compiled_code = compile(code, filename, 'exec') | |
except SyntaxError, e: | |
# display template line that caused the error along with the traceback. | |
try: | |
e.msg += '\n\nTemplate traceback:\n File %s, line %s\n %s' % \ | |
(repr(e.filename), e.lineno, get_source_line(e.filename, e.lineno-1)) | |
except: | |
pass | |
raise | |
# make sure code is safe - but not with jython, it doesn't have a working compiler module | |
if not sys.platform.startswith('java'): | |
import compiler | |
ast = compiler.parse(code) | |
SafeVisitor().walk(ast, filename) | |
else: | |
import warnings | |
warnings.warn("SECURITY ISSUE: You are using Jython, which does not support checking templates for safety. Your templates can execute arbitrary code.") | |
return compiled_code | |
class CompiledTemplate(Template): | |
def __init__(self, f, filename): | |
Template.__init__(self, '', filename) | |
self.t = f | |
def compile_template(self, *a): | |
return None | |
def _compile(self, *a): | |
return None | |
class Render: | |
"""The most preferred way of using templates. | |
render = web.template.render('templates') | |
print render.foo() | |
Optional parameter can be `base` can be used to pass output of | |
every template through the base template. | |
render = web.template.render('templates', base='layout') | |
""" | |
def __init__(self, loc='templates', cache=None, base=None, **keywords): | |
self._loc = loc | |
self._keywords = keywords | |
if cache is None: | |
cache = not config.get('debug', False) | |
if cache: | |
self._cache = {} | |
else: | |
self._cache = None | |
if base and not hasattr(base, '__call__'): | |
# make base a function, so that it can be passed to sub-renders | |
self._base = lambda page: self._template(base)(page) | |
else: | |
self._base = base | |
def _add_global(self, obj, name=None): | |
"""Add a global to this rendering instance.""" | |
if 'globals' not in self._keywords: self._keywords['globals'] = {} | |
if not name: | |
name = obj.__name__ | |
self._keywords['globals'][name] = obj | |
def _lookup(self, name): | |
path = os.path.join(self._loc, name) | |
if os.path.isdir(path): | |
return 'dir', path | |
else: | |
path = self._findfile(path) | |
if path: | |
return 'file', path | |
else: | |
return 'none', None | |
def _load_template(self, name): | |
kind, path = self._lookup(name) | |
if kind == 'dir': | |
return Render(path, cache=self._cache is not None, base=self._base, **self._keywords) | |
elif kind == 'file': | |
return Template(open(path).read(), filename=path, **self._keywords) | |
else: | |
raise AttributeError, "No template named " + name | |
def _findfile(self, path_prefix): | |
p = [f for f in glob.glob(path_prefix + '.*') if not f.endswith('~')] # skip backup files | |
p.sort() # sort the matches for deterministic order | |
return p and p[0] | |
def _template(self, name): | |
if self._cache is not None: | |
if name not in self._cache: | |
self._cache[name] = self._load_template(name) | |
return self._cache[name] | |
else: | |
return self._load_template(name) | |
def __getattr__(self, name): | |
t = self._template(name) | |
if self._base and isinstance(t, Template): | |
def template(*a, **kw): | |
return self._base(t(*a, **kw)) | |
return template | |
else: | |
return self._template(name) | |
class GAE_Render(Render): | |
# Render gets over-written. make a copy here. | |
super = Render | |
def __init__(self, loc, *a, **kw): | |
GAE_Render.super.__init__(self, loc, *a, **kw) | |
import types | |
if isinstance(loc, types.ModuleType): | |
self.mod = loc | |
else: | |
name = loc.rstrip('/').replace('/', '.') | |
self.mod = __import__(name, None, None, ['x']) | |
self.mod.__dict__.update(kw.get('builtins', TEMPLATE_BUILTINS)) | |
self.mod.__dict__.update(Template.globals) | |
self.mod.__dict__.update(kw.get('globals', {})) | |
def _load_template(self, name): | |
t = getattr(self.mod, name) | |
import types | |
if isinstance(t, types.ModuleType): | |
return GAE_Render(t, cache=self._cache is not None, base=self._base, **self._keywords) | |
else: | |
return t | |
render = Render | |
# setup render for Google App Engine. | |
try: | |
from google import appengine | |
render = Render = GAE_Render | |
except ImportError: | |
pass | |
def frender(path, **keywords): | |
"""Creates a template from the given file path. | |
""" | |
return Template(open(path).read(), filename=path, **keywords) | |
def compile_templates(root): | |
"""Compiles templates to python code.""" | |
re_start = re_compile('^', re.M) | |
for dirpath, dirnames, filenames in os.walk(root): | |
filenames = [f for f in filenames if not f.startswith('.') and not f.endswith('~') and not f.startswith('__init__.py')] | |
for d in dirnames[:]: | |
if d.startswith('.'): | |
dirnames.remove(d) # don't visit this dir | |
out = open(os.path.join(dirpath, '__init__.py'), 'w') | |
out.write('from web.template import CompiledTemplate, ForLoop, TemplateResult\n\n') | |
if dirnames: | |
out.write("import " + ", ".join(dirnames)) | |
out.write("\n") | |
for f in filenames: | |
path = os.path.join(dirpath, f) | |
if '.' in f: | |
name, _ = f.split('.', 1) | |
else: | |
name = f | |
text = open(path).read() | |
text = Template.normalize_text(text) | |
code = Template.generate_code(text, path) | |
code = code.replace("__template__", name, 1) | |
out.write(code) | |
out.write('\n\n') | |
out.write('%s = CompiledTemplate(%s, %s)\n' % (name, name, repr(path))) | |
out.write("join_ = %s._join; escape_ = %s._escape\n\n" % (name, name)) | |
# create template to make sure it compiles | |
t = Template(open(path).read(), path) | |
out.close() | |
class ParseError(Exception): | |
pass | |
class SecurityError(Exception): | |
"""The template seems to be trying to do something naughty.""" | |
pass | |
# Enumerate all the allowed AST nodes | |
ALLOWED_AST_NODES = [ | |
"Add", "And", | |
# "AssAttr", | |
"AssList", "AssName", "AssTuple", | |
# "Assert", | |
"Assign", "AugAssign", | |
# "Backquote", | |
"Bitand", "Bitor", "Bitxor", "Break", | |
"CallFunc","Class", "Compare", "Const", "Continue", | |
"Decorators", "Dict", "Discard", "Div", | |
"Ellipsis", "EmptyNode", | |
# "Exec", | |
"Expression", "FloorDiv", "For", | |
# "From", | |
"Function", | |
"GenExpr", "GenExprFor", "GenExprIf", "GenExprInner", | |
"Getattr", | |
# "Global", | |
"If", "IfExp", | |
# "Import", | |
"Invert", "Keyword", "Lambda", "LeftShift", | |
"List", "ListComp", "ListCompFor", "ListCompIf", "Mod", | |
"Module", | |
"Mul", "Name", "Not", "Or", "Pass", "Power", | |
# "Print", "Printnl", "Raise", | |
"Return", "RightShift", "Slice", "Sliceobj", | |
"Stmt", "Sub", "Subscript", | |
# "TryExcept", "TryFinally", | |
"Tuple", "UnaryAdd", "UnarySub", | |
"While", "With", "Yield", | |
] | |
class SafeVisitor(object): | |
""" | |
Make sure code is safe by walking through the AST. | |
Code considered unsafe if: | |
* it has restricted AST nodes | |
* it is trying to access resricted attributes | |
Adopted from http://www.zafar.se/bkz/uploads/safe.txt (public domain, Babar K. Zafar) | |
""" | |
def __init__(self): | |
"Initialize visitor by generating callbacks for all AST node types." | |
self.errors = [] | |
def walk(self, ast, filename): | |
"Validate each node in AST and raise SecurityError if the code is not safe." | |
self.filename = filename | |
self.visit(ast) | |
if self.errors: | |
raise SecurityError, '\n'.join([str(err) for err in self.errors]) | |
def visit(self, node, *args): | |
"Recursively validate node and all of its children." | |
def classname(obj): | |
return obj.__class__.__name__ | |
nodename = classname(node) | |
fn = getattr(self, 'visit' + nodename, None) | |
if fn: | |
fn(node, *args) | |
else: | |
if nodename not in ALLOWED_AST_NODES: | |
self.fail(node, *args) | |
for child in node.getChildNodes(): | |
self.visit(child, *args) | |
def visitName(self, node, *args): | |
"Disallow any attempts to access a restricted attr." | |
#self.assert_attr(node.getChildren()[0], node) | |
pass | |
def visitGetattr(self, node, *args): | |
"Disallow any attempts to access a restricted attribute." | |
self.assert_attr(node.attrname, node) | |
def assert_attr(self, attrname, node): | |
if self.is_unallowed_attr(attrname): | |
lineno = self.get_node_lineno(node) | |
e = SecurityError("%s:%d - access to attribute '%s' is denied" % (self.filename, lineno, attrname)) | |
self.errors.append(e) | |
def is_unallowed_attr(self, name): | |
return name.startswith('_') \ | |
or name.startswith('func_') \ | |
or name.startswith('im_') | |
def get_node_lineno(self, node): | |
return (node.lineno) and node.lineno or 0 | |
def fail(self, node, *args): | |
"Default callback for unallowed AST nodes." | |
lineno = self.get_node_lineno(node) | |
nodename = node.__class__.__name__ | |
e = SecurityError("%s:%d - execution of '%s' statements is denied" % (self.filename, lineno, nodename)) | |
self.errors.append(e) | |
class TemplateResult(object, DictMixin): | |
"""Dictionary like object for storing template output. | |
The result of a template execution is usally a string, but sometimes it | |
contains attributes set using $var. This class provides a simple | |
dictionary like interface for storing the output of the template and the | |
attributes. The output is stored with a special key __body__. Convering | |
the the TemplateResult to string or unicode returns the value of __body__. | |
When the template is in execution, the output is generated part by part | |
and those parts are combined at the end. Parts are added to the | |
TemplateResult by calling the `extend` method and the parts are combined | |
seemlessly when __body__ is accessed. | |
>>> d = TemplateResult(__body__='hello, world', x='foo') | |
>>> d | |
<TemplateResult: {'__body__': 'hello, world', 'x': 'foo'}> | |
>>> print d | |
hello, world | |
>>> d.x | |
'foo' | |
>>> d = TemplateResult() | |
>>> d.extend([u'hello', u'world']) | |
>>> d | |
<TemplateResult: {'__body__': u'helloworld'}> | |
""" | |
def __init__(self, *a, **kw): | |
self.__dict__["_d"] = dict(*a, **kw) | |
self._d.setdefault("__body__", u'') | |
self.__dict__['_parts'] = [] | |
self.__dict__["extend"] = self._parts.extend | |
self._d.setdefault("__body__", None) | |
def keys(self): | |
return self._d.keys() | |
def _prepare_body(self): | |
"""Prepare value of __body__ by joining parts. | |
""" | |
if self._parts: | |
value = u"".join(self._parts) | |
self._parts[:] = [] | |
body = self._d.get('__body__') | |
if body: | |
self._d['__body__'] = body + value | |
else: | |
self._d['__body__'] = value | |
def __getitem__(self, name): | |
if name == "__body__": | |
self._prepare_body() | |
return self._d[name] | |
def __setitem__(self, name, value): | |
if name == "__body__": | |
self._prepare_body() | |
return self._d.__setitem__(name, value) | |
def __delitem__(self, name): | |
if name == "__body__": | |
self._prepare_body() | |
return self._d.__delitem__(name) | |
def __getattr__(self, key): | |
try: | |
return self[key] | |
except KeyError, k: | |
raise AttributeError, k | |
def __setattr__(self, key, value): | |
self[key] = value | |
def __delattr__(self, key): | |
try: | |
del self[key] | |
except KeyError, k: | |
raise AttributeError, k | |
def __unicode__(self): | |
self._prepare_body() | |
return self["__body__"] | |
def __str__(self): | |
self._prepare_body() | |
return self["__body__"].encode('utf-8') | |
def __repr__(self): | |
self._prepare_body() | |
return "<TemplateResult: %s>" % self._d | |
def test(): | |
r"""Doctest for testing template module. | |
Define a utility function to run template test. | |
>>> class TestResult: | |
... def __init__(self, t): self.t = t | |
... def __getattr__(self, name): return getattr(self.t, name) | |
... def __repr__(self): return repr(unicode(self)) | |
... | |
>>> def t(code, **keywords): | |
... tmpl = Template(code, **keywords) | |
... return lambda *a, **kw: TestResult(tmpl(*a, **kw)) | |
... | |
Simple tests. | |
>>> t('1')() | |
u'1\n' | |
>>> t('$def with ()\n1')() | |
u'1\n' | |
>>> t('$def with (a)\n$a')(1) | |
u'1\n' | |
>>> t('$def with (a=0)\n$a')(1) | |
u'1\n' | |
>>> t('$def with (a=0)\n$a')(a=1) | |
u'1\n' | |
Test complicated expressions. | |
>>> t('$def with (x)\n$x.upper()')('hello') | |
u'HELLO\n' | |
>>> t('$(2 * 3 + 4 * 5)')() | |
u'26\n' | |
>>> t('${2 * 3 + 4 * 5}')() | |
u'26\n' | |
>>> t('$def with (limit)\nkeep $(limit)ing.')('go') | |
u'keep going.\n' | |
>>> t('$def with (a)\n$a.b[0]')(storage(b=[1])) | |
u'1\n' | |
Test html escaping. | |
>>> t('$def with (x)\n$x', filename='a.html')('<html>') | |
u'<html>\n' | |
>>> t('$def with (x)\n$x', filename='a.txt')('<html>') | |
u'<html>\n' | |
Test if, for and while. | |
>>> t('$if 1: 1')() | |
u'1\n' | |
>>> t('$if 1:\n 1')() | |
u'1\n' | |
>>> t('$if 1:\n 1\\')() | |
u'1' | |
>>> t('$if 0: 0\n$elif 1: 1')() | |
u'1\n' | |
>>> t('$if 0: 0\n$elif None: 0\n$else: 1')() | |
u'1\n' | |
>>> t('$if 0 < 1 and 1 < 2: 1')() | |
u'1\n' | |
>>> t('$for x in [1, 2, 3]: $x')() | |
u'1\n2\n3\n' | |
>>> t('$def with (d)\n$for k, v in d.iteritems(): $k')({1: 1}) | |
u'1\n' | |
>>> t('$for x in [1, 2, 3]:\n\t$x')() | |
u' 1\n 2\n 3\n' | |
>>> t('$def with (a)\n$while a and a.pop():1')([1, 2, 3]) | |
u'1\n1\n1\n' | |
The space after : must be ignored. | |
>>> t('$if True: foo')() | |
u'foo\n' | |
Test loop.xxx. | |
>>> t("$for i in range(5):$loop.index, $loop.parity")() | |
u'1, odd\n2, even\n3, odd\n4, even\n5, odd\n' | |
>>> t("$for i in range(2):\n $for j in range(2):$loop.parent.parity $loop.parity")() | |
u'odd odd\nodd even\neven odd\neven even\n' | |
Test assignment. | |
>>> t('$ a = 1\n$a')() | |
u'1\n' | |
>>> t('$ a = [1]\n$a[0]')() | |
u'1\n' | |
>>> t('$ a = {1: 1}\n$a.keys()[0]')() | |
u'1\n' | |
>>> t('$ a = []\n$if not a: 1')() | |
u'1\n' | |
>>> t('$ a = {}\n$if not a: 1')() | |
u'1\n' | |
>>> t('$ a = -1\n$a')() | |
u'-1\n' | |
>>> t('$ a = "1"\n$a')() | |
u'1\n' | |
Test comments. | |
>>> t('$# 0')() | |
u'\n' | |
>>> t('hello$#comment1\nhello$#comment2')() | |
u'hello\nhello\n' | |
>>> t('$#comment0\nhello$#comment1\nhello$#comment2')() | |
u'\nhello\nhello\n' | |
Test unicode. | |
>>> t('$def with (a)\n$a')(u'\u203d') | |
u'\u203d\n' | |
>>> t('$def with (a)\n$a')(u'\u203d'.encode('utf-8')) | |
u'\u203d\n' | |
>>> t(u'$def with (a)\n$a $:a')(u'\u203d') | |
u'\u203d \u203d\n' | |
>>> t(u'$def with ()\nfoo')() | |
u'foo\n' | |
>>> def f(x): return x | |
... | |
>>> t(u'$def with (f)\n$:f("x")')(f) | |
u'x\n' | |
>>> t('$def with (f)\n$:f("x")')(f) | |
u'x\n' | |
Test dollar escaping. | |
>>> t("Stop, $$money isn't evaluated.")() | |
u"Stop, $money isn't evaluated.\n" | |
>>> t("Stop, \$money isn't evaluated.")() | |
u"Stop, $money isn't evaluated.\n" | |
Test space sensitivity. | |
>>> t('$def with (x)\n$x')(1) | |
u'1\n' | |
>>> t('$def with(x ,y)\n$x')(1, 1) | |
u'1\n' | |
>>> t('$(1 + 2*3 + 4)')() | |
u'11\n' | |
Make sure globals are working. | |
>>> t('$x')() | |
Traceback (most recent call last): | |
... | |
NameError: global name 'x' is not defined | |
>>> t('$x', globals={'x': 1})() | |
u'1\n' | |
Can't change globals. | |
>>> t('$ x = 2\n$x', globals={'x': 1})() | |
u'2\n' | |
>>> t('$ x = x + 1\n$x', globals={'x': 1})() | |
Traceback (most recent call last): | |
... | |
UnboundLocalError: local variable 'x' referenced before assignment | |
Make sure builtins are customizable. | |
>>> t('$min(1, 2)')() | |
u'1\n' | |
>>> t('$min(1, 2)', builtins={})() | |
Traceback (most recent call last): | |
... | |
NameError: global name 'min' is not defined | |
Test vars. | |
>>> x = t('$var x: 1')() | |
>>> x.x | |
u'1' | |
>>> x = t('$var x = 1')() | |
>>> x.x | |
1 | |
>>> x = t('$var x: \n foo\n bar')() | |
>>> x.x | |
u'foo\nbar\n' | |
Test BOM chars. | |
>>> t('\xef\xbb\xbf$def with(x)\n$x')('foo') | |
u'foo\n' | |
Test for with weird cases. | |
>>> t('$for i in range(10)[1:5]:\n $i')() | |
u'1\n2\n3\n4\n' | |
>>> t("$for k, v in {'a': 1, 'b': 2}.items():\n $k $v")() | |
u'a 1\nb 2\n' | |
>>> t("$for k, v in ({'a': 1, 'b': 2}.items():\n $k $v")() | |
Traceback (most recent call last): | |
... | |
SyntaxError: invalid syntax | |
Test datetime. | |
>>> import datetime | |
>>> t("$def with (date)\n$date.strftime('%m %Y')")(datetime.datetime(2009, 1, 1)) | |
u'01 2009\n' | |
""" | |
pass | |
if __name__ == "__main__": | |
import sys | |
if '--compile' in sys.argv: | |
compile_templates(sys.argv[2]) | |
else: | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""test utilities | |
(part of web.py) | |
""" | |
import unittest | |
import sys, os | |
import web | |
TestCase = unittest.TestCase | |
TestSuite = unittest.TestSuite | |
def load_modules(names): | |
return [__import__(name, None, None, "x") for name in names] | |
def module_suite(module, classnames=None): | |
"""Makes a suite from a module.""" | |
if classnames: | |
return unittest.TestLoader().loadTestsFromNames(classnames, module) | |
elif hasattr(module, 'suite'): | |
return module.suite() | |
else: | |
return unittest.TestLoader().loadTestsFromModule(module) | |
def doctest_suite(module_names): | |
"""Makes a test suite from doctests.""" | |
import doctest | |
suite = TestSuite() | |
for mod in load_modules(module_names): | |
suite.addTest(doctest.DocTestSuite(mod)) | |
return suite | |
def suite(module_names): | |
"""Creates a suite from multiple modules.""" | |
suite = TestSuite() | |
for mod in load_modules(module_names): | |
suite.addTest(module_suite(mod)) | |
return suite | |
def runTests(suite): | |
runner = unittest.TextTestRunner() | |
return runner.run(suite) | |
def main(suite=None): | |
if not suite: | |
main_module = __import__('__main__') | |
# allow command line switches | |
args = [a for a in sys.argv[1:] if not a.startswith('-')] | |
suite = module_suite(main_module, args or None) | |
result = runTests(suite) | |
sys.exit(not result.wasSuccessful()) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
General Utilities | |
(part of web.py) | |
""" | |
__all__ = [ | |
"Storage", "storage", "storify", | |
"Counter", "counter", | |
"iters", | |
"rstrips", "lstrips", "strips", | |
"safeunicode", "safestr", "utf8", | |
"TimeoutError", "timelimit", | |
"Memoize", "memoize", | |
"re_compile", "re_subm", | |
"group", "uniq", "iterview", | |
"IterBetter", "iterbetter", | |
"safeiter", "safewrite", | |
"dictreverse", "dictfind", "dictfindall", "dictincr", "dictadd", | |
"requeue", "restack", | |
"listget", "intget", "datestr", | |
"numify", "denumify", "commify", "dateify", | |
"nthstr", "cond", | |
"CaptureStdout", "capturestdout", "Profile", "profile", | |
"tryall", | |
"ThreadedDict", "threadeddict", | |
"autoassign", | |
"to36", | |
"safemarkdown", | |
"sendmail" | |
] | |
import re, sys, time, threading, itertools, traceback, os | |
try: | |
import subprocess | |
except ImportError: | |
subprocess = None | |
try: import datetime | |
except ImportError: pass | |
try: set | |
except NameError: | |
from sets import Set as set | |
try: | |
from threading import local as threadlocal | |
except ImportError: | |
from python23 import threadlocal | |
class Storage(dict): | |
""" | |
A Storage object is like a dictionary except `obj.foo` can be used | |
in addition to `obj['foo']`. | |
>>> o = storage(a=1) | |
>>> o.a | |
1 | |
>>> o['a'] | |
1 | |
>>> o.a = 2 | |
>>> o['a'] | |
2 | |
>>> del o.a | |
>>> o.a | |
Traceback (most recent call last): | |
... | |
AttributeError: 'a' | |
""" | |
def __getattr__(self, key): | |
try: | |
return self[key] | |
except KeyError, k: | |
raise AttributeError, k | |
def __setattr__(self, key, value): | |
self[key] = value | |
def __delattr__(self, key): | |
try: | |
del self[key] | |
except KeyError, k: | |
raise AttributeError, k | |
def __repr__(self): | |
return '<Storage ' + dict.__repr__(self) + '>' | |
storage = Storage | |
def storify(mapping, *requireds, **defaults): | |
""" | |
Creates a `storage` object from dictionary `mapping`, raising `KeyError` if | |
d doesn't have all of the keys in `requireds` and using the default | |
values for keys found in `defaults`. | |
For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of | |
`storage({'a':1, 'b':2, 'c':3})`. | |
If a `storify` value is a list (e.g. multiple values in a form submission), | |
`storify` returns the last element of the list, unless the key appears in | |
`defaults` as a list. Thus: | |
>>> storify({'a':[1, 2]}).a | |
2 | |
>>> storify({'a':[1, 2]}, a=[]).a | |
[1, 2] | |
>>> storify({'a':1}, a=[]).a | |
[1] | |
>>> storify({}, a=[]).a | |
[] | |
Similarly, if the value has a `value` attribute, `storify will return _its_ | |
value, unless the key appears in `defaults` as a dictionary. | |
>>> storify({'a':storage(value=1)}).a | |
1 | |
>>> storify({'a':storage(value=1)}, a={}).a | |
<Storage {'value': 1}> | |
>>> storify({}, a={}).a | |
{} | |
Optionally, keyword parameter `_unicode` can be passed to convert all values to unicode. | |
>>> storify({'x': 'a'}, _unicode=True) | |
<Storage {'x': u'a'}> | |
>>> storify({'x': storage(value='a')}, x={}, _unicode=True) | |
<Storage {'x': <Storage {'value': 'a'}>}> | |
>>> storify({'x': storage(value='a')}, _unicode=True) | |
<Storage {'x': u'a'}> | |
""" | |
_unicode = defaults.pop('_unicode', False) | |
def unicodify(s): | |
if _unicode and isinstance(s, str): return safeunicode(s) | |
else: return s | |
def getvalue(x): | |
if hasattr(x, 'file') and hasattr(x, 'value'): | |
return x.value | |
elif hasattr(x, 'value'): | |
return unicodify(x.value) | |
else: | |
return unicodify(x) | |
stor = Storage() | |
for key in requireds + tuple(mapping.keys()): | |
value = mapping[key] | |
if isinstance(value, list): | |
if isinstance(defaults.get(key), list): | |
value = [getvalue(x) for x in value] | |
else: | |
value = value[-1] | |
if not isinstance(defaults.get(key), dict): | |
value = getvalue(value) | |
if isinstance(defaults.get(key), list) and not isinstance(value, list): | |
value = [value] | |
setattr(stor, key, value) | |
for (key, value) in defaults.iteritems(): | |
result = value | |
if hasattr(stor, key): | |
result = stor[key] | |
if value == () and not isinstance(result, tuple): | |
result = (result,) | |
setattr(stor, key, result) | |
return stor | |
class Counter(storage): | |
"""Keeps count of how many times something is added. | |
>>> c = counter() | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('y') | |
>>> c | |
<Counter {'y': 1, 'x': 5}> | |
>>> c.most() | |
['x'] | |
""" | |
def add(self, n): | |
self.setdefault(n, 0) | |
self[n] += 1 | |
def most(self): | |
"""Returns the keys with maximum count.""" | |
m = max(self.itervalues()) | |
return [k for k, v in self.iteritems() if v == m] | |
def least(self): | |
"""Returns the keys with mininum count.""" | |
m = min(self.itervalues()) | |
return [k for k, v in self.iteritems() if v == m] | |
def percent(self, key): | |
"""Returns what percentage a certain key is of all entries. | |
>>> c = counter() | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('y') | |
>>> c.percent('x') | |
0.75 | |
>>> c.percent('y') | |
0.25 | |
""" | |
return float(self[key])/sum(self.values()) | |
def sorted_keys(self): | |
"""Returns keys sorted by value. | |
>>> c = counter() | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('y') | |
>>> c.sorted_keys() | |
['x', 'y'] | |
""" | |
return sorted(self.keys(), key=lambda k: self[k], reverse=True) | |
def sorted_values(self): | |
"""Returns values sorted by value. | |
>>> c = counter() | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('y') | |
>>> c.sorted_values() | |
[2, 1] | |
""" | |
return [self[k] for k in self.sorted_keys()] | |
def sorted_items(self): | |
"""Returns items sorted by value. | |
>>> c = counter() | |
>>> c.add('x') | |
>>> c.add('x') | |
>>> c.add('y') | |
>>> c.sorted_items() | |
[('x', 2), ('y', 1)] | |
""" | |
return [(k, self[k]) for k in self.sorted_keys()] | |
def __repr__(self): | |
return '<Counter ' + dict.__repr__(self) + '>' | |
counter = Counter | |
iters = [list, tuple] | |
import __builtin__ | |
if hasattr(__builtin__, 'set'): | |
iters.append(set) | |
if hasattr(__builtin__, 'frozenset'): | |
iters.append(set) | |
if sys.version_info < (2,6): # sets module deprecated in 2.6 | |
try: | |
from sets import Set | |
iters.append(Set) | |
except ImportError: | |
pass | |
class _hack(tuple): pass | |
iters = _hack(iters) | |
iters.__doc__ = """ | |
A list of iterable items (like lists, but not strings). Includes whichever | |
of lists, tuples, sets, and Sets are available in this version of Python. | |
""" | |
def _strips(direction, text, remove): | |
if isinstance(remove, iters): | |
for subr in remove: | |
text = _strips(direction, text, subr) | |
return text | |
if direction == 'l': | |
if text.startswith(remove): | |
return text[len(remove):] | |
elif direction == 'r': | |
if text.endswith(remove): | |
return text[:-len(remove)] | |
else: | |
raise ValueError, "Direction needs to be r or l." | |
return text | |
def rstrips(text, remove): | |
""" | |
removes the string `remove` from the right of `text` | |
>>> rstrips("foobar", "bar") | |
'foo' | |
""" | |
return _strips('r', text, remove) | |
def lstrips(text, remove): | |
""" | |
removes the string `remove` from the left of `text` | |
>>> lstrips("foobar", "foo") | |
'bar' | |
>>> lstrips('http://foo.org/', ['http://', 'https://']) | |
'foo.org/' | |
>>> lstrips('FOOBARBAZ', ['FOO', 'BAR']) | |
'BAZ' | |
>>> lstrips('FOOBARBAZ', ['BAR', 'FOO']) | |
'BARBAZ' | |
""" | |
return _strips('l', text, remove) | |
def strips(text, remove): | |
""" | |
removes the string `remove` from the both sides of `text` | |
>>> strips("foobarfoo", "foo") | |
'bar' | |
""" | |
return rstrips(lstrips(text, remove), remove) | |
def safeunicode(obj, encoding='utf-8'): | |
r""" | |
Converts any given object to unicode string. | |
>>> safeunicode('hello') | |
u'hello' | |
>>> safeunicode(2) | |
u'2' | |
>>> safeunicode('\xe1\x88\xb4') | |
u'\u1234' | |
""" | |
t = type(obj) | |
if t is unicode: | |
return obj | |
elif t is str: | |
return obj.decode(encoding) | |
elif t in [int, float, bool]: | |
return unicode(obj) | |
elif hasattr(obj, '__unicode__') or isinstance(obj, unicode): | |
return unicode(obj) | |
else: | |
return str(obj).decode(encoding) | |
def safestr(obj, encoding='utf-8'): | |
r""" | |
Converts any given object to utf-8 encoded string. | |
>>> safestr('hello') | |
'hello' | |
>>> safestr(u'\u1234') | |
'\xe1\x88\xb4' | |
>>> safestr(2) | |
'2' | |
""" | |
if isinstance(obj, unicode): | |
return obj.encode(encoding) | |
elif isinstance(obj, str): | |
return obj | |
elif hasattr(obj, 'next'): # iterator | |
return itertools.imap(safestr, obj) | |
else: | |
return str(obj) | |
# for backward-compatibility | |
utf8 = safestr | |
class TimeoutError(Exception): pass | |
def timelimit(timeout): | |
""" | |
A decorator to limit a function to `timeout` seconds, raising `TimeoutError` | |
if it takes longer. | |
>>> import time | |
>>> def meaningoflife(): | |
... time.sleep(.2) | |
... return 42 | |
>>> | |
>>> timelimit(.1)(meaningoflife)() | |
Traceback (most recent call last): | |
... | |
TimeoutError: took too long | |
>>> timelimit(1)(meaningoflife)() | |
42 | |
_Caveat:_ The function isn't stopped after `timeout` seconds but continues | |
executing in a separate thread. (There seems to be no way to kill a thread.) | |
inspired by <http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/473878> | |
""" | |
def _1(function): | |
def _2(*args, **kw): | |
class Dispatch(threading.Thread): | |
def __init__(self): | |
threading.Thread.__init__(self) | |
self.result = None | |
self.error = None | |
self.setDaemon(True) | |
self.start() | |
def run(self): | |
try: | |
self.result = function(*args, **kw) | |
except: | |
self.error = sys.exc_info() | |
c = Dispatch() | |
c.join(timeout) | |
if c.isAlive(): | |
raise TimeoutError, 'took too long' | |
if c.error: | |
raise c.error[0], c.error[1] | |
return c.result | |
return _2 | |
return _1 | |
class Memoize: | |
""" | |
'Memoizes' a function, caching its return values for each input. | |
If `expires` is specified, values are recalculated after `expires` seconds. | |
If `background` is specified, values are recalculated in a separate thread. | |
>>> calls = 0 | |
>>> def howmanytimeshaveibeencalled(): | |
... global calls | |
... calls += 1 | |
... return calls | |
>>> fastcalls = memoize(howmanytimeshaveibeencalled) | |
>>> howmanytimeshaveibeencalled() | |
1 | |
>>> howmanytimeshaveibeencalled() | |
2 | |
>>> fastcalls() | |
3 | |
>>> fastcalls() | |
3 | |
>>> import time | |
>>> fastcalls = memoize(howmanytimeshaveibeencalled, .1, background=False) | |
>>> fastcalls() | |
4 | |
>>> fastcalls() | |
4 | |
>>> time.sleep(.2) | |
>>> fastcalls() | |
5 | |
>>> def slowfunc(): | |
... time.sleep(.1) | |
... return howmanytimeshaveibeencalled() | |
>>> fastcalls = memoize(slowfunc, .2, background=True) | |
>>> fastcalls() | |
6 | |
>>> timelimit(.05)(fastcalls)() | |
6 | |
>>> time.sleep(.2) | |
>>> timelimit(.05)(fastcalls)() | |
6 | |
>>> timelimit(.05)(fastcalls)() | |
6 | |
>>> time.sleep(.2) | |
>>> timelimit(.05)(fastcalls)() | |
7 | |
>>> fastcalls = memoize(slowfunc, None, background=True) | |
>>> threading.Thread(target=fastcalls).start() | |
>>> time.sleep(.01) | |
>>> fastcalls() | |
9 | |
""" | |
def __init__(self, func, expires=None, background=True): | |
self.func = func | |
self.cache = {} | |
self.expires = expires | |
self.background = background | |
self.running = {} | |
def __call__(self, *args, **keywords): | |
key = (args, tuple(keywords.items())) | |
if not self.running.get(key): | |
self.running[key] = threading.Lock() | |
def update(block=False): | |
if self.running[key].acquire(block): | |
try: | |
self.cache[key] = (self.func(*args, **keywords), time.time()) | |
finally: | |
self.running[key].release() | |
if key not in self.cache: | |
update(block=True) | |
elif self.expires and (time.time() - self.cache[key][1]) > self.expires: | |
if self.background: | |
threading.Thread(target=update).start() | |
else: | |
update() | |
return self.cache[key][0] | |
memoize = Memoize | |
re_compile = memoize(re.compile) #@@ threadsafe? | |
re_compile.__doc__ = """ | |
A memoized version of re.compile. | |
""" | |
class _re_subm_proxy: | |
def __init__(self): | |
self.match = None | |
def __call__(self, match): | |
self.match = match | |
return '' | |
def re_subm(pat, repl, string): | |
""" | |
Like re.sub, but returns the replacement _and_ the match object. | |
>>> t, m = re_subm('g(oo+)fball', r'f\\1lish', 'goooooofball') | |
>>> t | |
'foooooolish' | |
>>> m.groups() | |
('oooooo',) | |
""" | |
compiled_pat = re_compile(pat) | |
proxy = _re_subm_proxy() | |
compiled_pat.sub(proxy.__call__, string) | |
return compiled_pat.sub(repl, string), proxy.match | |
def group(seq, size): | |
""" | |
Returns an iterator over a series of lists of length size from iterable. | |
>>> list(group([1,2,3,4], 2)) | |
[[1, 2], [3, 4]] | |
>>> list(group([1,2,3,4,5], 2)) | |
[[1, 2], [3, 4], [5]] | |
""" | |
def take(seq, n): | |
for i in xrange(n): | |
yield seq.next() | |
if not hasattr(seq, 'next'): | |
seq = iter(seq) | |
while True: | |
x = list(take(seq, size)) | |
if x: | |
yield x | |
else: | |
break | |
def uniq(seq, key=None): | |
""" | |
Removes duplicate elements from a list while preserving the order of the rest. | |
>>> uniq([9,0,2,1,0]) | |
[9, 0, 2, 1] | |
The value of the optional `key` parameter should be a function that | |
takes a single argument and returns a key to test the uniqueness. | |
>>> uniq(["Foo", "foo", "bar"], key=lambda s: s.lower()) | |
['Foo', 'bar'] | |
""" | |
key = key or (lambda x: x) | |
seen = set() | |
result = [] | |
for v in seq: | |
k = key(v) | |
if k in seen: | |
continue | |
seen.add(k) | |
result.append(v) | |
return result | |
def iterview(x): | |
""" | |
Takes an iterable `x` and returns an iterator over it | |
which prints its progress to stderr as it iterates through. | |
""" | |
WIDTH = 70 | |
def plainformat(n, lenx): | |
return '%5.1f%% (%*d/%d)' % ((float(n)/lenx)*100, len(str(lenx)), n, lenx) | |
def bars(size, n, lenx): | |
val = int((float(n)*size)/lenx + 0.5) | |
if size - val: | |
spacing = ">" + (" "*(size-val))[1:] | |
else: | |
spacing = "" | |
return "[%s%s]" % ("="*val, spacing) | |
def eta(elapsed, n, lenx): | |
if n == 0: | |
return '--:--:--' | |
if n == lenx: | |
secs = int(elapsed) | |
else: | |
secs = int((elapsed/n) * (lenx-n)) | |
mins, secs = divmod(secs, 60) | |
hrs, mins = divmod(mins, 60) | |
return '%02d:%02d:%02d' % (hrs, mins, secs) | |
def format(starttime, n, lenx): | |
out = plainformat(n, lenx) + ' ' | |
if n == lenx: | |
end = ' ' | |
else: | |
end = ' ETA ' | |
end += eta(time.time() - starttime, n, lenx) | |
out += bars(WIDTH - len(out) - len(end), n, lenx) | |
out += end | |
return out | |
starttime = time.time() | |
lenx = len(x) | |
for n, y in enumerate(x): | |
sys.stderr.write('\r' + format(starttime, n, lenx)) | |
yield y | |
sys.stderr.write('\r' + format(starttime, n+1, lenx) + '\n') | |
class IterBetter: | |
""" | |
Returns an object that can be used as an iterator | |
but can also be used via __getitem__ (although it | |
cannot go backwards -- that is, you cannot request | |
`iterbetter[0]` after requesting `iterbetter[1]`). | |
>>> import itertools | |
>>> c = iterbetter(itertools.count()) | |
>>> c[1] | |
1 | |
>>> c[5] | |
5 | |
>>> c[3] | |
Traceback (most recent call last): | |
... | |
IndexError: already passed 3 | |
For boolean test, IterBetter peeps at first value in the itertor without effecting the iteration. | |
>>> c = iterbetter(iter(range(5))) | |
>>> bool(c) | |
True | |
>>> list(c) | |
[0, 1, 2, 3, 4] | |
>>> c = iterbetter(iter([])) | |
>>> bool(c) | |
False | |
>>> list(c) | |
[] | |
""" | |
def __init__(self, iterator): | |
self.i, self.c = iterator, 0 | |
def __iter__(self): | |
if hasattr(self, "_head"): | |
yield self._head | |
while 1: | |
yield self.i.next() | |
self.c += 1 | |
def __getitem__(self, i): | |
#todo: slices | |
if i < self.c: | |
raise IndexError, "already passed "+str(i) | |
try: | |
while i > self.c: | |
self.i.next() | |
self.c += 1 | |
# now self.c == i | |
self.c += 1 | |
return self.i.next() | |
except StopIteration: | |
raise IndexError, str(i) | |
def __nonzero__(self): | |
if hasattr(self, "__len__"): | |
return len(self) != 0 | |
elif hasattr(self, "_head"): | |
return True | |
else: | |
try: | |
self._head = self.i.next() | |
except StopIteration: | |
return False | |
else: | |
return True | |
iterbetter = IterBetter | |
def safeiter(it, cleanup=None, ignore_errors=True): | |
"""Makes an iterator safe by ignoring the exceptions occured during the iteration. | |
""" | |
def next(): | |
while True: | |
try: | |
return it.next() | |
except StopIteration: | |
raise | |
except: | |
traceback.print_exc() | |
it = iter(it) | |
while True: | |
yield next() | |
def safewrite(filename, content): | |
"""Writes the content to a temp file and then moves the temp file to | |
given filename to avoid overwriting the existing file in case of errors. | |
""" | |
f = file(filename + '.tmp', 'w') | |
f.write(content) | |
f.close() | |
os.rename(f.name, filename) | |
def dictreverse(mapping): | |
""" | |
Returns a new dictionary with keys and values swapped. | |
>>> dictreverse({1: 2, 3: 4}) | |
{2: 1, 4: 3} | |
""" | |
return dict([(value, key) for (key, value) in mapping.iteritems()]) | |
def dictfind(dictionary, element): | |
""" | |
Returns a key whose value in `dictionary` is `element` | |
or, if none exists, None. | |
>>> d = {1:2, 3:4} | |
>>> dictfind(d, 4) | |
3 | |
>>> dictfind(d, 5) | |
""" | |
for (key, value) in dictionary.iteritems(): | |
if element is value: | |
return key | |
def dictfindall(dictionary, element): | |
""" | |
Returns the keys whose values in `dictionary` are `element` | |
or, if none exists, []. | |
>>> d = {1:4, 3:4} | |
>>> dictfindall(d, 4) | |
[1, 3] | |
>>> dictfindall(d, 5) | |
[] | |
""" | |
res = [] | |
for (key, value) in dictionary.iteritems(): | |
if element is value: | |
res.append(key) | |
return res | |
def dictincr(dictionary, element): | |
""" | |
Increments `element` in `dictionary`, | |
setting it to one if it doesn't exist. | |
>>> d = {1:2, 3:4} | |
>>> dictincr(d, 1) | |
3 | |
>>> d[1] | |
3 | |
>>> dictincr(d, 5) | |
1 | |
>>> d[5] | |
1 | |
""" | |
dictionary.setdefault(element, 0) | |
dictionary[element] += 1 | |
return dictionary[element] | |
def dictadd(*dicts): | |
""" | |
Returns a dictionary consisting of the keys in the argument dictionaries. | |
If they share a key, the value from the last argument is used. | |
>>> dictadd({1: 0, 2: 0}, {2: 1, 3: 1}) | |
{1: 0, 2: 1, 3: 1} | |
""" | |
result = {} | |
for dct in dicts: | |
result.update(dct) | |
return result | |
def requeue(queue, index=-1): | |
"""Returns the element at index after moving it to the beginning of the queue. | |
>>> x = [1, 2, 3, 4] | |
>>> requeue(x) | |
4 | |
>>> x | |
[4, 1, 2, 3] | |
""" | |
x = queue.pop(index) | |
queue.insert(0, x) | |
return x | |
def restack(stack, index=0): | |
"""Returns the element at index after moving it to the top of stack. | |
>>> x = [1, 2, 3, 4] | |
>>> restack(x) | |
1 | |
>>> x | |
[2, 3, 4, 1] | |
""" | |
x = stack.pop(index) | |
stack.append(x) | |
return x | |
def listget(lst, ind, default=None): | |
""" | |
Returns `lst[ind]` if it exists, `default` otherwise. | |
>>> listget(['a'], 0) | |
'a' | |
>>> listget(['a'], 1) | |
>>> listget(['a'], 1, 'b') | |
'b' | |
""" | |
if len(lst)-1 < ind: | |
return default | |
return lst[ind] | |
def intget(integer, default=None): | |
""" | |
Returns `integer` as an int or `default` if it can't. | |
>>> intget('3') | |
3 | |
>>> intget('3a') | |
>>> intget('3a', 0) | |
0 | |
""" | |
try: | |
return int(integer) | |
except (TypeError, ValueError): | |
return default | |
def datestr(then, now=None): | |
""" | |
Converts a (UTC) datetime object to a nice string representation. | |
>>> from datetime import datetime, timedelta | |
>>> d = datetime(1970, 5, 1) | |
>>> datestr(d, now=d) | |
'0 microseconds ago' | |
>>> for t, v in { | |
... timedelta(microseconds=1): '1 microsecond ago', | |
... timedelta(microseconds=2): '2 microseconds ago', | |
... -timedelta(microseconds=1): '1 microsecond from now', | |
... -timedelta(microseconds=2): '2 microseconds from now', | |
... timedelta(microseconds=2000): '2 milliseconds ago', | |
... timedelta(seconds=2): '2 seconds ago', | |
... timedelta(seconds=2*60): '2 minutes ago', | |
... timedelta(seconds=2*60*60): '2 hours ago', | |
... timedelta(days=2): '2 days ago', | |
... }.iteritems(): | |
... assert datestr(d, now=d+t) == v | |
>>> datestr(datetime(1970, 1, 1), now=d) | |
'January 1' | |
>>> datestr(datetime(1969, 1, 1), now=d) | |
'January 1, 1969' | |
>>> datestr(datetime(1970, 6, 1), now=d) | |
'June 1, 1970' | |
>>> datestr(None) | |
'' | |
""" | |
def agohence(n, what, divisor=None): | |
if divisor: n = n // divisor | |
out = str(abs(n)) + ' ' + what # '2 day' | |
if abs(n) != 1: out += 's' # '2 days' | |
out += ' ' # '2 days ' | |
if n < 0: | |
out += 'from now' | |
else: | |
out += 'ago' | |
return out # '2 days ago' | |
oneday = 24 * 60 * 60 | |
if not then: return "" | |
if not now: now = datetime.datetime.utcnow() | |
if type(now).__name__ == "DateTime": | |
now = datetime.datetime.fromtimestamp(now) | |
if type(then).__name__ == "DateTime": | |
then = datetime.datetime.fromtimestamp(then) | |
elif type(then).__name__ == "date": | |
then = datetime.datetime(then.year, then.month, then.day) | |
delta = now - then | |
deltaseconds = int(delta.days * oneday + delta.seconds + delta.microseconds * 1e-06) | |
deltadays = abs(deltaseconds) // oneday | |
if deltaseconds < 0: deltadays *= -1 # fix for oddity of floor | |
if deltadays: | |
if abs(deltadays) < 4: | |
return agohence(deltadays, 'day') | |
out = then.strftime('%B %e') # e.g. 'June 13' | |
if then.year != now.year or deltadays < 0: | |
out += ', %s' % then.year | |
return out | |
if int(deltaseconds): | |
if abs(deltaseconds) > (60 * 60): | |
return agohence(deltaseconds, 'hour', 60 * 60) | |
elif abs(deltaseconds) > 60: | |
return agohence(deltaseconds, 'minute', 60) | |
else: | |
return agohence(deltaseconds, 'second') | |
deltamicroseconds = delta.microseconds | |
if delta.days: deltamicroseconds = int(delta.microseconds - 1e6) # datetime oddity | |
if abs(deltamicroseconds) > 1000: | |
return agohence(deltamicroseconds, 'millisecond', 1000) | |
return agohence(deltamicroseconds, 'microsecond') | |
def numify(string): | |
""" | |
Removes all non-digit characters from `string`. | |
>>> numify('800-555-1212') | |
'8005551212' | |
>>> numify('800.555.1212') | |
'8005551212' | |
""" | |
return ''.join([c for c in str(string) if c.isdigit()]) | |
def denumify(string, pattern): | |
""" | |
Formats `string` according to `pattern`, where the letter X gets replaced | |
by characters from `string`. | |
>>> denumify("8005551212", "(XXX) XXX-XXXX") | |
'(800) 555-1212' | |
""" | |
out = [] | |
for c in pattern: | |
if c == "X": | |
out.append(string[0]) | |
string = string[1:] | |
else: | |
out.append(c) | |
return ''.join(out) | |
def commify(n): | |
""" | |
Add commas to an integer `n`. | |
>>> commify(1) | |
'1' | |
>>> commify(123) | |
'123' | |
>>> commify(1234) | |
'1,234' | |
>>> commify(1234567890) | |
'1,234,567,890' | |
>>> commify(123.0) | |
'123.0' | |
>>> commify(1234.5) | |
'1,234.5' | |
>>> commify(1234.56789) | |
'1,234.56789' | |
>>> commify('%.2f' % 1234.5) | |
'1,234.50' | |
>>> commify(None) | |
>>> | |
""" | |
if n is None: return None | |
n = str(n) | |
if '.' in n: | |
dollars, cents = n.split('.') | |
else: | |
dollars, cents = n, None | |
r = [] | |
for i, c in enumerate(str(dollars)[::-1]): | |
if i and (not (i % 3)): | |
r.insert(0, ',') | |
r.insert(0, c) | |
out = ''.join(r) | |
if cents: | |
out += '.' + cents | |
return out | |
def dateify(datestring): | |
""" | |
Formats a numified `datestring` properly. | |
""" | |
return denumify(datestring, "XXXX-XX-XX XX:XX:XX") | |
def nthstr(n): | |
""" | |
Formats an ordinal. | |
Doesn't handle negative numbers. | |
>>> nthstr(1) | |
'1st' | |
>>> nthstr(0) | |
'0th' | |
>>> [nthstr(x) for x in [2, 3, 4, 5, 10, 11, 12, 13, 14, 15]] | |
['2nd', '3rd', '4th', '5th', '10th', '11th', '12th', '13th', '14th', '15th'] | |
>>> [nthstr(x) for x in [91, 92, 93, 94, 99, 100, 101, 102]] | |
['91st', '92nd', '93rd', '94th', '99th', '100th', '101st', '102nd'] | |
>>> [nthstr(x) for x in [111, 112, 113, 114, 115]] | |
['111th', '112th', '113th', '114th', '115th'] | |
""" | |
assert n >= 0 | |
if n % 100 in [11, 12, 13]: return '%sth' % n | |
return {1: '%sst', 2: '%snd', 3: '%srd'}.get(n % 10, '%sth') % n | |
def cond(predicate, consequence, alternative=None): | |
""" | |
Function replacement for if-else to use in expressions. | |
>>> x = 2 | |
>>> cond(x % 2 == 0, "even", "odd") | |
'even' | |
>>> cond(x % 2 == 0, "even", "odd") + '_row' | |
'even_row' | |
""" | |
if predicate: | |
return consequence | |
else: | |
return alternative | |
class CaptureStdout: | |
""" | |
Captures everything `func` prints to stdout and returns it instead. | |
>>> def idiot(): | |
... print "foo" | |
>>> capturestdout(idiot)() | |
'foo\\n' | |
**WARNING:** Not threadsafe! | |
""" | |
def __init__(self, func): | |
self.func = func | |
def __call__(self, *args, **keywords): | |
from cStringIO import StringIO | |
# Not threadsafe! | |
out = StringIO() | |
oldstdout = sys.stdout | |
sys.stdout = out | |
try: | |
self.func(*args, **keywords) | |
finally: | |
sys.stdout = oldstdout | |
return out.getvalue() | |
capturestdout = CaptureStdout | |
class Profile: | |
""" | |
Profiles `func` and returns a tuple containing its output | |
and a string with human-readable profiling information. | |
>>> import time | |
>>> out, inf = profile(time.sleep)(.001) | |
>>> out | |
>>> inf[:10].strip() | |
'took 0.0' | |
""" | |
def __init__(self, func): | |
self.func = func | |
def __call__(self, *args): ##, **kw): kw unused | |
import hotshot, hotshot.stats, os, tempfile ##, time already imported | |
f, filename = tempfile.mkstemp() | |
os.close(f) | |
prof = hotshot.Profile(filename) | |
stime = time.time() | |
result = prof.runcall(self.func, *args) | |
stime = time.time() - stime | |
prof.close() | |
import cStringIO | |
out = cStringIO.StringIO() | |
stats = hotshot.stats.load(filename) | |
stats.stream = out | |
stats.strip_dirs() | |
stats.sort_stats('time', 'calls') | |
stats.print_stats(40) | |
stats.print_callers() | |
x = '\n\ntook '+ str(stime) + ' seconds\n' | |
x += out.getvalue() | |
# remove the tempfile | |
try: | |
os.remove(filename) | |
except IOError: | |
pass | |
return result, x | |
profile = Profile | |
import traceback | |
# hack for compatibility with Python 2.3: | |
if not hasattr(traceback, 'format_exc'): | |
from cStringIO import StringIO | |
def format_exc(limit=None): | |
strbuf = StringIO() | |
traceback.print_exc(limit, strbuf) | |
return strbuf.getvalue() | |
traceback.format_exc = format_exc | |
def tryall(context, prefix=None): | |
""" | |
Tries a series of functions and prints their results. | |
`context` is a dictionary mapping names to values; | |
the value will only be tried if it's callable. | |
>>> tryall(dict(j=lambda: True)) | |
j: True | |
---------------------------------------- | |
results: | |
True: 1 | |
For example, you might have a file `test/stuff.py` | |
with a series of functions testing various things in it. | |
At the bottom, have a line: | |
if __name__ == "__main__": tryall(globals()) | |
Then you can run `python test/stuff.py` and get the results of | |
all the tests. | |
""" | |
context = context.copy() # vars() would update | |
results = {} | |
for (key, value) in context.iteritems(): | |
if not hasattr(value, '__call__'): | |
continue | |
if prefix and not key.startswith(prefix): | |
continue | |
print key + ':', | |
try: | |
r = value() | |
dictincr(results, r) | |
print r | |
except: | |
print 'ERROR' | |
dictincr(results, 'ERROR') | |
print ' ' + '\n '.join(traceback.format_exc().split('\n')) | |
print '-'*40 | |
print 'results:' | |
for (key, value) in results.iteritems(): | |
print ' '*2, str(key)+':', value | |
class ThreadedDict(threadlocal): | |
""" | |
Thread local storage. | |
>>> d = ThreadedDict() | |
>>> d.x = 1 | |
>>> d.x | |
1 | |
>>> import threading | |
>>> def f(): d.x = 2 | |
... | |
>>> t = threading.Thread(target=f) | |
>>> t.start() | |
>>> t.join() | |
>>> d.x | |
1 | |
""" | |
_instances = set() | |
def __init__(self): | |
ThreadedDict._instances.add(self) | |
def __del__(self): | |
ThreadedDict._instances.remove(self) | |
def __hash__(self): | |
return id(self) | |
def clear_all(): | |
"""Clears all ThreadedDict instances. | |
""" | |
for t in ThreadedDict._instances: | |
t.clear() | |
clear_all = staticmethod(clear_all) | |
# Define all these methods to more or less fully emulate dict -- attribute access | |
# is built into threading.local. | |
def __getitem__(self, key): | |
return self.__dict__[key] | |
def __setitem__(self, key, value): | |
self.__dict__[key] = value | |
def __delitem__(self, key): | |
del self.__dict__[key] | |
def __contains__(self, key): | |
return key in self.__dict__ | |
has_key = __contains__ | |
def clear(self): | |
self.__dict__.clear() | |
def copy(self): | |
return self.__dict__.copy() | |
def get(self, key, default=None): | |
return self.__dict__.get(key, default) | |
def items(self): | |
return self.__dict__.items() | |
def iteritems(self): | |
return self.__dict__.iteritems() | |
def keys(self): | |
return self.__dict__.keys() | |
def iterkeys(self): | |
return self.__dict__.iterkeys() | |
iter = iterkeys | |
def values(self): | |
return self.__dict__.values() | |
def itervalues(self): | |
return self.__dict__.itervalues() | |
def pop(self, key, *args): | |
return self.__dict__.pop(key, *args) | |
def popitem(self): | |
return self.__dict__.popitem() | |
def setdefault(self, key, default=None): | |
return self.__dict__.setdefault(key, default) | |
def update(self, *args, **kwargs): | |
self.__dict__.update(*args, **kwargs) | |
def __repr__(self): | |
return '<ThreadedDict %r>' % self.__dict__ | |
__str__ = __repr__ | |
threadeddict = ThreadedDict | |
def autoassign(self, locals): | |
""" | |
Automatically assigns local variables to `self`. | |
>>> self = storage() | |
>>> autoassign(self, dict(a=1, b=2)) | |
>>> self | |
<Storage {'a': 1, 'b': 2}> | |
Generally used in `__init__` methods, as in: | |
def __init__(self, foo, bar, baz=1): autoassign(self, locals()) | |
""" | |
for (key, value) in locals.iteritems(): | |
if key == 'self': | |
continue | |
setattr(self, key, value) | |
def to36(q): | |
""" | |
Converts an integer to base 36 (a useful scheme for human-sayable IDs). | |
>>> to36(35) | |
'z' | |
>>> to36(119292) | |
'2k1o' | |
>>> int(to36(939387374), 36) | |
939387374 | |
>>> to36(0) | |
'0' | |
>>> to36(-393) | |
Traceback (most recent call last): | |
... | |
ValueError: must supply a positive integer | |
""" | |
if q < 0: raise ValueError, "must supply a positive integer" | |
letters = "0123456789abcdefghijklmnopqrstuvwxyz" | |
converted = [] | |
while q != 0: | |
q, r = divmod(q, 36) | |
converted.insert(0, letters[r]) | |
return "".join(converted) or '0' | |
r_url = re_compile('(?<!\()(http://(\S+))') | |
def safemarkdown(text): | |
""" | |
Converts text to HTML following the rules of Markdown, but blocking any | |
outside HTML input, so that only the things supported by Markdown | |
can be used. Also converts raw URLs to links. | |
(requires [markdown.py](http://webpy.org/markdown.py)) | |
""" | |
from markdown import markdown | |
if text: | |
text = text.replace('<', '<') | |
# TODO: automatically get page title? | |
text = r_url.sub(r'<\1>', text) | |
text = markdown(text) | |
return text | |
def sendmail(from_address, to_address, subject, message, headers=None, **kw): | |
""" | |
Sends the email message `message` with mail and envelope headers | |
for from `from_address_` to `to_address` with `subject`. | |
Additional email headers can be specified with the dictionary | |
`headers. | |
Optionally cc, bcc and attachments can be specified as keyword arguments. | |
Attachments must be an iterable and each attachment can be either a | |
filename or a file object or a dictionary with filename, content and | |
optionally content_type keys. | |
If `web.config.smtp_server` is set, it will send the message | |
to that SMTP server. Otherwise it will look for | |
`/usr/sbin/sendmail`, the typical location for the sendmail-style | |
binary. To use sendmail from a different path, set `web.config.sendmail_path`. | |
""" | |
attachments = kw.pop("attachments", []) | |
mail = _EmailMessage(from_address, to_address, subject, message, headers, **kw) | |
for a in attachments: | |
if isinstance(a, dict): | |
mail.attach(a['filename'], a['content'], a.get('content_type')) | |
elif hasattr(a, 'read'): # file | |
filename = os.path.basename(getattr(a, "name", "")) | |
content_type = getattr(a, 'content_type', None) | |
mail.attach(filename, a.read(), content_type) | |
elif isinstance(a, basestring): | |
f = open(a, 'rb') | |
content = f.read() | |
f.close() | |
filename = os.path.basename(a) | |
mail.attach(filename, content, None) | |
else: | |
raise ValueError, "Invalid attachment: %s" % repr(a) | |
mail.send() | |
class _EmailMessage: | |
def __init__(self, from_address, to_address, subject, message, headers=None, **kw): | |
def listify(x): | |
if not isinstance(x, list): | |
return [safestr(x)] | |
else: | |
return [safestr(a) for a in x] | |
subject = safestr(subject) | |
message = safestr(message) | |
from_address = safestr(from_address) | |
to_address = listify(to_address) | |
cc = listify(kw.get('cc', [])) | |
bcc = listify(kw.get('bcc', [])) | |
recipients = to_address + cc + bcc | |
import email.Utils | |
self.from_address = email.Utils.parseaddr(from_address)[1] | |
self.recipients = [email.Utils.parseaddr(r)[1] for r in recipients] | |
self.headers = dictadd({ | |
'From': from_address, | |
'To': ", ".join(to_address), | |
'Subject': subject | |
}, headers or {}) | |
if cc: | |
self.headers['Cc'] = ", ".join(cc) | |
self.message = self.new_message() | |
self.message.add_header("Content-Transfer-Encoding", "7bit") | |
self.message.add_header("Content-Disposition", "inline") | |
self.message.add_header("MIME-Version", "1.0") | |
self.message.set_payload(message, 'utf-8') | |
self.multipart = False | |
def new_message(self): | |
from email.Message import Message | |
return Message() | |
def attach(self, filename, content, content_type=None): | |
if not self.multipart: | |
msg = self.new_message() | |
msg.add_header("Content-Type", "multipart/mixed") | |
msg.attach(self.message) | |
self.message = msg | |
self.multipart = True | |
import mimetypes | |
try: | |
from email import encoders | |
except: | |
from email import Encoders as encoders | |
content_type = content_type or mimetypes.guess_type(filename)[0] or "applcation/octet-stream" | |
msg = self.new_message() | |
msg.set_payload(content) | |
msg.add_header('Content-Type', content_type) | |
msg.add_header('Content-Disposition', 'attachment', filename=filename) | |
if not content_type.startswith("text/"): | |
encoders.encode_base64(msg) | |
self.message.attach(msg) | |
def prepare_message(self): | |
for k, v in self.headers.iteritems(): | |
if k.lower() == "content-type": | |
self.message.set_type(v) | |
else: | |
self.message.add_header(k, v) | |
self.headers = {} | |
def send(self): | |
try: | |
import webapi | |
except ImportError: | |
webapi = Storage(config=Storage()) | |
self.prepare_message() | |
message_text = self.message.as_string() | |
if webapi.config.get('smtp_server'): | |
server = webapi.config.get('smtp_server') | |
port = webapi.config.get('smtp_port', 0) | |
username = webapi.config.get('smtp_username') | |
password = webapi.config.get('smtp_password') | |
debug_level = webapi.config.get('smtp_debuglevel', None) | |
starttls = webapi.config.get('smtp_starttls', False) | |
import smtplib | |
smtpserver = smtplib.SMTP(server, port) | |
if debug_level: | |
smtpserver.set_debuglevel(debug_level) | |
if starttls: | |
smtpserver.ehlo() | |
smtpserver.starttls() | |
smtpserver.ehlo() | |
if username and password: | |
smtpserver.login(username, password) | |
smtpserver.sendmail(self.from_address, self.recipients, message_text) | |
smtpserver.quit() | |
elif webapi.config.get('email_engine') == 'aws': | |
import boto.ses | |
c = boto.ses.SESConnection( | |
aws_access_key_id=webapi.config.get('aws_access_key_id'), | |
aws_secret_access_key=web.api.config.get('aws_secret_access_key')) | |
c.send_raw_email(self.from_address, message_text, self.from_recipients) | |
else: | |
sendmail = webapi.config.get('sendmail_path', '/usr/sbin/sendmail') | |
assert not self.from_address.startswith('-'), 'security' | |
for r in self.recipients: | |
assert not r.startswith('-'), 'security' | |
cmd = [sendmail, '-f', self.from_address] + self.recipients | |
if subprocess: | |
p = subprocess.Popen(cmd, stdin=subprocess.PIPE) | |
p.stdin.write(message_text) | |
p.stdin.close() | |
p.wait() | |
else: | |
i, o = os.popen2(cmd) | |
i.write(message) | |
i.close() | |
o.close() | |
del i, o | |
def __repr__(self): | |
return "<EmailMessage>" | |
def __str__(self): | |
return self.message.as_string() | |
if __name__ == "__main__": | |
import doctest | |
doctest.testmod() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Web API (wrapper around WSGI) | |
(from web.py) | |
""" | |
__all__ = [ | |
"config", | |
"header", "debug", | |
"input", "data", | |
"setcookie", "cookies", | |
"ctx", | |
"HTTPError", | |
# 200, 201, 202 | |
"OK", "Created", "Accepted", | |
"ok", "created", "accepted", | |
# 301, 302, 303, 304, 307 | |
"Redirect", "Found", "SeeOther", "NotModified", "TempRedirect", | |
"redirect", "found", "seeother", "notmodified", "tempredirect", | |
# 400, 401, 403, 404, 405, 406, 409, 410, 412 | |
"BadRequest", "Unauthorized", "Forbidden", "NotFound", "NoMethod", "NotAcceptable", "Conflict", "Gone", "PreconditionFailed", | |
"badrequest", "unauthorized", "forbidden", "notfound", "nomethod", "notacceptable", "conflict", "gone", "preconditionfailed", | |
# 500 | |
"InternalError", | |
"internalerror", | |
] | |
import sys, cgi, Cookie, pprint, urlparse, urllib | |
from utils import storage, storify, threadeddict, dictadd, intget, safestr | |
config = storage() | |
config.__doc__ = """ | |
A configuration object for various aspects of web.py. | |
`debug` | |
: when True, enables reloading, disabled template caching and sets internalerror to debugerror. | |
""" | |
class HTTPError(Exception): | |
def __init__(self, status, headers={}, data=""): | |
ctx.status = status | |
for k, v in headers.items(): | |
header(k, v) | |
self.data = data | |
Exception.__init__(self, status) | |
def _status_code(status, data=None, classname=None, docstring=None): | |
if data is None: | |
data = status.split(" ", 1)[1] | |
classname = status.split(" ", 1)[1].replace(' ', '') # 304 Not Modified -> NotModified | |
docstring = docstring or '`%s` status' % status | |
def __init__(self, data=data, headers={}): | |
HTTPError.__init__(self, status, headers, data) | |
# trick to create class dynamically with dynamic docstring. | |
return type(classname, (HTTPError, object), { | |
'__doc__': docstring, | |
'__init__': __init__ | |
}) | |
ok = OK = _status_code("200 OK", data="") | |
created = Created = _status_code("201 Created") | |
accepted = Accepted = _status_code("202 Accepted") | |
class Redirect(HTTPError): | |
"""A `301 Moved Permanently` redirect.""" | |
def __init__(self, url, status='301 Moved Permanently', absolute=False): | |
""" | |
Returns a `status` redirect to the new URL. | |
`url` is joined with the base URL so that things like | |
`redirect("about") will work properly. | |
""" | |
newloc = urlparse.urljoin(ctx.path, url) | |
if newloc.startswith('/'): | |
if absolute: | |
home = ctx.realhome | |
else: | |
home = ctx.home | |
newloc = home + newloc | |
headers = { | |
'Content-Type': 'text/html', | |
'Location': newloc | |
} | |
HTTPError.__init__(self, status, headers, "") | |
redirect = Redirect | |
class Found(Redirect): | |
"""A `302 Found` redirect.""" | |
def __init__(self, url, absolute=False): | |
Redirect.__init__(self, url, '302 Found', absolute=absolute) | |
found = Found | |
class SeeOther(Redirect): | |
"""A `303 See Other` redirect.""" | |
def __init__(self, url, absolute=False): | |
Redirect.__init__(self, url, '303 See Other', absolute=absolute) | |
seeother = SeeOther | |
class NotModified(HTTPError): | |
"""A `304 Not Modified` status.""" | |
def __init__(self): | |
HTTPError.__init__(self, "304 Not Modified") | |
notmodified = NotModified | |
class TempRedirect(Redirect): | |
"""A `307 Temporary Redirect` redirect.""" | |
def __init__(self, url, absolute=False): | |
Redirect.__init__(self, url, '307 Temporary Redirect', absolute=absolute) | |
tempredirect = TempRedirect | |
class BadRequest(HTTPError): | |
"""`400 Bad Request` error.""" | |
message = "bad request" | |
def __init__(self): | |
status = "400 Bad Request" | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, self.message) | |
badrequest = BadRequest | |
class Unauthorized(HTTPError): | |
"""`401 Unauthorized` error.""" | |
message = "unauthorized" | |
def __init__(self): | |
status = "401 Unauthorized" | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, self.message) | |
unauthorized = Unauthorized | |
class Forbidden(HTTPError): | |
"""`403 Forbidden` error.""" | |
message = "forbidden" | |
def __init__(self): | |
status = "403 Forbidden" | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, self.message) | |
forbidden = Forbidden | |
class _NotFound(HTTPError): | |
"""`404 Not Found` error.""" | |
message = "not found" | |
def __init__(self, message=None): | |
status = '404 Not Found' | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, message or self.message) | |
def NotFound(message=None): | |
"""Returns HTTPError with '404 Not Found' error from the active application. | |
""" | |
if message: | |
return _NotFound(message) | |
elif ctx.get('app_stack'): | |
return ctx.app_stack[-1].notfound() | |
else: | |
return _NotFound() | |
notfound = NotFound | |
class NoMethod(HTTPError): | |
"""A `405 Method Not Allowed` error.""" | |
def __init__(self, cls=None): | |
status = '405 Method Not Allowed' | |
headers = {} | |
headers['Content-Type'] = 'text/html' | |
methods = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE'] | |
if cls: | |
methods = [method for method in methods if hasattr(cls, method)] | |
headers['Allow'] = ', '.join(methods) | |
data = None | |
HTTPError.__init__(self, status, headers, data) | |
nomethod = NoMethod | |
class NotAcceptable(HTTPError): | |
"""`406 Not Acceptable` error.""" | |
message = "not acceptable" | |
def __init__(self): | |
status = "406 Not Acceptable" | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, self.message) | |
notacceptable = NotAcceptable | |
class Conflict(HTTPError): | |
"""`409 Conflict` error.""" | |
message = "conflict" | |
def __init__(self): | |
status = "409 Conflict" | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, self.message) | |
conflict = Conflict | |
class Gone(HTTPError): | |
"""`410 Gone` error.""" | |
message = "gone" | |
def __init__(self): | |
status = '410 Gone' | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, self.message) | |
gone = Gone | |
class PreconditionFailed(HTTPError): | |
"""`412 Precondition Failed` error.""" | |
message = "precondition failed" | |
def __init__(self): | |
status = "412 Precondition Failed" | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, self.message) | |
preconditionfailed = PreconditionFailed | |
class _InternalError(HTTPError): | |
"""500 Internal Server Error`.""" | |
message = "internal server error" | |
def __init__(self, message=None): | |
status = '500 Internal Server Error' | |
headers = {'Content-Type': 'text/html'} | |
HTTPError.__init__(self, status, headers, message or self.message) | |
def InternalError(message=None): | |
"""Returns HTTPError with '500 internal error' error from the active application. | |
""" | |
if message: | |
return _InternalError(message) | |
elif ctx.get('app_stack'): | |
return ctx.app_stack[-1].internalerror() | |
else: | |
return _InternalError() | |
internalerror = InternalError | |
def header(hdr, value, unique=False): | |
""" | |
Adds the header `hdr: value` with the response. | |
If `unique` is True and a header with that name already exists, | |
it doesn't add a new one. | |
""" | |
hdr, value = safestr(hdr), safestr(value) | |
# protection against HTTP response splitting attack | |
if '\n' in hdr or '\r' in hdr or '\n' in value or '\r' in value: | |
raise ValueError, 'invalid characters in header' | |
if unique is True: | |
for h, v in ctx.headers: | |
if h.lower() == hdr.lower(): return | |
ctx.headers.append((hdr, value)) | |
def rawinput(method=None): | |
"""Returns storage object with GET or POST arguments. | |
""" | |
method = method or "both" | |
from cStringIO import StringIO | |
def dictify(fs): | |
# hack to make web.input work with enctype='text/plain. | |
if fs.list is None: | |
fs.list = [] | |
return dict([(k, fs[k]) for k in fs.keys()]) | |
e = ctx.env.copy() | |
a = b = {} | |
if method.lower() in ['both', 'post', 'put']: | |
if e['REQUEST_METHOD'] in ['POST', 'PUT']: | |
if e.get('CONTENT_TYPE', '').lower().startswith('multipart/'): | |
# since wsgi.input is directly passed to cgi.FieldStorage, | |
# it can not be called multiple times. Saving the FieldStorage | |
# object in ctx to allow calling web.input multiple times. | |
a = ctx.get('_fieldstorage') | |
if not a: | |
fp = e['wsgi.input'] | |
a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1) | |
ctx._fieldstorage = a | |
else: | |
fp = StringIO(data()) | |
a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1) | |
a = dictify(a) | |
if method.lower() in ['both', 'get']: | |
e['REQUEST_METHOD'] = 'GET' | |
b = dictify(cgi.FieldStorage(environ=e, keep_blank_values=1)) | |
def process_fieldstorage(fs): | |
if isinstance(fs, list): | |
return [process_fieldstorage(x) for x in fs] | |
elif fs.filename is None: | |
return fs.value | |
else: | |
return fs | |
return storage([(k, process_fieldstorage(v)) for k, v in dictadd(b, a).items()]) | |
def input(*requireds, **defaults): | |
""" | |
Returns a `storage` object with the GET and POST arguments. | |
See `storify` for how `requireds` and `defaults` work. | |
""" | |
_method = defaults.pop('_method', 'both') | |
out = rawinput(_method) | |
try: | |
defaults.setdefault('_unicode', True) # force unicode conversion by default. | |
return storify(out, *requireds, **defaults) | |
except KeyError: | |
raise badrequest() | |
def data(): | |
"""Returns the data sent with the request.""" | |
if 'data' not in ctx: | |
cl = intget(ctx.env.get('CONTENT_LENGTH'), 0) | |
ctx.data = ctx.env['wsgi.input'].read(cl) | |
return ctx.data | |
def setcookie(name, value, expires='', domain=None, | |
secure=False, httponly=False, path=None): | |
"""Sets a cookie.""" | |
morsel = Cookie.Morsel() | |
name, value = safestr(name), safestr(value) | |
morsel.set(name, value, urllib.quote(value)) | |
if expires < 0: | |
expires = -1000000000 | |
morsel['expires'] = expires | |
morsel['path'] = path or ctx.homepath+'/' | |
if domain: | |
morsel['domain'] = domain | |
if secure: | |
morsel['secure'] = secure | |
value = morsel.OutputString() | |
if httponly: | |
value += '; httponly' | |
header('Set-Cookie', value) | |
def cookies(*requireds, **defaults): | |
""" | |
Returns a `storage` object with all the cookies in it. | |
See `storify` for how `requireds` and `defaults` work. | |
""" | |
cookie = Cookie.SimpleCookie() | |
cookie.load(ctx.env.get('HTTP_COOKIE', '')) | |
try: | |
d = storify(cookie, *requireds, **defaults) | |
for k, v in d.items(): | |
d[k] = v and urllib.unquote(v) | |
return d | |
except KeyError: | |
badrequest() | |
raise StopIteration | |
def debug(*args): | |
""" | |
Prints a prettyprinted version of `args` to stderr. | |
""" | |
try: | |
out = ctx.environ['wsgi.errors'] | |
except: | |
out = sys.stderr | |
for arg in args: | |
print >> out, pprint.pformat(arg) | |
return '' | |
def _debugwrite(x): | |
try: | |
out = ctx.environ['wsgi.errors'] | |
except: | |
out = sys.stderr | |
out.write(x) | |
debug.write = _debugwrite | |
ctx = context = threadeddict() | |
ctx.__doc__ = """ | |
A `storage` object containing various information about the request: | |
`environ` (aka `env`) | |
: A dictionary containing the standard WSGI environment variables. | |
`host` | |
: The domain (`Host` header) requested by the user. | |
`home` | |
: The base path for the application. | |
`ip` | |
: The IP address of the requester. | |
`method` | |
: The HTTP method used. | |
`path` | |
: The path request. | |
`query` | |
: If there are no query arguments, the empty string. Otherwise, a `?` followed | |
by the query string. | |
`fullpath` | |
: The full path requested, including query arguments (`== path + query`). | |
### Response Data | |
`status` (default: "200 OK") | |
: The status code to be used in the response. | |
`headers` | |
: A list of 2-tuples to be used in the response. | |
`output` | |
: A string to be used as the response. | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""openid.py: an openid library for web.py | |
Notes: | |
- This will create a file called .openid_secret_key in the | |
current directory with your secret key in it. If someone | |
has access to this file they can log in as any user. And | |
if the app can't find this file for any reason (e.g. you | |
moved the app somewhere else) then each currently logged | |
in user will get logged out. | |
- State must be maintained through the entire auth process | |
-- this means that if you have multiple web.py processes | |
serving one set of URLs or if you restart your app often | |
then log ins will fail. You have to replace sessions and | |
store for things to work. | |
- We set cookies starting with "openid_". | |
""" | |
import os | |
import random | |
import hmac | |
import __init__ as web | |
import openid.consumer.consumer | |
import openid.store.memstore | |
sessions = {} | |
store = openid.store.memstore.MemoryStore() | |
def _secret(): | |
try: | |
secret = file('.openid_secret_key').read() | |
except IOError: | |
# file doesn't exist | |
secret = os.urandom(20) | |
file('.openid_secret_key', 'w').write(secret) | |
return secret | |
def _hmac(identity_url): | |
return hmac.new(_secret(), identity_url).hexdigest() | |
def _random_session(): | |
n = random.random() | |
while n in sessions: | |
n = random.random() | |
n = str(n) | |
return n | |
def status(): | |
oid_hash = web.cookies().get('openid_identity_hash', '').split(',', 1) | |
if len(oid_hash) > 1: | |
oid_hash, identity_url = oid_hash | |
if oid_hash == _hmac(identity_url): | |
return identity_url | |
return None | |
def form(openid_loc): | |
oid = status() | |
if oid: | |
return ''' | |
<form method="post" action="%s"> | |
<img src="http://openid.net/login-bg.gif" alt="OpenID" /> | |
<strong>%s</strong> | |
<input type="hidden" name="action" value="logout" /> | |
<input type="hidden" name="return_to" value="%s" /> | |
<button type="submit">log out</button> | |
</form>''' % (openid_loc, oid, web.ctx.fullpath) | |
else: | |
return ''' | |
<form method="post" action="%s"> | |
<input type="text" name="openid" value="" | |
style="background: url(http://openid.net/login-bg.gif) no-repeat; padding-left: 18px; background-position: 0 50%%;" /> | |
<input type="hidden" name="return_to" value="%s" /> | |
<button type="submit">log in</button> | |
</form>''' % (openid_loc, web.ctx.fullpath) | |
def logout(): | |
web.setcookie('openid_identity_hash', '', expires=-1) | |
class host: | |
def POST(self): | |
# unlike the usual scheme of things, the POST is actually called | |
# first here | |
i = web.input(return_to='/') | |
if i.get('action') == 'logout': | |
logout() | |
return web.redirect(i.return_to) | |
i = web.input('openid', return_to='/') | |
n = _random_session() | |
sessions[n] = {'webpy_return_to': i.return_to} | |
c = openid.consumer.consumer.Consumer(sessions[n], store) | |
a = c.begin(i.openid) | |
f = a.redirectURL(web.ctx.home, web.ctx.home + web.ctx.fullpath) | |
web.setcookie('openid_session_id', n) | |
return web.redirect(f) | |
def GET(self): | |
n = web.cookies('openid_session_id').openid_session_id | |
web.setcookie('openid_session_id', '', expires=-1) | |
return_to = sessions[n]['webpy_return_to'] | |
c = openid.consumer.consumer.Consumer(sessions[n], store) | |
a = c.complete(web.input(), web.ctx.home + web.ctx.fullpath) | |
if a.status.lower() == 'success': | |
web.setcookie('openid_identity_hash', _hmac(a.identity_url) + ',' + a.identity_url) | |
del sessions[n] | |
return web.redirect(return_to) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
WSGI Utilities | |
(from web.py) | |
""" | |
import os, sys | |
import http | |
import webapi as web | |
from utils import listget | |
from net import validaddr, validip | |
import httpserver | |
def runfcgi(func, addr=('localhost', 8000)): | |
"""Runs a WSGI function as a FastCGI server.""" | |
import flup.server.fcgi as flups | |
return flups.WSGIServer(func, multiplexed=True, bindAddress=addr, debug=False).run() | |
def runscgi(func, addr=('localhost', 4000)): | |
"""Runs a WSGI function as an SCGI server.""" | |
import flup.server.scgi as flups | |
return flups.WSGIServer(func, bindAddress=addr, debug=False).run() | |
def runwsgi(func): | |
""" | |
Runs a WSGI-compatible `func` using FCGI, SCGI, or a simple web server, | |
as appropriate based on context and `sys.argv`. | |
""" | |
if os.environ.has_key('SERVER_SOFTWARE'): # cgi | |
os.environ['FCGI_FORCE_CGI'] = 'Y' | |
if (os.environ.has_key('PHP_FCGI_CHILDREN') #lighttpd fastcgi | |
or os.environ.has_key('SERVER_SOFTWARE')): | |
return runfcgi(func, None) | |
if 'fcgi' in sys.argv or 'fastcgi' in sys.argv: | |
args = sys.argv[1:] | |
if 'fastcgi' in args: args.remove('fastcgi') | |
elif 'fcgi' in args: args.remove('fcgi') | |
if args: | |
return runfcgi(func, validaddr(args[0])) | |
else: | |
return runfcgi(func, None) | |
if 'scgi' in sys.argv: | |
args = sys.argv[1:] | |
args.remove('scgi') | |
if args: | |
return runscgi(func, validaddr(args[0])) | |
else: | |
return runscgi(func) | |
return httpserver.runsimple(func, validip(listget(sys.argv, 1, ''))) | |
def _is_dev_mode(): | |
# quick hack to check if the program is running in dev mode. | |
if os.environ.has_key('SERVER_SOFTWARE') \ | |
or os.environ.has_key('PHP_FCGI_CHILDREN') \ | |
or 'fcgi' in sys.argv or 'fastcgi' in sys.argv \ | |
or 'mod_wsgi' in sys.argv: | |
return False | |
return True | |
# When running the builtin-server, enable debug mode if not already set. | |
web.config.setdefault('debug', _is_dev_mode()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment