Skip to content

Instantly share code, notes, and snippets.

@qxj
Created March 24, 2020 12:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save qxj/59fc422a5d2dd2700eb5f372ee3cb1e6 to your computer and use it in GitHub Desktop.
Save qxj/59fc422a5d2dd2700eb5f372ee3cb1e6 to your computer and use it in GitHub Desktop.
Generate tfrecord in MapReduce
package net.jqian.tutorial.tfrecord;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.Mapper;
import org.tensorflow.example.*;
import java.util.*;
/**
* Created by jqian on 17/10/10.
*/
public class DecodeProcessor {
private static final Log LOG = LogFactory.getLog(DecodeProcessor.class);
protected Map<Integer, Map<String, Long>> features = new HashMap<>();
protected List<Integer> contFea = new ArrayList<>();
protected List<Integer> deleteFea = new ArrayList<>();
protected List<Integer> lineFea = new ArrayList<>();
protected List<Integer> sparseFea = new ArrayList<>();
private Configuration conf;
public void init(Configuration conf) {
LOG.info("start init decode processor");
String dictFilePath = conf.get("dict").trim();
this.conf = conf;
String[] linears = conf.getStrings("linear_fea");
if (linears != null) {
for (String linear : linears) {
lineFea.add(Integer.parseInt(linear.trim()));
}
}
String[] deletes = conf.getStrings("delete_fea");
if (deletes != null) {
for (String delete : deletes) {
deleteFea.add(Integer.parseInt(delete));
}
}
String[] conts = conf.getStrings("cont_fea");
if (conts != null) {
for (String cont : conts) {
contFea.add(Integer.parseInt(cont));
}
}
LOG.info("load dict success");
String fieldPath = conf.get("field");
}
public void setConf(Configuration conf) {
this.conf = conf;
}
public Example processLine(String line, Mapper.Context context) {
String[] contents = line.trim().split(" ");
Map<Integer, LinkedList<Long>> x1_0 = new HashMap<Integer, LinkedList<Long>>();
Map<Integer, List<Float>> x2 = new HashMap<>();
Map<Integer, LinkedList<Long>> x3_0 = new HashMap<Integer, LinkedList<Long>>();
for (int i = 1; i < contents.length; i++) {
String content = contents[i];
String[] idsStr = content.split(":");
Integer fieldId = Integer.parseInt(idsStr[0]);
String featureId = idsStr[1];
String val = idsStr[2];
//Float num = Float.parseFloat(idsStr[2]);
//idsStr[2].split(";");
if (this.features.keySet().contains(fieldId)) {
long featureValue;
try {
featureValue = this.features.get(fieldId).get(featureId);
} catch (Exception e) {
LOG.debug(String.format("use rare for field[%d] feature[%s]", fieldId, featureId));
featureValue = 0l;
}
if (this.contFea.size() > 0 && this.contFea.contains(fieldId)) {
if (!val.trim().isEmpty()) {
List<Float> values = new ArrayList<>();
for (String v : val.trim().split(";"))
values.add(Float.parseFloat(v));
x2.put(fieldId, values);
}
} else if (this.lineFea.size() > 0 && this.lineFea.contains(fieldId)) {
putValue(x3_0, fieldId, featureValue);
} else if (this.sparseFea.contains(fieldId)) {
putValue(x1_0, fieldId, featureValue);
}
} else {
context.getCounter("counter", "feature not exist " + fieldId).increment(1);
}
}
putRares(x1_0,sparseFea);
//putRaresIndex(x1_0, index, x1_1, sparseFea);
if (contFea.size() > 0) {
for (Integer fieldId : contFea) {
if (!x2.containsKey(fieldId)) {
//填写默认值 -10000
List<Float> arr = new ArrayList<>();
for (int idx = 0; idx < globalFieldInfo.getField(fieldId).contFeaVecSize; idx++)
arr.add(-10000.0f);
x2.put(fieldId, arr);
}
}
}
putRares(x3_0,lineFea);
//putRaresIndex(x3_0, index, x3_1, lineFea);
Features.Builder featuresBuilder = Features.newBuilder();
/*for(Map.Entry<Integer,LinkedList<Long>> entry:x1.entrySet()){
String keyId = "x1_"+entry.getKey();
featuresBuilder.putFeature(keyId,createLongListFeature(entry.getValue()));
//saveFloatCounter(context, entry, keyId);
}*/
for (Map.Entry<Integer, LinkedList<Long>> entry : x1_0.entrySet()) {
String keyId = "x_" + entry.getKey();
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue()));
saveLongCounter(context, entry, keyId);
}
/*for (Map.Entry<Integer, LinkedList<Long>> entry : x1_1.entrySet()) {
String keyId = "x1_" + entry.getKey() + "_1";
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue()));
saveLongCounter(context, entry, keyId);
}*/
if (x2.size() > 0) {
for (Map.Entry<Integer, List<Float>> entry : x2.entrySet()) {
String keyId = "x2_" + entry.getKey();
//ArrayList<Float> values = new ArrayList<>();
//values.add(entry.getValue());
featuresBuilder.putFeature(keyId, createFloatListFeature(entry.getValue()));
}
}
/*for(Map.Entry<Integer,LinkedList<Long>> entry:x3.entrySet()){
String keyId = "x3_"+entry.getKey();
featuresBuilder.putFeature(keyId,createLongListFeature(entry.getValue()));
saveLongCounter(context, entry, keyId);
}*/
for (Map.Entry<Integer, LinkedList<Long>> entry : x3_0.entrySet()) {
String keyId = "x_" + entry.getKey();
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue()));
saveLongCounter(context, entry, keyId);
}
/*for (Map.Entry<Integer, LinkedList<Long>> entry : x3_1.entrySet()) {
String keyId = "x3_" + entry.getKey() + "_1";
featuresBuilder.putFeature(keyId, createLongListFeature(entry.getValue()));
saveLongCounter(context, entry, keyId);
}*/
if(contents[0].split(",").length==2) {
int clickNums = Integer.parseInt(contents[0].split(",")[0]);
int orderNums = Integer.parseInt(contents[0].split(",")[1]);
int y_ctr = 0;
int y_cvr = 0;
if (clickNums > 0) {
if (orderNums > 0) {
y_ctr = 1;
y_cvr = 1;
} else {
y_ctr = 1;
y_cvr = 0;
}
}
featuresBuilder.putFeature("y_ctr", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(y_ctr).build()).build());
featuresBuilder.putFeature("y_cvr", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(y_cvr).build()).build());
featuresBuilder.putFeature("clickNums", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(clickNums).build()).build());
featuresBuilder.putFeature("orderNums", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(orderNums).build()).build());
}else{
int y = Integer.parseInt(contents[0]);
if (y <= 0) {
y = 0;
} else {
y = 1;
}
featuresBuilder.putFeature("y", Feature.newBuilder().setInt64List(Int64List.newBuilder().addValue(y).build()).build());
}
Example example = Example.newBuilder().setFeatures(featuresBuilder.build()).build();
return example;
}
private void saveLongCounter(Mapper.Context context, Map.Entry<Integer, LinkedList<Long>> entry, String keyId) {
if (context != null) {
if (entry.getValue().size() > 1) {
context.getCounter("field_size_max_than_one", keyId).increment(1l);
} else {
context.getCounter("field_size_eq_one", keyId).increment(1l);
}
}
}
private void putValue(Map<Integer, LinkedList<Long>> x3, Integer fieldId, long featureValue) {
if (x3.containsKey(fieldId)) {
x3.get(fieldId).add(featureValue);
} else {
LinkedList<Long> feaValues = new LinkedList<>();
feaValues.add(featureValue);
x3.put(fieldId, feaValues);
}
}
private void putRares(Map<Integer, LinkedList<Long>> x3, List<Integer> features) {
for (Integer fieldId : features) {
if (!x3.containsKey(fieldId)) {
LinkedList<Long> feaValues = new LinkedList<>();
feaValues.add(0l);
x3.put(fieldId, feaValues);
}
}
}
private void putRaresIndex(Map<Integer, LinkedList<Long>> x1, int index, Map<Integer, LinkedList<Long>> x2, List<Integer> features) {
for (Integer fieldId : features) {
if (!x2.containsKey(fieldId)) {
LinkedList<Long> x2Values = new LinkedList<>();
x2Values.add(0l);
x2.put(fieldId, x2Values);
LinkedList<Long> x1Values = new LinkedList<>();
x1Values.add((long) index);
x1.put(fieldId, x1Values);
}
}
}
private Feature createFloatListFeature(Iterable<? extends Float> values) {
FloatList floatList = FloatList.newBuilder().addAllValue(values).build();
return Feature.newBuilder().setFloatList(floatList).build();
}
private Feature createLongListFeature(Iterable<? extends Long> values) {
Int64List floatList = Int64List.newBuilder().addAllValue(values).build();
return Feature.newBuilder().setInt64List(floatList).build();
}
}
package net.jqian.tutorial.tfrecord;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.*;
import org.apache.hadoop.mapreduce.*;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.util.GenericOptionsParser;
import org.tensorflow.example.Example;
import org.tensorflow.hadoop.io.TFRecordFileOutputFormat;
import java.io.IOException;
import java.util.Random;
public class TFPreprocess {
private static Log LOG = LogFactory.getLog(TFPreprocess.class);
static class ToTFRecordMapper extends Mapper<LongWritable, Text, Text, BytesWritable> {
ToTFRecordMapper() {
}
private DecodeProcessor processor;
private Random random = new Random(System.currentTimeMillis());
@Override
protected void setup(Context context) throws IOException, InterruptedException {
processor = new DecodeProcessor();
processor.init(context.getConfiguration());
String feaStr = "";
for (Integer fea: processor.sparseFea) {
feaStr += fea;
feaStr += ";";
}
context.getCounter("counter", "sparse fea: " + feaStr).increment(1);
feaStr = "";
for (Integer fea: processor.contFea) {
feaStr += fea;
feaStr += ";";
}
context.getCounter("counter", "cont fea: " + feaStr).increment(1);
feaStr = "";
for (Integer fea: processor.deleteFea) {
feaStr += fea;
feaStr += ";";
}
context.getCounter("counter", "delete fea: " + feaStr).increment(1);
feaStr = "";
for (Integer fea: processor.lineFea) {
feaStr += fea;
feaStr += ";";
}
context.getCounter("counter", "linear fea: " + feaStr).increment(1);
}
@Override
protected void map(LongWritable key, Text value,
Context context) throws IOException, InterruptedException {
Example example = processor.processLine(value.toString(), context);
String outKey = String.valueOf(random.nextInt());
context.write(new Text(outKey), new BytesWritable(example.toByteArray()));
}
}
public static class ToTFRecordReducer extends Reducer<Text, BytesWritable, BytesWritable, NullWritable> {
protected void reduce(Text key, Iterable<BytesWritable> values, Context context)
throws IOException, InterruptedException {
for (BytesWritable v: values) {
context.write(v, NullWritable.get());
}
}
}
public static boolean convert(String jobName, Configuration conf,
Class<? extends Mapper> mapperClass,
Class<? extends Reducer> reducerClass,
Class<? extends Writable> mapKeyClass,
Class<? extends Writable> mapValueClass,
Class<? extends Writable> outputKeyClass,
Class<? extends Writable> outputValueClass,
Class<? extends InputFormat> inFormatClass,
Class<? extends OutputFormat> outFormatClass) throws InterruptedException, IOException, ClassNotFoundException {
Job job = Job.getInstance(conf, jobName);
job.setJarByClass(mapperClass);
job.setMapperClass(mapperClass);
job.setReducerClass(reducerClass);
job.setNumReduceTasks(conf.getInt("rednum", 128));
job.setInputFormatClass(inFormatClass);
job.setOutputFormatClass(outFormatClass);
job.setMapOutputKeyClass(mapKeyClass);
job.setMapOutputValueClass(mapValueClass);
job.setOutputKeyClass(outputKeyClass);
job.setOutputValueClass(outputValueClass);
//并发数
job.getConfiguration().setLong("mapred.min.split.size",536870912);
job.getConfiguration().setInt("mapreduce.map.memory.mb", 4096);
job.getConfiguration().set("mapreduce.map.java.opts", "-Xmx4096m");
String inputPath = conf.get("in");
String outputPath = conf.get("out");
String day = conf.get("day");
FileSystem fileSystem = FileSystem.get(conf);
FileInputFormat.addInputPath(job, new Path(inputPath + "/" + day));
LOG.info("input:" + inputPath + "/" + day);
FileInputFormat.setInputDirRecursive(job, true);
Path out = new Path(outputPath + "/" + day);
fileSystem.delete(out, true);
FileOutputFormat.setOutputPath(job, out);
return job.waitForCompletion(true);
}
public static void main(String[] args) throws Exception {
GenericOptionsParser parser = new GenericOptionsParser(new Configuration(), args);
Configuration conf = parser.getConfiguration();
boolean ret = convert("ToTFR", conf, ToTFRecordMapper.class, ToTFRecordReducer.class,
Text.class, BytesWritable.class,
BytesWritable.class, NullWritable.class,
TextInputFormat.class, TFRecordFileOutputFormat.class);
System.exit(ret ? 0 : 1);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment