Skip to content

Instantly share code, notes, and snippets.

@mlimotte
Created March 16, 2011 17:44
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save mlimotte/872918 to your computer and use it in GitHub Desktop.
Save mlimotte/872918 to your computer and use it in GitHub Desktop.
A Cascalog function to join a small file that can fit in memory, map-side.
package foo.cascalog;
import cascading.flow.FlowProcess;
import cascading.flow.hadoop.HadoopFlowProcess;
import cascading.operation.FunctionCall;
import cascading.operation.OperationCall;
import cascading.tuple.Tuple;
import cascading.tuple.TupleEntry;
import cascalog.CascalogFunction;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import java.io.*;
import java.util.HashMap;
import java.util.Map;
public class MemoryJoin extends CascalogFunction {
private Map<String,String> theMap = new HashMap<String,String>();
private String name;
private int keyIndex;
private int valueIndex;
private int splitLimit;
private String defaultValue=null;
private boolean caseSensitive;
private boolean warnOnNoJoin;
private String delim;
/**
* @param name
* @param keyIndex
* @param valueIndex
* @param defaultValue
* @param caseSensitive - only set false if you need case-insensitivity, false is slightly worse performance
*/
public MemoryJoin(String name, String delim, int keyIndex, int valueIndex, String defaultValue,
boolean caseSensitive, boolean warnOnNoJoin) {
this.name = name;
this.delim = delim;
this.keyIndex = keyIndex;
this.valueIndex = valueIndex;
int maxIdx = (keyIndex > valueIndex)?keyIndex:valueIndex;
this.splitLimit = maxIdx + 2;
this.defaultValue = defaultValue;
this.caseSensitive = caseSensitive;
this.warnOnNoJoin = warnOnNoJoin;
}
public MemoryJoin(String name, int keyIndex, int valueIndex, String defaultValue,
boolean caseSensitive, boolean warnOnNoJoin) {
this(name,"\t",keyIndex,valueIndex,defaultValue,caseSensitive,warnOnNoJoin);
}
public MemoryJoin(String name, String delim, int keyIndex, int valueIndex, String defaultValue) {
this(name,delim,keyIndex,valueIndex,defaultValue,true,true);
}
public MemoryJoin(String name, int keyIndex, int valueIndex, String defaultValue) {
this(name,"\t",keyIndex,valueIndex,defaultValue,true,true);
}
private Path getFileFor(HadoopFlowProcess hfp, String name) throws IOException {
Path[] files = DistributedCache.getLocalCacheFiles(hfp.getJobConf());
for (Path cachePath : files) {
if (cachePath.getName().equals(name)) {
return cachePath;
}
}
// error, if we haven't returned by now
throw new RuntimeException("Did not find " + name + " in local cache files (distributedCache contains "
+ files.length + " files).");
}
@Override
public void prepare(FlowProcess flowProcess, OperationCall operationCall) {
super.prepare(flowProcess, operationCall);
Path f=null;
try {
HadoopFlowProcess hfp = (HadoopFlowProcess) flowProcess;
f = getFileFor(hfp,name);
} catch (IOException e) {
new RuntimeException("Error getting files from distributedCache for " + name, e);
}
try {
FileSystem fs = FileSystem.getLocal(new Configuration());
InputStream in = fs.open(f,1000000);
InputStreamReader inr = new InputStreamReader(in);
BufferedReader r = new BufferedReader(inr);
String line;
while ((line = r.readLine()) != null) {
String[] flds = line.split(delim,splitLimit);
if (flds.length >= (splitLimit - 1)) {
String key = caseSensitive?flds[keyIndex]:flds[keyIndex].toLowerCase();
theMap.put(key,flds[valueIndex]);
}
}
r.close();
} catch (IOException e) {
new RuntimeException("Error reading file " + f.toString() + " from distributedCache",e);
}
System.out.printf("MemoryJoin.prepare() for %s(mappings=%d, path=%s)%n", name, theMap.size(), f.toString());
}
@Override
public void operate(FlowProcess flowProcess, FunctionCall functionCall) {
Tuple result = new Tuple();
TupleEntry arguments = functionCall.getArguments();
String key = arguments.getString(0);
String s = "";
if (! "".equals(key)) {
key = caseSensitive?key:key.toLowerCase();
s = theMap.get(key);
if (s == null) {
if (warnOnNoJoin) {
System.err.println("MemoryJoin " + name + ": no join found for key \'" + key + "\'");
theMap.put(key,defaultValue); // so we only warn on join once per JVM
}
s = defaultValue;
}
}
result.add(s);
functionCall.getOutputCollector().add(result);
}
}
@mlimotte
Copy link
Author

The files that you want to join will need to exist in HDFS, and then you can use the distributed cache to pass them around to each tasktracker. To do this, add to mapred.cache.files in your job-conf. E.g.:

(with-job-conf { "mapred.cache.files" csv-of-paths-in-hdfs}

Where csv-of-paths-in-hdfs is a comma-separated list of paths in HDFS where the files to be joined exist.

In clojure, you can use this like so:

; import the java class
(ns ... (:import com.metamx.cascalog.MemoryJoin))
; create an instance
(def scatg (MemoryJoin. "site_categories" "\001" 0 1 "unknown"))
; and then as a predicate in a Cascalog query
(<-  ...  (scatg ?site_category-id :> ?site_category))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment