Skip to content

Instantly share code, notes, and snippets.

@JoaoFelipe
Created May 31, 2013 17:57
Show Gist options
  • Save JoaoFelipe/5686728 to your computer and use it in GitHub Desktop.
Save JoaoFelipe/5686728 to your computer and use it in GitHub Desktop.
Gera model, DAO e xml para um create statement de bd
from cStringIO import StringIO
from tokenize import generate_tokens
PACKAGE = 'com.example.app'
SUB_PACKAGE = 'sub'
def uncapitalize(s):
return s[:1].lower() + s[1:]
def capitalize(s):
return s[:1].upper() + s[1:]
def to_camel_case(s):
return uncapitalize(to_pascal_case(s))
def to_pascal_case(s):
return ''.join(word.capitalize() for word in s.lower().split('_'))
class Entity(object):
def __init__(self, db_name):
self.db_name = db_name
self.name = to_pascal_case(db_name)
self.cls_name = self.name
if self.cls_name[:3] == 'Mas':
self.cls_name = self.cls_name[3:]
self.imports = set()
self.imports.add('import javax.persistence.Entity;')
self.imports.add('import javax.persistence.Table;')
def model_text(self, ident = ''):
new_line = '\n%s'%(ident)
result = ident
imports = set()
imports = imports.union(self.imports)
for field in self.fields:
imports = imports.union(field.imports)
for imp in imports:
result += imp + new_line
result += new_line
result += '@Entity' + new_line
result += '@Table(name = "%s")'%(self.db_name) + new_line
result += 'public class %s implements java.io.Serializable {'%(self.cls_name) + new_line
result += new_line
for field in self.fields:
result += field.declaration(' ') + '\n' + new_line
result += ' public %s() {'%(self.cls_name) + new_line
result += ' }' + '\n' + new_line
const_fields = [field for field in self.fields if not field.is_id]
result += ' public %s(%s) {'%(self.cls_name, ', '.join(field.param() for field in const_fields)) + new_line
for field in const_fields:
result += ' ' + field.set() + new_line
result += ' }' + '\n' + new_line
for field in self.fields:
result += field.getter(' ') + '\n' + new_line
result += field.setter(' ') + '\n' + new_line
result += '}'
return result
def dao_text(self, ident = ''):
new_line = '\n%s'%(ident)
result = ident
imports = set()
for imp in imports:
result += imp + new_line
result += new_line
result += 'import java.util.List;' + new_line
result += 'import javax.persistence.EntityManager;' + new_line
result += 'import javax.persistence.TypedQuery;' + new_line
result += 'import org.apache.log4j.Logger;' + new_line
result += 'import %s.dal.util.JPAUtil;'%(PACKAGE) + new_line
result += 'import %s.dal.model.%s.%s;'%(PACKAGE, SUB_PACKAGE, self.cls_name) + '\n' + new_line
result += 'public class %sDAO {'%(self.cls_name) + new_line
result += ' private static Logger logger = Logger.getLogger(%sDAO.class);'%(self.cls_name) + '\n' + new_line
result += ' public %s get%sByCompId(int compid) {'%(self.cls_name, self.cls_name) + new_line
result += ' EntityManager entityManager = JPAUtil.getEntityManager();' + new_line
result += ' TypedQuery<%(cls_name)s> query = entityManager.createNamedQuery("%(cls_name)s.get%(cls_name)sByCompId", %(cls_name)s.class);'%({'cls_name':self.cls_name}) + new_line
result += ' query.setParameter("compId", compid);' + new_line
result += ' List<%(cls_name)s> %(lower_cls_name)sId = query.getResultList();'%({'cls_name':self.cls_name, 'lower_cls_name':uncapitalize(self.cls_name)}) + new_line
result += ' if (%(lower_cls_name)sId == null || %(lower_cls_name)sId.size() < 1) {'%({'lower_cls_name':uncapitalize(self.cls_name)}) + new_line
result += ' logger.info("%(cls_name)s is not found for id: " + compid + " in MAS database");'%({'cls_name':self.cls_name}) + new_line
result += ' return null;' + new_line
result += ' }' + new_line
result += ' return %(lower_cls_name)sId.get(0);'%({'lower_cls_name':uncapitalize(self.cls_name)}) + new_line
result += ' }' + '\n' + new_line
result += ' public List<%s> getCurrent%s() {'%(self.cls_name, self.cls_name) + new_line
result += ' EntityManager entityManager = JPAUtil.getEntityManager();' + new_line
result += ' TypedQuery<%(cls_name)s> query = entityManager.createNamedQuery("%(cls_name)s.getCurrent%(cls_name)s", %(cls_name)s.class);'%({'cls_name':self.cls_name}) + new_line
result += ' List<%(cls_name)s> %(lower_cls_name)sList = query.getResultList();'%({'cls_name':self.cls_name, 'lower_cls_name':uncapitalize(self.cls_name)}) + new_line
result += ' return %(lower_cls_name)sList;'%({'lower_cls_name':uncapitalize(self.cls_name)}) + new_line
result += ' }' + '\n' + new_line
result += '}'
return result
def xml_text(self, ident = ''):
new_line = '\n%s'%(ident)
result = ident
result += '<entity class="%s">'%(self.cls_name) + new_line
result += ' <named-query name="%(cls_name)s.get%(cls_name)sByCompId">'%({'cls_name':self.cls_name}) + new_line
result += ' <query>' + new_line
result += ' <![CDATA[SELECT rec FROM %s rec'%(self.cls_name) + new_line
result += ' WHERE rec.component.id = :compId]]>' + new_line
result += ' </query>' + new_line
result += ' </named-query>' + new_line
result += ' <named-query name="%(cls_name)s.getCurrent%(cls_name)s">'%({'cls_name':self.cls_name}) + new_line
result += ' <query>' + new_line
result += ' <![CDATA[SELECT rec FROM %s rec'%(self.cls_name) + new_line
result += ' WHERE rec.snapshot.snapshotId = (SELECT max(m.snapshot.snapshotId) FROM %s m)]]>'%(self.cls_name) + new_line
result += ' </query>' + new_line
result += ' </named-query>' + new_line
result += '</entity>'
return result
class Field(object):
def __init__(self, entity, tokens):
self.db_name = tokens[0]
self.name = to_camel_case(tokens[0])
self.is_id = False
self.modifiers = []
self.annotations = []
self.imports = set()
self.type = "void"
if "INTEGER" in tokens:
self.type = "Long"
if "TIMESTAMP" in tokens:
self.type = "Date"
self.imports.add('import java.util.Date;')
self.imports.add('import javax.persistence.Temporal;')
self.imports.add('import javax.persistence.TemporalType;')
self.annotations.append("@Temporal(TemporalType.TIMESTAMP)")
if "VARCHAR" in tokens:
self.type = 'String'
varchar_pos = tokens.index('VARCHAR')
if varchar_pos + 1 < len(tokens) and tokens[varchar_pos + 1] == '(':
self.modifiers.append("length = %s"%(tokens[varchar_pos + 2]))
if any(['NOT','NULL'] == tokens[i:i+2] for i in xrange(len(tokens) - 1)):
self.modifiers.append("nullable = false")
if any(['PRIMARY','KEY'] == tokens[i:i+2] for i in xrange(len(tokens) - 1)):
self.modifiers.append("unique = true")
self.annotations.append("@Id")
self.annotations.append('@SequenceGenerator(name = "%sSeq", sequenceName = "%s_SEQ")'%(uncapitalize(entity.cls_name), entity.db_name))
self.annotations.append('@GeneratedValue(strategy = GenerationType.AUTO, generator = "%sSeq")'%(uncapitalize(entity.cls_name)))
self.imports.add('import javax.persistence.SequenceGenerator;')
self.imports.add('import javax.persistence.GeneratedValue;')
self.imports.add('import javax.persistence.GenerationType;')
self.imports.add('import javax.persistence.Id;')
self.is_id = True
self.annotations.append('@Column(name = "%s"%s)'%(self.db_name, (', ' + ', '.join(self.modifiers)) if self.modifiers else ''))
self.imports.add('import javax.persistence.Column;')
def set_fk(self, tup):
optional = 'false' if 'nullable = false' in self.modifiers else 'true'
self.annotations = []
self.annotations.append('@ManyToOne(optional = %s, fetch = FetchType.LAZY)'%(optional))
self.annotations.append('@JoinColumn(name = "%s")'%(self.db_name))
self.imports.add('import javax.persistence.ManyToOne;')
self.imports.add('import javax.persistence.FetchType;')
self.imports.add('import javax.persistence.JoinColumn;')
self.name = tup[1]
self.type = tup[0]
self.imports.add(tup[2])
def __repr__(self):
return ' ' + '\n '.join(self.annotations)
def declaration(self, ident):
new_line = '\n%s'%(ident)
return ident + new_line.join(self.annotations) + new_line + b'private %s %s;'%(self.type, self.name)
def getter(self, ident):
new_line = '\n%s'%(ident)
result = ident
result += 'public %s get%s() {'%(self.type, capitalize(self.name)) + new_line
result += ' return %s;'%(self.name) + new_line
result += '}'
return result
def setter(self, ident):
new_line = '\n%s'%(ident)
result = ident
result += 'public void set%s(%s) {'%(capitalize(self.name), self.param()) + new_line
result += ' %s'%(self.set()) + new_line
result += '}'
return result
def param(self):
return '%s %s'%(self.type, self.name)
def set(self):
return 'this.%s = %s;'%(self.name, self.name)
def parse_create(tokens, relationship = {}):
if tokens[0] == 'create':
name = tokens[2]
entity = Entity(name)
entity.fields = parse_fields(tokens[4:], entity, relationship)
print entity.dao_text()
else:
print 'Error'
def parse_fields(tokens, entity, relationship = {}):
fields = []
while tokens and tokens[0] != ')':
field, tokens = parse_field(tokens)
f = Field(entity, field)
if f.db_name in relationship:
f.set_fk(relationship[f.db_name])
fields.append(f)
return fields
def parse_field(tokens):
level = 0
elements = []
for i, v in enumerate(tokens):
if (v == ',' or v == ')') and level == 0:
return elements, tokens[i+(1 if v == ',' else 0):]
if v == '(':
level += 1
if v == ')' and level > 0:
level -= 1
elements.append(v.upper())
return elements, tokens[i+1:]
STRING = 1
for i in [8]:#range(1, 9):
print ''
with open("create_cmd%d.sql"%(i), 'r') as f:
text = f.readlines()
tokens = list(token[STRING] for token
in generate_tokens(StringIO(''.join(text)).readline)
if token[STRING])
new_tokens = []
comment = False
for i, v in enumerate(tokens):
if v == '-' and len(tokens) > i + 1 and tokens[i+1] == '-':
comment = True
if v == '\n':
comment = False
if (not comment) and v != '\n':
new_tokens.append(v)
relationship = {
'COMPONENT_ID' : ('Component', 'component', 'import %s.dal.model.Component;'%(PACKAGE)),
'SNAPSHOT_ID' : ('Snapshot', 'snapshot', 'import %s.dal.model.Snapshot;'%(PACKAGE)),
'TIME_PERIOD_ID' : ('TimePeriod', 'timePeriod', 'import %s.dal.model.TimePeriod;'%(PACKAGE)),
}
parse_create(new_tokens, relationship)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment