Skip to content

Instantly share code, notes, and snippets.

@robertwb
Created August 16, 2016 08:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save robertwb/8073710c51a415af68259f902fdfe521 to your computer and use it in GitHub Desktop.
Save robertwb/8073710c51a415af68259f902fdfe521 to your computer and use it in GitHub Desktop.
Attachment dagss-branch-changes.diff for http://trac.cython.org/ticket/11
# HG changeset patch
# User Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
# Date 1210947140 -7200
# Node ID 1caa3ed4962ba04c165bda6187a194fd05032b7a
# Parent 9a731464ea49650627ac6a988518aee93e6991aa
Replace filename strings with more generic source descriptors.
This facilitates using the parser and compiler with runtime sources (such as
strings), while still being able to provide context for error messages/C debugging comments.
diff -r 9a731464ea49 -r 1caa3ed4962b Cython/Compiler/Code.py
--- a/Cython/Compiler/Code.py Sat May 10 14:46:37 2008 +0200
+++ b/Cython/Compiler/Code.py Fri May 16 16:12:20 2008 +0200
@@ -8,6 +8,7 @@ from Cython.Utils import open_new_file,
from Cython.Utils import open_new_file, open_source_file
from PyrexTypes import py_object_type, typecast
from TypeSlots import method_coexist
+from Scanning import SourceDescriptor
class CCodeWriter:
# f file output file
@@ -89,21 +90,22 @@ class CCodeWriter:
def get_py_version_hex(self, pyversion):
return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4]
- def file_contents(self, file):
+ def file_contents(self, source_desc):
try:
- return self.input_file_contents[file]
+ return self.input_file_contents[source_desc]
except KeyError:
F = [line.encode('ASCII', 'replace').replace(
'*/', '*[inserted by cython to avoid comment closer]/')
- for line in open_source_file(file)]
- self.input_file_contents[file] = F
+ for line in source_desc.get_lines(decode=True)]
+ self.input_file_contents[source_desc] = F
return F
def mark_pos(self, pos):
if pos is None:
return
- filename, line, col = pos
- contents = self.file_contents(filename)
+ source_desc, line, col = pos
+ assert isinstance(source_desc, SourceDescriptor)
+ contents = self.file_contents(source_desc)
context = ''
for i in range(max(0,line-3), min(line+2, len(contents))):
@@ -112,7 +114,7 @@ class CCodeWriter:
s = s.rstrip() + ' # <<<<<<<<<<<<<< ' + '\n'
context += " * " + s
- marker = '"%s":%d\n%s' % (filename.encode('ASCII', 'replace'), line, context)
+ marker = '"%s":%d\n%s' % (str(source_desc).encode('ASCII', 'replace'), line, context)
if self.last_marker != marker:
self.marker = marker
diff -r 9a731464ea49 -r 1caa3ed4962b Cython/Compiler/Errors.py
--- a/Cython/Compiler/Errors.py Sat May 10 14:46:37 2008 +0200
+++ b/Cython/Compiler/Errors.py Fri May 16 16:12:20 2008 +0200
@@ -12,13 +12,17 @@ class PyrexWarning(Exception):
class PyrexWarning(Exception):
pass
+
def context(position):
- F = open(position[0]).readlines()
- s = ''.join(F[position[1]-6:position[1]])
+ source = position[0]
+ assert not (isinstance(source, unicode) or isinstance(source, str)), (
+ "Please replace filename strings with Scanning.FileSourceDescriptor instances %r" % source)
+ F = list(source.get_lines())
+ s = ''.join(F[min(0, position[1]-6):position[1]])
s += ' '*(position[2]-1) + '^'
s = '-'*60 + '\n...\n' + s + '\n' + '-'*60 + '\n'
return s
-
+
class CompileError(PyrexError):
def __init__(self, position = None, message = ""):
diff -r 9a731464ea49 -r 1caa3ed4962b Cython/Compiler/Main.py
--- a/Cython/Compiler/Main.py Sat May 10 14:46:37 2008 +0200
+++ b/Cython/Compiler/Main.py Fri May 16 16:12:20 2008 +0200
@@ -9,7 +9,7 @@ if sys.version_info[:2] < (2, 2):
from time import time
import Version
-from Scanning import PyrexScanner
+from Scanning import PyrexScanner, FileSourceDescriptor
import Errors
from Errors import PyrexError, CompileError, error
import Parsing
@@ -85,7 +85,8 @@ class Context:
try:
if debug_find_module:
print("Context.find_module: Parsing %s" % pxd_pathname)
- pxd_tree = self.parse(pxd_pathname, scope.type_names, pxd = 1,
+ source_desc = FileSourceDescriptor(pxd_pathname)
+ pxd_tree = self.parse(source_desc, scope.type_names, pxd = 1,
full_module_name = module_name)
pxd_tree.analyse_declarations(scope)
except CompileError:
@@ -116,7 +117,10 @@ class Context:
# None if not found, but does not report an error.
dirs = self.include_directories
if pos:
- here_dir = os.path.dirname(pos[0])
+ file_desc = pos[0]
+ if not isinstance(file_desc, FileSourceDescriptor):
+ raise RuntimeError("Only file sources for code supported")
+ here_dir = os.path.dirname(file_desc.filename)
dirs = [here_dir] + dirs
for dir in dirs:
path = os.path.join(dir, filename)
@@ -137,19 +141,21 @@ class Context:
self.modules[name] = scope
return scope
- def parse(self, source_filename, type_names, pxd, full_module_name):
- name = Utils.encode_filename(source_filename)
+ def parse(self, source_desc, type_names, pxd, full_module_name):
+ if not isinstance(source_desc, FileSourceDescriptor):
+ raise RuntimeError("Only file sources for code supported")
+ source_filename = Utils.encode_filename(source_desc.filename)
# Parse the given source file and return a parse tree.
try:
f = Utils.open_source_file(source_filename, "rU")
try:
- s = PyrexScanner(f, name, source_encoding = f.encoding,
+ s = PyrexScanner(f, source_desc, source_encoding = f.encoding,
type_names = type_names, context = self)
tree = Parsing.p_module(s, pxd, full_module_name)
finally:
f.close()
except UnicodeDecodeError, msg:
- error((name, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
+ error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
if Errors.num_errors > 0:
raise CompileError
return tree
@@ -197,6 +203,7 @@ class Context:
except EnvironmentError:
pass
module_name = full_module_name # self.extract_module_name(source, options)
+ source = FileSourceDescriptor(source)
initial_pos = (source, 1, 0)
scope = self.find_module(module_name, pos = initial_pos, need_pxd = 0)
errors_occurred = False
@@ -339,6 +346,8 @@ def main(command_line = 0):
if any_failures:
sys.exit(1)
+
+
#------------------------------------------------------------------------
#
# Set the default options depending on the platform
diff -r 9a731464ea49 -r 1caa3ed4962b Cython/Compiler/ModuleNode.py
--- a/Cython/Compiler/ModuleNode.py Sat May 10 14:46:37 2008 +0200
+++ b/Cython/Compiler/ModuleNode.py Fri May 16 16:12:20 2008 +0200
@@ -427,8 +427,8 @@ class ModuleNode(Nodes.Node, Nodes.Block
code.putln("")
code.putln("static char *%s[] = {" % Naming.filenames_cname)
if code.filename_list:
- for filename in code.filename_list:
- filename = os.path.basename(filename)
+ for source_desc in code.filename_list:
+ filename = os.path.basename(str(source_desc))
escaped_filename = filename.replace("\\", "\\\\").replace('"', r'\"')
code.putln('"%s",' %
escaped_filename)
diff -r 9a731464ea49 -r 1caa3ed4962b Cython/Compiler/Parsing.py
--- a/Cython/Compiler/Parsing.py Sat May 10 14:46:37 2008 +0200
+++ b/Cython/Compiler/Parsing.py Fri May 16 16:12:20 2008 +0200
@@ -5,7 +5,7 @@ import os, re
import os, re
from string import join, replace
from types import ListType, TupleType
-from Scanning import PyrexScanner
+from Scanning import PyrexScanner, FileSourceDescriptor
import Nodes
import ExprNodes
from ModuleNode import ModuleNode
@@ -1182,7 +1182,8 @@ def p_include_statement(s, level):
include_file_path = s.context.find_include_file(include_file_name, pos)
if include_file_path:
f = Utils.open_source_file(include_file_path, mode="rU")
- s2 = PyrexScanner(f, include_file_path, s, source_encoding=f.encoding)
+ source_desc = FileSourceDescriptor(include_file_path)
+ s2 = PyrexScanner(f, source_desc, s, source_encoding=f.encoding)
try:
tree = p_statement_list(s2, level)
finally:
diff -r 9a731464ea49 -r 1caa3ed4962b Cython/Compiler/Scanning.py
--- a/Cython/Compiler/Scanning.py Sat May 10 14:46:37 2008 +0200
+++ b/Cython/Compiler/Scanning.py Fri May 16 16:12:20 2008 +0200
@@ -16,6 +16,8 @@ from Cython.Plex.Errors import Unrecogni
from Cython.Plex.Errors import UnrecognizedInput
from Errors import CompileError, error
from Lexicon import string_prefixes, make_lexicon
+
+from Cython import Utils
plex_version = getattr(Plex, '_version', None)
#print "Plex version:", plex_version ###
@@ -203,6 +205,57 @@ def initial_compile_time_env():
#------------------------------------------------------------------
+class SourceDescriptor:
+ pass
+
+class FileSourceDescriptor(SourceDescriptor):
+ """
+ Represents a code source. A code source is a more generic abstraction
+ for a "filename" (as sometimes the code doesn't come from a file).
+ Instances of code sources are passed to Scanner.__init__ as the
+ optional name argument and will be passed back when asking for
+ the position()-tuple.
+ """
+ def __init__(self, filename):
+ self.filename = filename
+
+ def get_lines(self, decode=False):
+ # decode is True when called from Code.py (which reserializes in a standard way to ASCII),
+ # while decode is False when called from Errors.py.
+ #
+ # Note that if changing Errors.py in this respect, raising errors over wrong encoding
+ # will no longer be able to produce the line where the encoding problem occurs ...
+ if decode:
+ return Utils.open_source_file(self.filename)
+ else:
+ return open(self.filename)
+
+ def __str__(self):
+ return self.filename
+
+ def __repr__(self):
+ return "<FileSourceDescriptor:%s>" % self
+
+class StringSourceDescriptor(SourceDescriptor):
+ """
+ Instances of this class can be used instead of a filenames if the
+ code originates from a string object.
+ """
+ def __init__(self, name, code):
+ self.name = name
+ self.codelines = [x + "\n" for x in code.split("\n")]
+
+ def get_lines(self, decode=False):
+ return self.codelines
+
+ def __str__(self):
+ return self.name
+
+ def __repr__(self):
+ return "<StringSourceDescriptor:%s>" % self
+
+#------------------------------------------------------------------
+
class PyrexScanner(Scanner):
# context Context Compilation context
# type_names set Identifiers to be treated as type names
# HG changeset patch
# User Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
# Date 1210951140 -7200
# Node ID e12195139626875f71dd708e3dab648b233dc330
# Parent 1caa3ed4962ba04c165bda6187a194fd05032b7a
VisitorTransform + smaller Transform changes
diff -r 1caa3ed4962b -r e12195139626 Cython/Compiler/Transform.py
--- a/Cython/Compiler/Transform.py Fri May 16 16:12:20 2008 +0200
+++ b/Cython/Compiler/Transform.py Fri May 16 17:19:00 2008 +0200
@@ -3,11 +3,22 @@
#
import Nodes
import ExprNodes
+import inspect
class Transform(object):
- # parent_stack [Node] A stack providing information about where in the tree
- # we currently are. Nodes here should be considered
- # read-only.
+ # parent_stack [Node] A stack providing information about where in the tree
+ # we currently are. Nodes here should be considered
+ # read-only.
+ #
+ # attr_stack [(string,int|None)]
+ # A stack providing information about the attribute names
+ # followed to get to the current location in the tree.
+ # The first tuple item is the attribute name, the second is
+ # the index if the attribute is a list, or None otherwise.
+ #
+ #
+ # Additionally, any keyword arguments to __call__ will be set as fields while in
+ # a transformation.
# Transforms for the parse tree should usually extend this class for convenience.
# The caller of a transform will only first call initialize and then process_node on
@@ -18,12 +29,6 @@ class Transform(object):
# return the input node untouched. Returning None will remove the node from the
# parent.
- def __init__(self):
- self.parent_stack = []
-
- def initialize(self, phase, **options):
- pass
-
def process_children(self, node):
"""For all children of node, either process_list (if isinstance(node, list))
or process_node (otherwise) is called."""
@@ -36,28 +41,102 @@ class Transform(object):
newchild = self.process_list(child, childacc.name())
if not isinstance(newchild, list): raise Exception("Cannot replace list with non-list!")
else:
- newchild = self.process_node(child, childacc.name())
+ self.attr_stack.append((childacc.name(), None))
+ newchild = self.process_node(child)
if newchild is not None and not isinstance(newchild, Nodes.Node):
raise Exception("Cannot replace Node with non-Node!")
+ self.attr_stack.pop()
childacc.set(newchild)
self.parent_stack.pop()
- def process_list(self, l, name):
- """Calls process_node on all the items in l, using the name one gets when appending
- [idx] to the name. Each item in l is transformed in-place by the item process_node
- returns, then l is returned."""
- # Comment: If moving to a copying strategy, it might makes sense to return a
- # new list instead.
+ def process_list(self, l, attrname):
+ """Calls process_node on all the items in l. Each item in l is transformed
+ in-place by the item process_node returns, then l is returned. If process_node
+ returns None, the item is removed from the list."""
for idx in xrange(len(l)):
- l[idx] = self.process_node(l[idx], "%s[%d]" % (name, idx))
- return l
+ self.attr_stack.append((attrname, idx))
+ l[idx] = self.process_node(l[idx])
+ self.attr_stack.pop()
+ return [x for x in l if x is not None]
- def process_node(self, node, name):
+ def process_node(self, node):
"""Override this method to process nodes. name specifies which kind of relation the
parent has with child. This method should always return the node which the parent
should use for this relation, which can either be the same node, None to remove
the node, or a different node."""
raise NotImplementedError("Not implemented")
+
+ def __call__(self, root, **params):
+ self.parent_stack = []
+ self.attr_stack = []
+ for key, value in params.iteritems():
+ setattr(self, key, value)
+ root = self.process_node(root)
+ for key, value in params.iteritems():
+ delattr(self, key)
+ del self.parent_stack
+ del self.attr_stack
+ return root
+
+
+class VisitorTransform(Transform):
+
+ # Note: If needed, this can be replaced with a more efficient metaclass
+ # approach, resolving the jump table at module load time.
+
+ def __init__(self, readonly=False, **kw):
+ """readonly - If this is set to True, the results of process_node
+ will be discarded (so that one can return None without changing
+ the tree)."""
+ super(VisitorTransform, self).__init__(**kw)
+ self.visitmethods = {'process_' : {}, 'pre_' : {}, 'post_' : {}}
+ self.attrname = ""
+ self.readonly = readonly
+
+ def get_visitfunc(self, prefix, cls):
+ mname = prefix + cls.__name__
+ m = self.visitmethods[prefix].get(mname)
+ if m is None:
+ # Must resolve, try entire hierarchy
+ for cls in inspect.getmro(cls):
+ m = getattr(self, prefix + cls.__name__, None)
+ if m is not None:
+ break
+ if m is None: raise RuntimeError("Not a Node descendant: " + cls.__name__)
+ self.visitmethods[prefix][mname] = m
+ return m
+
+ def process_node(self, node, name="_"):
+ # Pass on to calls registered in self.visitmethods
+ self.attrname = name
+ if node is None:
+ return None
+ result = self.get_visitfunc("process_", node.__class__)(node)
+ if self.readonly:
+ return node
+ else:
+ return result
+
+ def process_Node(self, node):
+ descend = self.get_visitfunc("pre_", node.__class__)(node)
+ if descend:
+ self.process_children(node)
+ self.get_visitfunc("post_", node.__class__)(node)
+ return node
+
+ def pre_Node(self, node):
+ return True
+
+ def post_Node(self, node):
+ pass
+
+
+# Utils
+def ensure_statlist(node):
+ if not isinstance(node, Nodes.StatListNode):
+ node = Nodes.StatListNode(pos=node.pos, stats=[node])
+ return node
+
class PrintTree(Transform):
"""Prints a representation of the tree to standard output.
@@ -72,15 +151,24 @@ class PrintTree(Transform):
def unindent(self):
self._indent = self._indent[:-2]
- def initialize(self, phase, **options):
+ def __call__(self, tree, phase=None, **params):
print("Parse tree dump at phase '%s'" % phase)
+ super(PrintTree, self).__call__(tree, phase=phase, **params)
# Don't do anything about process_list, the defaults gives
# nice-looking name[idx] nodes which will visually appear
# under the parent-node, not displaying the list itself in
# the hierarchy.
- def process_node(self, node, name):
+ def process_node(self, node):
+ if len(self.attr_stack) == 0:
+ name = "(root)"
+ else:
+ attr, idx = self.attr_stack[-1]
+ if idx is not None:
+ name = "%s[%d]" % (attr, idx)
+ else:
+ name = attr
print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
self.indent()
self.process_children(node)
@@ -92,9 +180,14 @@ class PrintTree(Transform):
return "(none)"
else:
result = node.__class__.__name__
- if isinstance(node, ExprNodes.ExprNode):
+ if isinstance(node, ExprNodes.NameNode):
+ result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
+ elif isinstance(node, Nodes.DefNode):
+ result += "(name=\"%s\")" % node.name
+ elif isinstance(node, ExprNodes.ExprNode):
t = node.type
result += "(type=%s)" % repr(t)
+
return result
@@ -108,9 +201,8 @@ class TransformSet(dict):
for name in PHASES:
self[name] = []
def run(self, name, node, **options):
- assert name in self
+ assert name in self, "Transform phase %s not defined" % name
for transform in self[name]:
- transform.initialize(phase=name, **options)
- transform.process_node(node, "(root)")
+ transform(node, phase=name, **options)
# HG changeset patch
# User Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
# Date 1210951173 -7200
# Node ID 7141900aa6a46503724501409bdac66e582829a8
# Parent e12195139626875f71dd708e3dab648b233dc330
Fixed typo children_attrs -> child_attrs
diff -r e12195139626 -r 7141900aa6a4 Cython/Compiler/ModuleNode.py
--- a/Cython/Compiler/ModuleNode.py Fri May 16 17:19:00 2008 +0200
+++ b/Cython/Compiler/ModuleNode.py Fri May 16 17:19:33 2008 +0200
@@ -33,7 +33,7 @@ class ModuleNode(Nodes.Node, Nodes.Block
# module_temp_cname string
# full_module_name string
- children_attrs = ["body"]
+ child_attrs = ["body"]
def analyse_declarations(self, env):
if Options.embed_pos_in_docstring:
# HG changeset patch
# User Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
# Date 1210953293 -7200
# Node ID c93feb4713475cbb5218157ea61f1c751716a3e0
# Parent 7141900aa6a46503724501409bdac66e582829a8
Added ReadonlyVisitor.
There was an option in VisitorTransform for this but it was way too obscure,
better to have a seperate class.
diff -r 7141900aa6a4 -r c93feb471347 Cython/Compiler/Transform.py
--- a/Cython/Compiler/Transform.py Fri May 16 17:19:33 2008 +0200
+++ b/Cython/Compiler/Transform.py Fri May 16 17:54:53 2008 +0200
@@ -84,14 +84,12 @@ class VisitorTransform(Transform):
# Note: If needed, this can be replaced with a more efficient metaclass
# approach, resolving the jump table at module load time.
- def __init__(self, readonly=False, **kw):
+ def __init__(self, **kw):
"""readonly - If this is set to True, the results of process_node
will be discarded (so that one can return None without changing
the tree)."""
super(VisitorTransform, self).__init__(**kw)
self.visitmethods = {'process_' : {}, 'pre_' : {}, 'post_' : {}}
- self.attrname = ""
- self.readonly = readonly
def get_visitfunc(self, prefix, cls):
mname = prefix + cls.__name__
@@ -106,16 +104,12 @@ class VisitorTransform(Transform):
self.visitmethods[prefix][mname] = m
return m
- def process_node(self, node, name="_"):
+ def process_node(self, node):
# Pass on to calls registered in self.visitmethods
- self.attrname = name
if node is None:
return None
result = self.get_visitfunc("process_", node.__class__)(node)
- if self.readonly:
- return node
- else:
- return result
+ return node
def process_Node(self, node):
descend = self.get_visitfunc("pre_", node.__class__)(node)
@@ -130,6 +124,15 @@ class VisitorTransform(Transform):
def post_Node(self, node):
pass
+class ReadonlyVisitor(VisitorTransform):
+ """
+ Like VisitorTransform, however process_X methods do not have to return
+ the result node -- the result of process_X is always discarded and the
+ structure of the original tree is not changed.
+ """
+ def process_node(self, node):
+ super(ReadonlyVisitor, self).process_node(node) # discard result
+ return node
# Utils
def ensure_statlist(node):
# HG changeset patch
# User Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
# Date 1210954141 -7200
# Node ID 3ed80b6f894b5f7faba4872d74565b44cfcdd22a
# Parent c93feb4713475cbb5218157ea61f1c751716a3e0
Added Node.clone_node utility.
A method for cloning nodes. I expect this one to work on all descandants, but
it can be overriden if a node has special needs. It seems natural to put
such core functionality in the node classes rather than in a visitor.
diff -r c93feb471347 -r 3ed80b6f894b Cython/Compiler/Nodes.py
--- a/Cython/Compiler/Nodes.py Fri May 16 17:54:53 2008 +0200
+++ b/Cython/Compiler/Nodes.py Fri May 16 18:09:01 2008 +0200
@@ -2,7 +2,7 @@
# Pyrex - Parse tree nodes
#
-import string, sys, os, time
+import string, sys, os, time, copy
import Code
from Errors import error, warning, InternalError
@@ -149,6 +149,19 @@ class Node(object):
"""Utility method for more easily implementing get_child_accessors.
If you override get_child_accessors then this method is not used."""
return self.child_attrs
+
+ def clone_node(self):
+ """Clone the node. This is defined as a shallow copy, except for member lists
+ amongst the child attributes (from get_child_accessors) which are also
+ copied. Lists containing child nodes are thus seen as a way for the node
+ to hold multiple children directly; the list is not treated as a seperate
+ level in the tree."""
+ c = copy.copy(self)
+ for acc in c.get_child_accessors():
+ value = acc.get()
+ if isinstance(value, list):
+ acc.set([x for x in value])
+ return c
#
# HG changeset patch
# User Dag Sverre Seljebotn <dagss@student.matnat.uio.no>
# Date 1210954341 -7200
# Node ID 7e8ca264b2812a5ce9bce28a9be56e2a6b333e5f
# Parent 3ed80b6f894b5f7faba4872d74565b44cfcdd22a
New features: CodeWriter, TreeFragment, and a transform unit test framework.
See the documentation of each class for details.
It is a rather big commit, however seperating it is non-trivial. The tests
for all of these features all rely on using each other, so there's a
circular dependency in the tests and I wanted to commit the tests and
features at the same time. (However, the non-test-code does not have a circular
dependency.)
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/CodeWriter.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/CodeWriter.py Fri May 16 18:12:21 2008 +0200
@@ -0,0 +1,202 @@
+from Cython.Compiler.Transform import ReadonlyVisitor
+from Cython.Compiler.Nodes import *
+
+"""
+Serializes a Cython code tree to Cython code. This is primarily useful for
+debugging and testing purposes.
+
+The output is in a strict format, no whitespace or comments from the input
+is preserved (and it could not be as it is not present in the code tree).
+"""
+
+class LinesResult(object):
+ def __init__(self):
+ self.lines = []
+ self.s = u""
+
+ def put(self, s):
+ self.s += s
+
+ def newline(self):
+ self.lines.append(self.s)
+ self.s = u""
+
+ def putline(self, s):
+ self.put(s)
+ self.newline()
+
+class CodeWriter(ReadonlyVisitor):
+
+ indent_string = u" "
+
+ def __init__(self, result = None):
+ super(CodeWriter, self).__init__()
+ if result is None:
+ result = LinesResult()
+ self.result = result
+ self.numindents = 0
+
+ def indent(self):
+ self.numindents += 1
+
+ def dedent(self):
+ self.numindents -= 1
+
+ def startline(self, s = u""):
+ self.result.put(self.indent_string * self.numindents + s)
+
+ def put(self, s):
+ self.result.put(s)
+
+ def endline(self, s = u""):
+ self.result.putline(s)
+
+ def line(self, s):
+ self.startline(s)
+ self.endline()
+
+ def comma_seperated_list(self, items, output_rhs=False):
+ if len(items) > 0:
+ for item in items[:-1]:
+ self.process_node(item)
+ if output_rhs and item.rhs is not None:
+ self.put(u" = ")
+ self.process_node(item.rhs)
+ self.put(u", ")
+ self.process_node(items[-1])
+
+ def process_Node(self, node):
+ raise AssertionError("Node not handled by serializer: %r" % node)
+
+ def process_ModuleNode(self, node):
+ self.process_children(node)
+
+ def process_StatListNode(self, node):
+ self.process_children(node)
+
+ def process_FuncDefNode(self, node):
+ self.startline(u"def %s(" % node.name)
+ self.comma_seperated_list(node.args)
+ self.endline(u"):")
+ self.indent()
+ self.process_node(node.body)
+ self.dedent()
+
+ def process_CArgDeclNode(self, node):
+ if node.base_type.name is not None:
+ self.process_node(node.base_type)
+ self.put(u" ")
+ self.process_node(node.declarator)
+ if node.default is not None:
+ self.put(u" = ")
+ self.process_node(node.default)
+
+ def process_CNameDeclaratorNode(self, node):
+ self.put(node.name)
+
+ def process_CSimpleBaseTypeNode(self, node):
+ # See Parsing.p_sign_and_longness
+ if node.is_basic_c_type:
+ self.put(("unsigned ", "", "signed ")[node.signed])
+ if node.longness < 0:
+ self.put("short " * -node.longness)
+ elif node.longness > 0:
+ self.put("long " * node.longness)
+
+ self.put(node.name)
+
+ def process_SingleAssignmentNode(self, node):
+ self.startline()
+ self.process_node(node.lhs)
+ self.put(u" = ")
+ self.process_node(node.rhs)
+ self.endline()
+
+ def process_NameNode(self, node):
+ self.put(node.name)
+
+ def process_IntNode(self, node):
+ self.put(node.value)
+
+ def process_IfStatNode(self, node):
+ # The IfClauseNode is handled directly without a seperate match
+ # for clariy.
+ self.startline(u"if ")
+ self.process_node(node.if_clauses[0].condition)
+ self.endline(":")
+ self.indent()
+ self.process_node(node.if_clauses[0].body)
+ self.dedent()
+ for clause in node.if_clauses[1:]:
+ self.startline("elif ")
+ self.process_node(clause.condition)
+ self.endline(":")
+ self.indent()
+ self.process_node(clause.body)
+ self.dedent()
+ if node.else_clause is not None:
+ self.line("else:")
+ self.indent()
+ self.process_node(node.else_clause)
+ self.dedent()
+
+ def process_PassStatNode(self, node):
+ self.startline(u"pass")
+ self.endline()
+
+ def process_PrintStatNode(self, node):
+ self.startline(u"print ")
+ self.comma_seperated_list(node.args)
+ if node.ends_with_comma:
+ self.put(u",")
+ self.endline()
+
+ def process_BinopNode(self, node):
+ self.process_node(node.operand1)
+ self.put(u" %s " % node.operator)
+ self.process_node(node.operand2)
+
+ def process_CVarDefNode(self, node):
+ self.startline(u"cdef ")
+ self.process_node(node.base_type)
+ self.put(u" ")
+ self.comma_seperated_list(node.declarators, output_rhs=True)
+ self.endline()
+
+ def process_ForInStatNode(self, node):
+ self.startline(u"for ")
+ self.process_node(node.target)
+ self.put(u" in ")
+ self.process_node(node.iterator.sequence)
+ self.endline(u":")
+ self.indent()
+ self.process_node(node.body)
+ self.dedent()
+ if node.else_clause is not None:
+ self.line(u"else:")
+ self.indent()
+ self.process_node(node.else_clause)
+ self.dedent()
+
+ def process_SequenceNode(self, node):
+ self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
+
+ def process_SimpleCallNode(self, node):
+ self.put(node.function.name + u"(")
+ self.comma_seperated_list(node.args)
+ self.put(")")
+
+ def process_ExprStatNode(self, node):
+ self.startline()
+ self.process_node(node.expr)
+ self.endline()
+
+ def process_InPlaceAssignmentNode(self, node):
+ self.startline()
+ self.process_node(node.lhs)
+ self.put(" %s= " % node.operator)
+ self.process_node(node.rhs)
+ self.endline()
+
+
+
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/Compiler/Tests/TestTreeFragment.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/Compiler/Tests/TestTreeFragment.py Fri May 16 18:12:21 2008 +0200
@@ -0,0 +1,26 @@
+from Cython.TestUtils import CythonTest
+from Cython.Compiler.TreeFragment import *
+
+class TestTreeFragments(CythonTest):
+ def test_basic(self):
+ F = self.fragment(u"x = 4")
+ T = F.copy()
+ self.assertCode(u"x = 4", T)
+
+ def test_copy_is_independent(self):
+ F = self.fragment(u"if True: x = 4")
+ T1 = F.root
+ T2 = F.copy()
+ self.assertEqual("x", T2.body.if_clauses[0].body.lhs.name)
+ T2.body.if_clauses[0].body.lhs.name = "other"
+ self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
+
+ def test_substitution(self):
+ F = self.fragment(u"x = 4")
+ y = NameNode(pos=None, name=u"y")
+ T = F.substitute({"x" : y})
+ self.assertCode(u"y = 4", T)
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/Compiler/Tests/__init__.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/Compiler/Tests/__init__.py Fri May 16 18:12:21 2008 +0200
@@ -0,0 +1,1 @@
+#empty
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/Compiler/Transform.py
--- a/Cython/Compiler/Transform.py Fri May 16 18:09:01 2008 +0200
+++ b/Cython/Compiler/Transform.py Fri May 16 18:12:21 2008 +0200
@@ -109,7 +109,7 @@ class VisitorTransform(Transform):
if node is None:
return None
result = self.get_visitfunc("process_", node.__class__)(node)
- return node
+ return result
def process_Node(self, node):
descend = self.get_visitfunc("pre_", node.__class__)(node)
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/Compiler/TreeFragment.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/Compiler/TreeFragment.py Fri May 16 18:12:21 2008 +0200
@@ -0,0 +1,122 @@
+#
+# TreeFragments - parsing of strings to trees
+#
+
+import re
+from cStringIO import StringIO
+from Scanning import PyrexScanner, StringSourceDescriptor
+from Symtab import BuiltinScope, ModuleScope
+from Transform import Transform, VisitorTransform
+from Nodes import Node
+from ExprNodes import NameNode
+import Parsing
+import Main
+
+"""
+Support for parsing strings into code trees.
+"""
+
+class StringParseContext(Main.Context):
+ def __init__(self, include_directories, name):
+ Main.Context.__init__(self, include_directories)
+ self.module_name = name
+
+ def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
+ if module_name != self.module_name:
+ raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
+ return ModuleScope(module_name, parent_module = None, context = self)
+
+def parse_from_strings(name, code, pxds={}):
+ """
+ Utility method to parse a (unicode) string of code. This is mostly
+ used for internal Cython compiler purposes (creating code snippets
+ that transforms should emit, as well as unit testing).
+
+ code - a unicode string containing Cython (module-level) code
+ name - a descriptive name for the code source (to use in error messages etc.)
+ """
+
+ # Since source files carry an encoding, it makes sense in this context
+ # to use a unicode string so that code fragments don't have to bother
+ # with encoding. This means that test code passed in should not have an
+ # encoding header.
+ assert isinstance(code, unicode), "unicode code snippets only please"
+ encoding = "UTF-8"
+
+ module_name = name
+ initial_pos = (name, 1, 0)
+ code_source = StringSourceDescriptor(name, code)
+
+ context = StringParseContext([], name)
+ scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
+
+ buf = StringIO(code.encode(encoding))
+
+ scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
+ type_names = scope.type_names, context = context)
+ tree = Parsing.p_module(scanner, 0, module_name)
+ return tree
+
+class TreeCopier(Transform):
+ def process_node(self, node):
+ if node is None:
+ return node
+ else:
+ c = node.clone_node()
+ self.process_children(c)
+ return c
+
+class SubstitutionTransform(VisitorTransform):
+ def process_Node(self, node):
+ if node is None:
+ return node
+ else:
+ c = node.clone_node()
+ self.process_children(c)
+ return c
+
+ def process_NameNode(self, node):
+ if node.name in self.substitute:
+ # Name matched, substitute node
+ return self.substitute[node.name]
+ else:
+ # Clone
+ return self.process_Node(node)
+
+def copy_code_tree(node):
+ return TreeCopier()(node)
+
+INDENT_RE = re.compile(ur"^ *")
+def strip_common_indent(lines):
+ "Strips empty lines and common indentation from the list of strings given in lines"
+ lines = [x for x in lines if x.strip() != u""]
+ minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines)
+ lines = [x[minindent:] for x in lines]
+ return lines
+
+class TreeFragment(object):
+ def __init__(self, code, name, pxds={}):
+ if isinstance(code, unicode):
+ def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
+
+ fmt_code = fmt(code)
+ fmt_pxds = {}
+ for key, value in pxds.iteritems():
+ fmt_pxds[key] = fmt(value)
+
+ self.root = parse_from_strings(name, fmt_code, fmt_pxds)
+ elif isinstance(code, Node):
+ if pxds != {}: raise NotImplementedError()
+ self.root = code
+ else:
+ raise ValueError("Unrecognized code format (accepts unicode and Node)")
+
+ def copy(self):
+ return copy_code_tree(self.root)
+
+ def substitute(self, nodes={}):
+ return SubstitutionTransform()(self.root, substitute = nodes)
+
+
+
+
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/TestUtils.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/TestUtils.py Fri May 16 18:12:21 2008 +0200
@@ -0,0 +1,61 @@
+import Cython.Compiler.Errors as Errors
+from Cython.CodeWriter import CodeWriter
+import unittest
+from Cython.Compiler.ModuleNode import ModuleNode
+import Cython.Compiler.Main as Main
+from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
+
+class CythonTest(unittest.TestCase):
+ def assertCode(self, expected, result_tree):
+ writer = CodeWriter()
+ writer(result_tree)
+ result_lines = writer.result.lines
+
+ expected_lines = strip_common_indent(expected.split("\n"))
+
+ for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
+ self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
+ self.assertEqual(len(result_lines), len(expected_lines),
+ "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
+
+ def fragment(self, code, pxds={}):
+ "Simply create a tree fragment using the name of the test-case in parse errors."
+ name = self.id()
+ if name.startswith("__main__."): name = name[len("__main__."):]
+ name = name.replace(".", "_")
+ return TreeFragment(code, name, pxds)
+
+
+class TransformTest(CythonTest):
+ """
+ Utility base class for transform unit tests. It is based around constructing
+ test trees (either explicitly or by parsing a Cython code string); running
+ the transform, serialize it using a customized Cython serializer (with
+ special markup for nodes that cannot be represented in Cython),
+ and do a string-comparison line-by-line of the result.
+
+ To create a test case:
+ - Call run_pipeline. The pipeline should at least contain the transform you
+ are testing; pyx should be either a string (passed to the parser to
+ create a post-parse tree) or a ModuleNode representing input to pipeline.
+ The result will be a transformed result (usually a ModuleNode).
+
+ - Check that the tree is correct. If wanted, assertCode can be used, which
+ takes a code string as expected, and a ModuleNode in result_tree
+ (it serializes the ModuleNode to a string and compares line-by-line).
+
+ All code strings are first stripped for whitespace lines and then common
+ indentation.
+
+ Plans: One could have a pxd dictionary parameter to run_pipeline.
+ """
+
+
+ def run_pipeline(self, pipeline, pyx, pxds={}):
+ tree = self.fragment(pyx, pxds).root
+ assert isinstance(tree, ModuleNode)
+ # Run pipeline
+ for T in pipeline:
+ tree = T(tree)
+ return tree
+
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/Tests/TestCodeWriter.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/Tests/TestCodeWriter.py Fri May 16 18:12:21 2008 +0200
@@ -0,0 +1,79 @@
+from Cython.TestUtils import CythonTest
+
+class TestCodeWriter(CythonTest):
+ # CythonTest uses the CodeWriter heavily, so do some checking by
+ # roundtripping Cython code through the test framework.
+
+ # Note that this test is dependant upon the normal Cython parser
+ # to generate the input trees to the CodeWriter. This save *a lot*
+ # of time; better to spend that time writing other tests than perfecting
+ # this one...
+
+ # Whitespace is very significant in this process:
+ # - always newline on new block (!)
+ # - indent 4 spaces
+ # - 1 space around every operator
+
+ def t(self, codestr):
+ self.assertCode(codestr, self.fragment(codestr).root)
+
+ def test_print(self):
+ self.t(u"""
+ print x, y
+ print x + y ** 2
+ print x, y, z,
+ """)
+
+ def test_if(self):
+ self.t(u"if x:\n pass")
+
+ def test_ifelifelse(self):
+ self.t(u"""
+ if x:
+ pass
+ elif y:
+ pass
+ elif z + 34 ** 34 - 2:
+ pass
+ else:
+ pass
+ """)
+
+ def test_def(self):
+ self.t(u"""
+ def f(x, y, z):
+ pass
+ def f(x = 34, y = 54, z):
+ pass
+ """)
+
+ def test_longness_and_signedness(self):
+ self.t(u"def f(unsigned long long long long long int y):\n pass")
+
+ def test_signed_short(self):
+ self.t(u"def f(signed short int y):\n pass")
+
+ def test_typed_args(self):
+ self.t(u"def f(int x, unsigned long int y):\n pass")
+
+ def test_cdef_var(self):
+ self.t(u"""
+ cdef int hello
+ cdef int hello = 4, x = 3, y, z
+ """)
+
+ def test_for_loop(self):
+ self.t(u"""
+ for x, y, z in f(g(h(34) * 2) + 23):
+ print x, y, z
+ else:
+ print 43
+ """)
+
+ def test_inplace_assignment(self):
+ self.t(u"x += 43")
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
+
diff -r 3ed80b6f894b -r 7e8ca264b281 Cython/Tests/__init__.py
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/Tests/__init__.py Fri May 16 18:12:21 2008 +0200
@@ -0,0 +1,1 @@
+#empty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment