Skip to content

Instantly share code, notes, and snippets.

@duck2
Created August 16, 2019 11:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save duck2/bc024ed712237b85a912675760370bfa to your computer and use it in GitHub Desktop.
Save duck2/bc024ed712237b85a912675760370bfa to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from typing import List, Tuple, Dict, Set, Union, Optional
import xmlschema # type: ignore
from xmlschema.validators import ( # type: ignore
XsdAttribute,
XsdAtomicBuiltin,
XsdAtomicRestriction,
XsdComplexType,
XsdElement,
XsdGroup,
XsdSimpleType,
XsdList,
XsdType,
XsdUnion,
XMLSchema10,
)
import utils
from dfa import dfa_from_group, XsdDFA
# https://docs.python.org/3/library/dataclasses.html
@dataclass
class UxsdType:
"""An XSD type which corresponds to a type in C++."""
cpp_type: str
name: str
class UxsdSimple(UxsdType):
pass
@dataclass
class UxsdUnion(UxsdSimple):
member_types: List[UxsdSimple]
@dataclass
class UxsdEnum(UxsdSimple):
enumeration: List[str]
@dataclass
class UxsdAtomic(UxsdSimple):
@property
def cpp_load_format(self):
return utils.atomic_builtin_load_formats[self.name]
@dataclass
class UxsdNumber(UxsdAtomic):
pass
@dataclass
class UxsdString(UxsdAtomic):
pass
@dataclass
class UxsdComplex(UxsdType):
"""An XSD complex type. It has attributes and content."""
attrs: List[UxsdAttribute]
content: Optional[UxsdContentType]
class UxsdContentType:
pass
@dataclass
class UxsdAll(UxsdContentType):
children: List[UxsdElement]
@dataclass
class UxsdDfa(UxsdContentType):
children: List[UxsdElement]
dfa: XsdDFA
@dataclass
class UxsdLeaf(UxsdContentType):
type: UxsdSimple
@dataclass
class UxsdAttribute:
name: str
default_value: Optional[str]
optional: bool
type: UxsdSimple
@dataclass
class UxsdElement:
name: str
many: bool
optional: bool
type: UxsdType
# Helper types.
UxsdNonstring = Union[UxsdComplex, UxsdUnion, UxsdEnum, UxsdNumber]
UxsdAny = Union[UxsdType, UxsdContentType, UxsdElement, UxsdAttribute]
class UxsdSchema:
"""A schema tree derived from the xmlschema tree.
It includes convenient data structures, such as ordered
access to children and attributes, C++ type names of complex types etc.
"""
# All user-defined complex types and root elements.
complex_types: List[UxsdComplex]
root_elements: List[UxsdElement]
# Complex types found inside elements. They are not found in the global map,
# so we have to reserve them while traversing types in the global map
# and generate them afterwards.
anonymous_complex_types: List[UxsdComplex]
# Enumerations and unions, which need C++ declarations of their own.
enums: List[UxsdEnum]
unions: List[UxsdUnion]
# Simple types found inside unions.
# We generate a special "type_tag" enum from this.
simple_types_in_unions: List[UxsdSimple]
# In C++ code, we allocate global pools for types
# which may occur more than once, so that we can avoid
# frequent allocations.
pool_types: List[UxsdNonstring]
# A special pool is generated for strings.
has_string_pool: bool = False
# Build a UxsdSchema out of an XsdSchema using a recursive walk.
# Complex types can be recursive - don't get trapped in a loop.
_visited: Dict[XsdComplexType, UxsdComplex] = {}
def visit_group(self, t: XsdGroup, many=False, optional=False) -> List[UxsdElement]:
out: List[UxsdElement] = []
if t.occurs[1] is None or t.occurs[1] > 1: many = True
if t.occurs[0] is 0: optional = True
for e in t._group:
if isinstance(e, XsdGroup):
out += self.visit_group(e, many, optional)
elif isinstance(e, XsdElement):
out.append(self.visit_element(e, many, optional))
else:
raise NotImplementedError("I don't know what to do with group member %s." % e)
return out
def visit_element(self, t: XsdElement, many=False, optional=False) -> UxsdElement:
if t.occurs[1] is None or t.occurs[1] > 1: many = True
if t.occurs[0] is 0: optional = True
type: UxsdType
if isinstance(t.type, XsdComplexType):
type = self.visit_complex_type(t.type)
else:
type = self.visit_simple_type(t.type)
name = utils.pluralize(t.name) if many else t.name
if many and isinstance(type, (UxsdComplex, UxsdUnion, UxsdEnum, UxsdNumber)):
self.pool_types.append(type)
return UxsdElement(name, many, optional, type)
# Only enumerations are supported as restrictions.
def visit_restriction(self, t: XsdAtomicRestriction) -> UxsdEnum:
assert len(t.validators) == 1, "I can only handle simple enumerations."
# Possibly member of an XsdList or XsdUnion if it doesn't have a name attribute.
name = t.name if t.name else t.parent.name
cpp_type = "enum_%s" % t.name
enumeration = t.validators[0].enumeration
out = UxsdEnum(name, cpp_type, enumeration)
self.enums.append(out)
return out
def visit_union(self, t: XsdUnion) -> UxsdUnion:
member_types = []
for m in t.member_types:
x = self.visit_simple_type(t)
member_types.append(x)
self.simple_types_in_unions.append(x)
cpp_type = "union_%s" % t.name
out = UxsdUnion(t.name, cpp_type, member_types)
self.unions.append(out)
return out
def visit_simple_type(self, t: XsdSimpleType) -> UxsdSimple:
# Remove w3.org namespace from built-in type names.
if "w3.org" in t.name:
name = t.name.split("}")[1]
if isinstance(t, XsdAtomicBuiltin):
if t.name == "string":
self.has_string_pool = True
return UxsdAtomic(name=name, cpp_type=utils.atomic_builtins[name])
elif isinstance(t, XsdList):
# Just read xs:lists into a string for now.
# That simplifies validation and keeps heap allocation to nodes only.
# VPR just reads list types into a string, too.
return UxsdAtomic(name=name, cpp_type="const char *")
elif isinstance(t, XsdAtomicRestriction):
return self.visit_restriction(t)
elif isinstance(t, XsdUnion):
return self.visit_union(t)
else:
raise NotImplementedError("I don't know what to do with type %s." % t)
def visit_attribute(self, a: XsdAttribute) -> UxsdAttribute:
if a.use == "optional":
optional = True
elif a.use == "required":
optional = False
else:
raise NotImplementedError("I don't know what to do with attribute use=%s." % a.use)
default_value = getattr(a, "default", None)
optional = True if a.use == "optional" else False
type = self.visit_simple_type(a.type)
return UxsdAttribute(a.name, default_value, optional, type)
def visit_complex_type(self, t: XsdComplexType) -> UxsdComplex:
if self._visited.get(t, None) is not None:
return self._visited[t]
name = t.name
cpp_type = "t_%s" % t.name
# Remove possible duplicates.
# https://stackoverflow.com/a/39835527
attrs = sorted([self.visit_attribute(a) for a in t.attributes.values()])
content: Optional[UxsdContentType] = None
if isinstance(t.content_type, XsdGroup) and len(t.content_type._group) > 0:
if t.content_type.model == "all":
children = self.visit_group(t.content_type)
content = UxsdAll(children)
elif t.content_type.model in ["choice", "sequence"]:
children = self.visit_group(t.content_type)
dfa = dfa_from_group(t.content_type)
content = UxsdDfa(children, dfa)
else:
raise NotImplementedError("Model group %s is not supported." % t.content_type.model)
elif t.has_simple_content():
type = self.visit_simple_type(t.content_type)
content = UxsdLeaf(type)
return UxsdComplex(name, cpp_type, attrs, content)
def __init__(self, parent: XMLSchema10) -> None:
for k, v in parent.types.items():
if "w3.org" not in k and isinstance(v, XsdComplexType):
self.complex_types.append(self.visit_complex_type(v))
for v in parent.elements.values():
self.root_elements.append(self.visit_element(v))
# The visit_foo functions have side effects, they update schema-wide lists.
# Remove duplicates from schema-wide lists while preserving order.
self.enums = list(dict.fromkeys(self.enums))
self.unions = list(dict.fromkeys(self.unions))
self.simple_types_in_unions = list(dict.fromkeys(self.simple_types_in_unions))
self.pool_types = list(dict.fromkeys(self.pool_types))
# Collect complex types and sort by tree height.
def key_ctype(x: UxsdComplex, visited=None) -> int:
if not visited: visited=set()
if x in visited: return 0
else: visited.add(x)
if isinstance(x.content, UxsdAll) or isinstance(x.content, UxsdDfa):
tree_heights: List[int] = []
for child in x.content.children:
if isinstance(child.type, UxsdComplex):
tree_heights.append(key_ctype(x, visited))
else:
tree_heights.append(1)
return max(tree_heights) + 1
else:
return 1
self.complex_types += self.anonymous_complex_types
self.complex_types.sort(key=key_ctype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment