Skip to content

Instantly share code, notes, and snippets.

@nowell-jana
Created April 24, 2014 20:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nowell-jana/11268603 to your computer and use it in GitHub Desktop.
Save nowell-jana/11268603 to your computer and use it in GitHub Desktop.
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;
import org.apache.avro.hadoop.io.AvroSerialization;
import org.apache.avro.mapred.AvroWrapper;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.lib.MultipleOutputs;
import org.apache.hadoop.mrunit.ReduceDriver;
import org.apache.hadoop.mrunit.internal.mapred.MockReporter;
import org.apache.hadoop.mrunit.internal.output.OutputCollectable;
import org.apache.hadoop.mrunit.internal.util.Errors;
import org.apache.hadoop.mrunit.types.Pair;
import org.apache.hadoop.util.ReflectionUtils;
import org.junit.Assert;
//TODO: Add/complete comments about differences from prototype methods
public class MultipleTextOutputsReduceDriver<K1, V1, K2, V2> extends ReduceDriver<K1, V1, K2, V2> {
private MockMultipleTextOutputs mockMultipleOutputs;
private TreeMap<String, List<Pair<Text, Text>>> expectedMultipleOutputs;
private Set<String> ignoredNamedOutputs;
private boolean ignoreMultipleOutputs;
public MultipleTextOutputsReduceDriver(Reducer<K1, V1, K2, V2> reducer) {
super(reducer);
expectedMultipleOutputs = new TreeMap<String, List<Pair<Text, Text>>>();
ignoredNamedOutputs = new HashSet<String>();
ignoreMultipleOutputs = false;
}
public MultipleTextOutputsReduceDriver() {
super();
expectedMultipleOutputs = new TreeMap<String, List<Pair<Text, Text>>>();
ignoredNamedOutputs = new HashSet<String>();
ignoreMultipleOutputs = false;
}
public void patchAndRegisterMultipleOutputs(MockMultipleTextOutputs mockMultipleOutputs) {
this.mockMultipleOutputs = mockMultipleOutputs;
}
public void addMultipleOutput(String name, final Pair<Text, Text> outputPair) {
expectedMultipleOutputs.get(name).add(outputPair);
}
public void addAllMultipleOutput(TreeMap<String, List<Pair<Text, Text>>> outputMap) {
expectedMultipleOutputs = outputMap;
}
public void addIgnoredNamedOutput(String name) {
ignoredNamedOutputs.add(name);
}
public void addAllIgnoredNamedOutputs(Set<String> outputs) {
ignoredNamedOutputs = outputs;
}
public void ignoreMultipleOutputs() {
ignoreMultipleOutputs = true;
}
@Override
public List<Pair<K2, V2>> run() throws IOException {
try {
preRunChecks(getReducer());
initDistributedCache();
final OutputCollectable<K2, V2> outputCollectable =
mockOutputCreator.createMapredOutputCollectable(getConfiguration(),
getOutputSerializationConfiguration());
final MockReporter reporter = new MockReporter(
MockReporter.ReporterType.Reducer, getCounters());
ReflectionUtils.setConf(getReducer(), new JobConf(getConfiguration()));
mockMultipleOutputs.clear();
String resultMsg = ReflectUtils.injectMockByClass(MultipleOutputs.class, getReducer(),
mockMultipleOutputs);
LOG.info(resultMsg);
for (Pair<K1, List<V1>> kv : inputs) {
getReducer().reduce(kv.getFirst(), kv.getSecond().iterator(),
outputCollectable, reporter);
}
getReducer().close();
return outputCollectable.getOutputs();
} finally {
cleanupDistributedCache();
}
}
@Override
public void runTest(final boolean orderMatters) throws IOException {
if (LOG.isDebugEnabled()) {
printPreTestDebugLog();
}
final List<Pair<K2, V2>> outputs = run();
validate(outputs, expectedOutputs, orderMatters);
if (!ignoreMultipleOutputs) {
validate(getMultipleOutputsResults(), expectedMultipleOutputs);
}
validate(counterWrapper);
}
protected <K, V> void validate(final List<Pair<K, V>> outputs, final List<Pair<K, V>> expectedOutputs,
final boolean orderMatters) {
final Errors errors = new Errors(LOG);
if (!outputs.isEmpty()) {
/* Were we supposed to get output in the first place? */
if (expectedOutputs.isEmpty()) {
errors.record("Expected no outputs; got %d outputs.", outputs.size());
}
/* Check that user's key and value Writables implement equals, hashCode, toString */
checkOverrides(outputs.get(0));
}
final Map<Pair<K, V>, List<Integer>> expectedPositions = buildPositionMap(expectedOutputs);
final Map<Pair<K, V>, List<Integer>> actualPositions = buildPositionMap(outputs);
for (final Pair<K, V> output : expectedPositions.keySet()) {
final List<Integer> expectedPositionList = expectedPositions.get(output);
final List<Integer> actualPositionList = actualPositions.get(output);
if (actualPositionList != null) {
/* The expected value has been seen - check positions */
final int expectedPositionsCount = expectedPositionList.size();
final int actualPositionsCount = actualPositionList.size();
if (orderMatters) {
/* Order is important, so the positions must match exactly */
if (expectedPositionList.equals(actualPositionList)) {
LOG.debug(String.format("Matched expected output %s at "
+ "positions %s", output, expectedPositionList.toString()));
} else {
int i = 0;
while (expectedPositionsCount > i || actualPositionsCount > i) {
if (expectedPositionsCount > i && actualPositionsCount > i) {
final int expectedPosition = expectedPositionList.get(i);
final int actualPosition = actualPositionList.get(i);
if (expectedPosition == actualPosition) {
LOG.debug(String.format("Matched expected output %s at "
+ "position %d", output, expectedPosition));
} else {
errors.record("Matched expected output %s but at "
+ "incorrect position %d (expected position %d)", output,
actualPosition, expectedPosition);
}
} else if (expectedPositionsCount > i) {
/* Not ok, value wasn't seen enough times */
errors.record("Missing expected output %s at position %d.",
output, expectedPositionList.get(i));
} else {
/* Not ok, value seen too many times */
errors.record("Received unexpected output %s at position %d.",
output, actualPositionList.get(i));
}
i++;
}
}
} else {
/* Order is unimportant, just check that the count of times seen match */
if (expectedPositionsCount == actualPositionsCount) {
/* Ok, counts match */
LOG.debug(String.format("Matched expected output %s in "
+ "%d positions", output, expectedPositionsCount));
} else if (expectedPositionsCount > actualPositionsCount) {
/* Not ok, value wasn't seen enough times */
for (int i = 0; i < expectedPositionsCount - actualPositionsCount; i++) {
errors.record("Missing expected output %s", output);
}
} else {
/* Not ok, value seen too many times */
for (int i = 0; i < actualPositionsCount - expectedPositionsCount; i++) {
errors.record("Received unexpected output %s", output);
}
}
}
actualPositions.remove(output);
} else {
/* The expected value was not found anywhere - output errors */
checkTypesAndLogError(outputs, expectedOutputs, output, expectedPositionList,
orderMatters, errors, "Missing expected output");
}
}
for (final Pair<K, V> output : actualPositions.keySet()) {
/* Anything left in actual set is unexpected */
checkTypesAndLogError(outputs, expectedOutputs, output, actualPositions.get(output),
orderMatters, errors, "Received unexpected output");
}
errors.assertNone();
}
protected <K, V> void validate(TreeMap<String, List<Pair<K, V>>> outputs,
TreeMap<String, List<Pair<K, V>>> expectedOutputs) {
HashSet<String> expectedAndIgnoredOutputs = new HashSet<String>(expectedOutputs.keySet());
expectedAndIgnoredOutputs.addAll(ignoredNamedOutputs);
if (!outputs.keySet().equals(expectedAndIgnoredOutputs)) {
String msg = "Actual and (required + ignored) named outputs do not match.\n"
+ "RECEIVED: " + outputs.keySet() + "\n"
+ "REQUIRED: " + expectedOutputs.keySet() + "\n"
+ "IGNORED: " + ignoredNamedOutputs;
LOG.error(msg);
Assert.fail(msg);
} else {
Iterator<Entry<String, List<Pair<K, V>>>> iter1 = outputs.entrySet().iterator();
Iterator<Entry<String, List<Pair<K, V>>>> iter2 = expectedOutputs.entrySet().iterator();
while (iter1.hasNext()) {
Entry<String, List<Pair<K, V>>> entry1 = iter1.next();
if (!ignoredNamedOutputs.contains(entry1.getKey())) {
LOG.info("Validating " + entry1.getKey());
validate(entry1.getValue(), iter2.next().getValue(), true);
LOG.info("Validation succeeded for " + entry1.getKey());
}
}
}
}
private <K, V> void checkOverrides(final Pair<K,V> outputPair) {
checkOverride(outputPair.getFirst().getClass());
checkOverride(outputPair.getSecond().getClass());
}
private void checkOverride(final Class<?> clazz) {
try {
if (AvroWrapper.class.isAssignableFrom(clazz)) {
return;
} else {
if (clazz.getMethod("equals", Object.class).getDeclaringClass() != clazz) {
LOG.warn(clazz.getCanonicalName() + ".equals(Object) " +
"is not being overridden - tests may fail!");
}
if (clazz.getMethod("hashCode").getDeclaringClass() != clazz) {
LOG.warn(clazz.getCanonicalName() + ".hashCode() " +
"is not being overridden - tests may fail!");
}
if (clazz.getMethod("toString").getDeclaringClass() != clazz) {
LOG.warn(clazz.getCanonicalName() + ".toString() " +
"is not being overridden - test failures may be difficult to diagnose.");
LOG.warn("Consider executing test using run() to access outputs");
}
}
} catch (SecurityException e) {
LOG.error(e);
} catch (NoSuchMethodException e) {
LOG.error(e);
}
}
private <K, V> void checkTypesAndLogError(
final List<Pair<K, V>> outputs,
final List<Pair<K, V>> expectedOutputs,
final Pair<K, V> output, final List<Integer> positions,
final boolean orderMatters, final Errors errors,
final String errorString) {
for (final int pos : positions) {
String msg = null;
if (expectedOutputs.size() > pos && outputs.size() > pos) {
final Pair<K, V> actual = outputs.get(pos);
final Pair<K, V> expected = expectedOutputs.get(pos);
final Class<?> actualKeyClass = actual.getFirst().getClass();
final Class<?> actualValueClass = actual.getSecond().getClass();
final Class<?> expectedKeyClass = expected.getFirst().getClass();
final Class<?> expectedValueClass = expected.getSecond().getClass();
if (actualKeyClass != expectedKeyClass) {
msg = String.format(
"%s %s: Mismatch in key class: expected: %s actual: %s",
errorString, output, expectedKeyClass, actualKeyClass);
} else if (actualValueClass != expectedValueClass) {
msg = String.format(
"%s %s: Mismatch in value class: expected: %s actual: %s",
errorString, output, expectedValueClass, actualValueClass);
}
}
if (msg == null) {
if (orderMatters) {
msg = String
.format("%s %s at position %d.", errorString, output, pos);
} else {
msg = String.format("%s %s", errorString, output);
}
}
errors.record(msg);
}
}
private <K,V> Map<Pair<K, V>, List<Integer>> buildPositionMap(
final List<Pair<K, V>> values) {
final Map<Pair<K, V>, List<Integer>> valuePositions = new HashMap<Pair<K, V>, List<Integer>>();
for (int i = 0; i < values.size(); i++) {
final Pair<K, V> output = values.get(i);
List<Integer> positions;
if (valuePositions.containsKey(output)) {
positions = valuePositions.get(output);
} else {
positions = new ArrayList<Integer>();
valuePositions.put(output, positions);
}
positions.add(i);
}
return valuePositions;
}
public TreeMap<String, List<Pair<Text, Text>>> getMultipleOutputsResults() {
return mockMultipleOutputs.getOutput();
}
public static <K1, V1, K2, V2> MultipleTextOutputsReduceDriver<K1, V1, K2, V2> newMultipleOutputsReduceDriver() {
return new MultipleTextOutputsReduceDriver<K1, V1, K2, V2>();
}
public static <K1, V1, K2, V2> MultipleTextOutputsReduceDriver<K1, V1, K2, V2> newMultipleOutputsReduceDriver(
final Reducer<K1, V1, K2, V2> reducer) {
return new MultipleTextOutputsReduceDriver<K1, V1, K2, V2>(reducer);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment