create DaoTestName.queryOneRecord.when.wiki
f_qor = open(DaoTestName+".queryOneRecord.when.wiki", 'w')
f_qor.write('\n'.join([conn, query_stmt, query_fields]))
f_qor.close
# create DaoTestName.testUpdate.when.wiki
f_update = open(DaoTestName+".testUpdate.when.wiki", 'w')
f_update.write('\n'.join([conn, query_stmt, query_fields]))
f_update.close
def gene_daotest_java(daoTestName, tableName, fieldArray):
f_daotest_java = open(daoTestName+'.java', 'w')
f_daotest_tmpl = open('TemplateDefaultDAOTest.java')
content = ''
for line in f_daotest_tmpl:
content += line
daoPrefixIndex = daoTestName.find('DAOTest')
daoPrefix = daoTestName[0: daoPrefixIndex]
XXXReplacer = daoPrefix
YYYReplacer = firstLowerCase(XXXReplacer)
filteredFieldArray = getFilteredFields(fieldArray, filterTimeFieldFunc)
contentReplaced = content.replace('XXX', XXXReplacer).replace('YYY', YYYReplacer) \
.replace('$setFields', geneSetFields(filteredFieldArray, YYYReplacer)) \
.replace('$AssertGetValues', geneAssertGetValues(filteredFieldArray, YYYReplacer))
f_daotest_java.write(contentReplaced)
def geneAssertGetValues(fieldArray, YYYReplacer):
content = ''
for field in fieldArray:
content += 'Assert.assertEquals(' + YYYReplacer + '.get' + transformField(field) + '(), )\n' + indentTimes(2)
return content
def geneSetFields(fieldArray, YYYReplacer):
content = ''
for field in fieldArray:
content += YYYReplacer + '.set' + transformField(field) + '();\n' + indentTimes(2)
return content
def transformField(field):
parts = field.split('_')
content = ''
for part in parts:
content += firstSuperCase(part)
return content
def indentTimes(num):
indent = '';
while num > 0 :
indent += '\t'
num -= 1
return indent
def firstLowerCase(input):
return input[0].lower() + input[1:]
def firstSuperCase(input):
return input[0].upper() + input[1:]
def nopFunc(field):
return True
def getfieldsWithSep(fieldArray, index=0, sep='|', filterFunc=nopFunc):
if index < 0 or index > len(fieldArray):
raise Exception('index ' + index + ' invalid: must be in [0,' + len(fieldArray) + ']')
fieldFilteredArray = getFilteredFields(fieldArray, filterFunc)
return sep.join(fieldFilteredArray[index:])
def filterTimeFieldFunc(field):
return field.find('gmt_') == -1
def getFilteredFields(fieldArray, filterFunc):
return filter(filterFunc, fieldArray)
if __name__ == '__main__':
allDAOTest = readcfg.getAllDAOTestInfo()
for daoTestName, daoTestInfo in allDAOTest.iteritems():
gene_daotest_wiki(daoTestInfo) daotest.conf: 配置文件
[VmDAOTest]
DaoTestName=VmDAOTest
TableName=vm
FieldArray=id,gmt_create,gmt_modify,vm_name,cores,mem,disk,status,nc_id,is_deleted
[NcDAOTest]
DaoTestName=NcDAOTest
TableName=nc
FieldArray=id,gmt_create,gmt_modify,hostname,ip,avail_cpu, avail_mem, avail_disk
DAO java 文件模板:
package xxx.dao.regiondb.impl;
import java.util.Date;
import java.util.List;
import org.jtester.unitils.dbfit.DbFit;
import org.testng.Assert;
import org.testng.annotations.Test;
import org.unitils.spring.annotation.SpringBeanByName;
import xxx.Base