Created
October 24, 2012 20:35
-
-
Save fdelbos/3948701 to your computer and use it in GitHub Desktop.
A Python JSON implementation of the Thrift protocol
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# | |
from TProtocol import * | |
import json, base64, sys | |
__all__ = ['TJSONProtocol', 'TJSONProtocolFactory'] | |
VERSION = 1 | |
COMMA = ',' | |
COLON = ':' | |
LBRACE = '{' | |
RBRACE = '}' | |
LBRACKET = '[' | |
RBRACKET = ']' | |
QUOTE = '"' | |
BACKSLASH = '\\' | |
ZERO = '0' | |
ESCSEQ = '\\u00' | |
ESCAPE_CHAR = '"\\bfnrt' | |
ESCAPE_CHAR_VALS = ['"', '\\', '\b', '\f', '\n', '\r', '\t'] | |
HEXA_CHAR = '0123456789abcdef' | |
NUMERIC_CHAR = '+-.0123456789Ee' | |
CTYPES = {TType.BOOL: 'tf', | |
TType.BYTE: 'i8', | |
TType.I16: 'i16', | |
TType.I32: 'i32', | |
TType.I64: 'i64', | |
TType.DOUBLE: 'dbl', | |
TType.STRING: 'str', | |
TType.STRUCT: 'rect', | |
TType.LIST: 'lst', | |
TType.SET: 'set', | |
TType.MAP: 'map'} | |
JTYPES = {} | |
for key in CTYPES.keys(): | |
JTYPES[CTYPES[key]] = key | |
class JSONBaseContext(): | |
def __init__(self, protocol): | |
self.protocol = protocol | |
self.first = True | |
def doIO(self, function): | |
pass | |
def write(self): | |
pass | |
def read(self): | |
pass | |
def escapeNum(self): | |
return False | |
class JSONListContext(JSONBaseContext): | |
def doIO(self, function): | |
if self.first is True: | |
self.first = False | |
else: | |
function(COMMA) | |
def write(self): | |
self.doIO(self.protocol.trans.write) | |
def read(self): | |
self.doIO(self.protocol.readJSONSyntaxChar) | |
class JSONPairContext(JSONBaseContext): | |
colon = True | |
def doIO(self, function): | |
if self.first is True: | |
self.first = False | |
self.colon = True | |
else: | |
function(COLON if self.colon == True else COMMA) | |
self.colon = not self.colon | |
def write(self): | |
self.doIO(self.protocol.trans.write) | |
def read(self): | |
self.doIO(self.protocol.readJSONSyntaxChar) | |
def escapeNum(self): | |
return self.colon | |
class LookaheadReader(): | |
hasData = False | |
data = '' | |
def __init__(self, protocol): | |
self.protocol = protocol | |
def read(self): | |
if self.hasData is True: | |
self.hasData = False | |
else: | |
self.data = self.protocol.trans.read(1) | |
return self.data | |
def peek(self): | |
if self.hasData is False: | |
self.data = self.protocol.trans.read(1) | |
self.hasData = True | |
return self.data | |
class TJSONProtocolBase(TProtocolBase): | |
def __init__(self, trans): | |
TProtocolBase.__init__(self, trans) | |
def resetWriteContext(self): | |
self.contextStack = [] | |
self.context = JSONBaseContext(self) | |
def resetReadContext(self): | |
self.resetWriteContext() | |
self.reader = LookaheadReader(self) | |
def pushContext(self, ctx): | |
self.contextStack.append(ctx) | |
self.context = ctx | |
def popContext(self): | |
self.contextStack.pop() | |
def writeJSONString(self, string): | |
self.context.write() | |
self.trans.write(json.dumps(string)) | |
def writeJSONNumber(self, number): | |
self.context.write() | |
jsNumber = str(number) | |
if self.context.escapeNum(): | |
jsNumber = "%s%s%s" % (QUOTE, jsNumber, QUOTE) | |
self.trans.write(jsNumber) | |
def writeJSONBase64(self, binary): | |
self.context.write() | |
self.trans.write(QUOTE) | |
self.trans.write(base64.b64encode(binary)) | |
self.trans.write(QUOTE) | |
def writeJSONObjectStart(self): | |
self.context.write() | |
self.trans.write(LBRACE) | |
self.pushContext(JSONPairContext(self)) | |
def writeJSONObjectEnd(self): | |
self.popContext() | |
self.trans.write(RBRACE) | |
def writeJSONArrayStart(self): | |
self.context.write() | |
self.trans.write(LBRACKET) | |
self.pushContext(JSONListContext(self)) | |
def writeJSONArrayEnd(self): | |
self.popContext() | |
self.trans.write(RBRACKET) | |
def readJSONSyntaxChar(self, character): | |
current = self.reader.read() | |
if character != current: | |
raise TProtocolException(TProtocolException.INVALID_DATA, | |
"Unexpected character: %s" % current) | |
def readJSONString(self, skipContext): | |
string = [] | |
if skipContext is False: | |
self.context.read() | |
self.readJSONSyntaxChar(QUOTE) | |
while True: | |
character = self.reader.read() | |
if character == QUOTE: | |
break | |
if character == ESCSEQ[0]: | |
character = self.reader.read() | |
if character == ESCSEQ[1]: | |
self.readJSONSyntaxChar(ZERO) | |
self.readJSONSyntaxChar(ZERO) | |
character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2)) | |
else: | |
off = ESCAPE_CHAR.find(char) | |
if off == -1: | |
raise TProtocolException(TProtocolException.INVALID_DATA, | |
"Expected control char") | |
character = ESCAPE_CHAR_VALS[off] | |
string.append(character) | |
return ''.join(string) | |
def isJSONNumeric(self, character): | |
return (True if NUMERIC_CHAR.find(character) != - 1 else False) | |
def readJSONQuotes(self): | |
if (self.context.escapeNum()): | |
self.readJSONSyntaxChar(QUOTE) | |
def readJSONNumericChars(self): | |
numeric = [] | |
while True: | |
character = self.reader.peek() | |
if self.isJSONNumeric(character) is False: | |
break | |
numeric.append(self.reader.read()) | |
return ''.join(numeric) | |
def readJSONInteger(self): | |
self.context.read() | |
self.readJSONQuotes() | |
numeric = self.readJSONNumericChars() | |
self.readJSONQuotes() | |
try: | |
return int(numeric) | |
except ValueError: | |
raise TProtocolException(TProtocolException.INVALID_DATA, | |
"Bad data encounted in numeric data") | |
def readJSONDouble(self): | |
self.context.read() | |
if self.reader.peek() == QUOTE: | |
string = self.readJSONString(True) | |
try: | |
double = float(string) | |
if self.context.escapeNum is False and double != inf and double != nan: | |
raise TProtocolException(TProtocolException.INVALID_DATA, | |
"Numeric data unexpectedly quoted") | |
return double | |
except ValueError: | |
raise TProtocolException(TProtocolException.INVALID_DATA, | |
"Bad data encounted in numeric data") | |
else: | |
if self.context.escapeNum() is True: | |
self.readJSONSyntaxChar(QUOTE) | |
try: | |
return float(self.readJSONNumericChars()) | |
except ValueErro: | |
raise TProtocolException(TProtocolException.INVALID_DATA, | |
"Bad data encounted in numeric data") | |
def readJSONBase64(self): | |
string = self.readJSONString(False) | |
return base64.b64decode(string) | |
def readJSONObjectStart(self): | |
self.context.read() | |
self.readJSONSyntaxChar(LBRACE) | |
self.pushContext(JSONPairContext(self)) | |
def readJSONObjectEnd(self): | |
self.readJSONSyntaxChar(RBRACE) | |
self.popContext() | |
def readJSONArrayStart(self): | |
self.context.read() | |
self.readJSONSyntaxChar(LBRACKET) | |
self.pushContext(JSONListContext(self)) | |
def readJSONArrayEnd(self): | |
self.readJSONSyntaxChar(RBRACKET) | |
self.popContext() | |
class TJSONProtocol(TJSONProtocolBase): | |
def readMessageBegin(self): | |
self.resetReadContext() | |
self.readJSONArrayStart() | |
if self.readJSONInteger() != VERSION: | |
raise TProtocolException(TProtocolException.BAD_VERSION, | |
"Message contained bad version.") | |
name = self.readJSONString(False) | |
typen = self.readJSONInteger() | |
seqid = self.readJSONInteger() | |
return (name, typen, seqid) | |
def readMessageEnd(self): | |
self.readJSONArrayEnd() | |
def readStructBegin(self): | |
self.readJSONObjectStart() | |
def readStructEnd(self): | |
self.readJSONObjectEnd() | |
def readFieldBegin(self): | |
character = self.reader.peek() | |
type = 0 | |
id = 0 | |
if character == RBRACE: | |
type = TType.STOP | |
else: | |
id = self.readJSONInteger() | |
self.readJSONObjectStart() | |
type = JTYPES[self.readJSONString(False)] | |
return (None, type, id) | |
def readFieldEnd(self): | |
self.readJSONObjectEnd() | |
def readMapBegin(self): | |
self.readJSONArrayStart() | |
keyType = JTYPES[self.readJSONString(False)] | |
valueType = JTYPES[self.readJSONString(False)] | |
size = self.readJSONInteger() | |
self.readJSONObjectStart() | |
return (keyType, valueType, size) | |
def readMapEnd(self): | |
self.readJSONObjectEnd() | |
self.readJSONArrayEnd() | |
def readCollectionBegin(self): | |
self.readJSONArrayStart() | |
elemType = JTYPES[self.readJSONString(False)] | |
size = self.readJSONInteger() | |
return (type, size) | |
readListBegin = readCollectionBegin | |
readSetBegin = readCollectionBegin | |
def readCollectionEnd(self): | |
self.readJSONArrayEnd() | |
readSetEnd = readCollectionEnd | |
readListEnd = readCollectionEnd | |
def readBool(self): | |
return (False if self.readJSONInteger() == 0 else True) | |
def readNumber(self): | |
return self.readJSONInteger() | |
readByte = readNumber | |
readI16 = readNumber | |
readI32 = readNumber | |
readI64 = readNumber | |
def readDouble(self): | |
return self.readJSONDouble() | |
def readString(self): | |
return self.readJSONString(False) | |
def readBinary(self): | |
return self.readJSONBase64() | |
def writeMessageBegin(self, name, request_type, seqid): | |
self.resetWriteContext() | |
self.writeJSONArrayStart() | |
self.writeJSONNumber(VERSION) | |
self.writeJSONString(name) | |
self.writeJSONNumber(request_type) | |
self.writeJSONNumber(seqid) | |
def writeMessageEnd(self): | |
self.writeJSONArrayEnd() | |
def writeStructBegin(self, name): | |
self.writeJSONObjectStart() | |
def writeStructEnd(self): | |
self.writeJSONObjectEnd() | |
def writeFieldBegin(self, name, type, id): | |
self.writeJSONNumber(id) | |
self.writeJSONObjectStart() | |
self.writeJSONString(CTYPES[type]) | |
def writeFieldEnd(self): | |
self.writeJSONObjectEnd() | |
def writeFieldStop(self): | |
pass | |
def writeMapBegin(self, ktype, vtype, size): | |
self.writeJSONArrayStart() | |
self.writeJSONString(CTYPES[ktype]) | |
self.writeJSONString(CTYPES[vtype]) | |
self.writeJSONNumber(size) | |
self.writeJSONObjectStart() | |
def writeMapEnd(self): | |
self.writeJSONObjectEnd() | |
self.writeJSONArrayEnd() | |
def writeListBegin(self, etype, size): | |
self.writeJSONArrayStart() | |
self.writeJSONString(CTYPES[etype]) | |
self.writeJSONNumber(size) | |
def writeListEnd(self): | |
self.writeJSONArrayEnd() | |
def writeSetBegin(self, etype, size): | |
self.writeJSONArrayStart() | |
self.writeJSONString(CTYPES[etype]) | |
self.writeJSONNumber(size) | |
def writeSetEnd(self): | |
self.writeJSONArrayEnd() | |
def writeBool(self, boolean): | |
self.writeJSONNumber(1 if boolean is True else 0) | |
def writeInteger(self, integer): | |
self.writeJSONNumber(integer) | |
writeByte = writeInteger | |
writeI16 = writeInteger | |
writeI32 = writeInteger | |
writeI64 = writeInteger | |
def writeDouble(self, dbl): | |
self.writeJSONNumber(dbl) | |
def writeString(self, string): | |
self.writeJSONString(string) | |
def writeBinary(self, binary): | |
self.writeJSONBase64(binary) | |
class TJSONProtocolFactory: | |
def __init__(self): | |
pass | |
def getProtocol(self, trans): | |
return TJSONProtocol(trans) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
How can I use this to serialize thrift to JSON?