Skip to content

Instantly share code, notes, and snippets.

@hallkk
Last active November 10, 2016 07:25
Show Gist options
  • Save hallkk/da5adbeb8df88774864bc0e1b1ac4151 to your computer and use it in GitHub Desktop.
Save hallkk/da5adbeb8df88774864bc0e1b1ac4151 to your computer and use it in GitHub Desktop.
hive UDTF样例,将列转化为多行
/**
编写自己的UDTF:
1.继承org.apache.Hadoop.hive.ql.udf.generic.GenericUDTF。
2.实现initialize(),process(),close()三个方法。
3.UDTF首先会调用initialize()方法,此方法返回UDTF的返回行的信息(返回个数,类型)。
4.初始化完成后会调用process()方法,对传入的参数进行处理,可以通过forward()方法把结果返回。
5.最后调用close()对需要清理的方法进行清理。
**/
@Description(name = "convert_nplus_freq",
value = "convert_nplus_freq(nPlusFreqInfo,[freqSeparator]) - convert n+ freq info to n freq")
public class ConvertNPlusFreqUdf extends GenericUDTF {
private static final String DEFAULT_SEPARATOR = ",";
@Override
public StructObjectInspector initialize(ObjectInspector[] args)
throws UDFArgumentException {
if (args.length < 1) {
throw new UDFArgumentLengthException("convert_nplus_freq takes require arguments");
}
if (args[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentException("convert_nplus_freq takes string as parameter");
}
List<String> fieldNames = new ArrayList();
fieldNames.add("freq");
fieldNames.add("count");
List<ObjectInspector> fieldOIs = new ArrayList();
fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
fieldOIs.add(PrimitiveObjectInspectorFactory.javaIntObjectInspector);
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
@Override
public void process(Object[] args) throws HiveException {
String freqInfo = args[0].toString();
String freqSeparator = DEFAULT_SEPARATOR;
if (args.length > 1) {
freqSeparator = args[1].toString();
}
for (Object[] row : this.processRecord(freqInfo, freqSeparator)) {
forward(row);
}
}
public List<Object[]> processRecord(String nPlusFreqInfo, String freqSeparator) {
List<Integer> nPlusFreqList = this.parseIntList(nPlusFreqInfo, freqSeparator);
List<Object[]> resultRows = new ArrayList();
for (int index = 0; index < nPlusFreqList.size(); index++) {
//遍历到N+频次为0,停止计算
if (nPlusFreqList.get(index) == 0) {
break;
}
int freq = index + 1;
//遍历到结尾,直接输出
if (freq == nPlusFreqList.size()) {
resultRows.add(new Object[]{freq, nPlusFreqList.get(index)});
}
//计算freq的值
else {
resultRows.add(new Object[]{freq, nPlusFreqList.get(index) - nPlusFreqList.get(index + 1)});
}
}
return resultRows;
}
@Override
public void close() throws HiveException {
}
/**
* 根据输入n+频次信息和分隔符,转化为n+频次数组。
* 如果在转化过程中发生异常,返回空集合。
*
* @param nPlusFreqInfo
* @param freqSeparator
* @return
*/
private List<Integer> parseIntList(String nPlusFreqInfo, String freqSeparator) {
List<Integer> freqList = Lists.newArrayList();
for (String freq : nPlusFreqInfo.split(freqSeparator)) {
try {
freqList.add(Integer.parseInt(freq));
} catch (NumberFormatException ex) {
return Lists.newArrayList();
}
}
return freqList;
}
}
~
/**
* 对UDTF进行单元测试。
**/
public class ConvertNPlusFreqUdfTest {
@Test
public void testconvertion() throws UDFArgumentException {
ConvertNPlusFreqUdf udf = new ConvertNPlusFreqUdf();
ObjectInspector[] inputOI = {PrimitiveObjectInspectorFactory.javaStringObjectInspector};
// the value exists
try {
udf.initialize(inputOI);
} catch (Exception ex) {
throw ex;
}
// 目标方法应为private方法,通过反射进行单元测试
List<Object[]> results = udf.processRecord("3,2,1,0,0", ",");
Assert.assertEquals(3, results.size());
Assert.assertEquals(1, results.get(0)[0]);
Assert.assertEquals(1, results.get(0)[1]);
Assert.assertEquals(2, results.get(1)[0]);
Assert.assertEquals(1, results.get(1)[1]);
Assert.assertEquals(3, results.get(2)[0]);
Assert.assertEquals(1, results.get(2)[1]);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment