Skip to content

Instantly share code, notes, and snippets.

@balidani
Created February 23, 2015 14:06
Show Gist options
  • Save balidani/1139eddcb1b4cbf60404 to your computer and use it in GitHub Desktop.
Save balidani/1139eddcb1b4cbf60404 to your computer and use it in GitHub Desktop.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.graph.gsa;
import org.apache.commons.lang3.Validate;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.operators.CustomUnaryOperation;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.graph.Edge;
import org.apache.flink.graph.Vertex;
import org.apache.flink.graph.spargel.MessagingFunction;
import org.apache.flink.util.Collector;
import java.io.Serializable;
/**
* This class represents iterative graph computations, programmed in a gather-sum-apply perspective.
*
* @param <K> The type of the vertex key in the graph
* @param <VV> The type of the vertex value in the graph
* @param <EV> The type of the edge value in the graph
* @param <M> The intermediate type used by the gather, sum and apply functions
*/
public class GatherSumApplyIteration<K extends Comparable<K> & Serializable,
VV extends Serializable, EV extends Serializable, M> implements CustomUnaryOperation<Vertex<K, VV>,
Vertex<K, VV>> {
private DataSet<Vertex<K, VV>> vertexDataSet;
private DataSet<Edge<K, EV>> edgeDataSet;
private final GatherFunction<K, VV, EV, M> gather;
private final SumFunction<K, VV, EV, M> sum;
private final ApplyFunction<K, VV, EV, M> apply;
private final int maximumNumberOfIterations;
private String name;
private int parallelism = -1;
// ----------------------------------------------------------------------------------
private GatherSumApplyIteration(GatherFunction<K, VV, EV, M> gather, SumFunction<K, VV, EV, M> sum,
ApplyFunction<K, VV, EV, M> apply, DataSet<Edge<K, EV>> edges, int maximumNumberOfIterations) {
Validate.notNull(gather);
Validate.notNull(sum);
Validate.notNull(apply);
Validate.notNull(edges);
Validate.isTrue(maximumNumberOfIterations > 0, "The maximum number of iterations must be at least one.");
this.gather = gather;
this.sum = sum;
this.apply = apply;
this.edgeDataSet = edges;
this.maximumNumberOfIterations = maximumNumberOfIterations;
}
/**
* Sets the name for the gather-sum-apply iteration. The name is displayed in logs and messages.
*
* @param name The name for the iteration.
*/
public void setName(String name) {
this.name = name;
}
/**
* Gets the name from this gather-sum-apply iteration.
*
* @return The name of the iteration.
*/
public String getName() {
return name;
}
/**
* Sets the degree of parallelism for the iteration.
*
* @param parallelism The degree of parallelism.
*/
public void setParallelism(int parallelism) {
Validate.isTrue(parallelism > 0 || parallelism == -1,
"The degree of parallelism must be positive, or -1 (use default).");
this.parallelism = parallelism;
}
/**
* Gets the iteration's degree of parallelism.
*
* @return The iterations parallelism, or -1, if not set.
*/
public int getParallelism() {
return parallelism;
}
// --------------------------------------------------------------------------------------------
// Custom Operator behavior
// --------------------------------------------------------------------------------------------
/**
* Sets the input data set for this operator. In the case of this operator this input data set represents
* the set of vertices with their initial state.
*
* @param dataSet The input data set, which in the case of this operator represents the set of
* vertices with their initial state.
*/
@Override
public void setInput(DataSet<Vertex<K, VV>> dataSet) {
this.vertexDataSet = dataSet;
}
/**
* Computes the results of the gather-sum-apply iteration
*
* @return The resulting DataSet
*/
@Override
public DataSet<Vertex<K, VV>> createResult() {
if (vertexDataSet == null) {
throw new IllegalStateException("The input data set has not been set.");
}
// Prepare type information
TypeInformation<K> keyType = ((TupleTypeInfo<?>) vertexDataSet.getType()).getTypeAt(0);
TypeInformation<M> messageType = TypeExtractor.createTypeInfo(GatherFunction.class, gather.getClass(), 3, null, null);
TypeInformation<Tuple2<K, M>> innerType = new TupleTypeInfo<Tuple2<K, M>>(keyType, messageType);
GatherUdf<K, VV, EV, M> gatherUdf = new GatherUdf<K, VV, EV, M>(gather, innerType);
SumUdf<K, VV, EV, M> sumUdf = new SumUdf<K, VV, EV, M>(sum);
ApplyUdf<K, VV, EV, M> applyUdf = new ApplyUdf<K, VV, EV, M>(apply);
final int[] zeroKeyPos = new int[] {0};
final DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration =
vertexDataSet.iterateDelta(vertexDataSet, maximumNumberOfIterations, zeroKeyPos);
DataSet<Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>> triplets = iteration
.getSolutionSet()
.join(edgeDataSet
.join(iteration.getWorkset())
.where(1)
.equalTo(0)
.with(new PairJoinFunction<K, VV, EV>()))
.where(0)
.equalTo(0)
.with(new TripletJoinFunction<K, VV, EV>());
DataSet<Tuple2<K, M>> gatheredSet = triplets.map(gatherUdf);
DataSet<Tuple2<K, M>> summedSet = gatheredSet.groupBy(0).reduce(sumUdf);
DataSet<Vertex<K, VV>> appliedSet = summedSet
.join(vertexDataSet)
.where(0)
.equalTo(0)
.with(applyUdf);
return iteration.closeWith(appliedSet, appliedSet);
}
/**
* Creates a new gather-sum-apply iteration operator for graphs
*
* @param edges The edge DataSet
*
* @param gather The gather function of the GSA iteration
* @param sum The sum function of the GSA iteration
* @param apply The apply function of the GSA iteration
*
* @param maximumNumberOfIterations The maximum number of iterations executed
*
* @param <K> The type of the vertex key in the graph
* @param <VV> The type of the vertex value in the graph
* @param <EV> The type of the edge value in the graph
* @param <M> The intermediate type used by the gather, sum and apply functions
*
* @return An in stance of the gather-sum-apply graph computation operator.
*/
public static final <K extends Comparable<K> & Serializable, VV extends Serializable, EV extends Serializable, M>
GatherSumApplyIteration<K, VV, EV, M> withEdges(DataSet<Edge<K, EV>> edges,
GatherFunction<K, VV, EV, M> gather, SumFunction<K, VV, EV, M> sum, ApplyFunction<K, VV, EV, M> apply,
int maximumNumberOfIterations) {
return new GatherSumApplyIteration<K, VV, EV, M>(gather, sum, apply, edges, maximumNumberOfIterations);
}
// --------------------------------------------------------------------------------------------
// Triplet Utils
// --------------------------------------------------------------------------------------------
private static final class PairJoinFunction<K extends Comparable<K> & Serializable, VV extends Serializable,
EV extends Serializable> implements FlatJoinFunction<Edge<K, EV>, Vertex<K, VV>,
Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> {
@Override
public void join(Edge<K, EV> edge, Vertex<K, VV> vertex,
Collector<Tuple3<K, Edge<K, EV>, Vertex<K, VV>>> collector) throws Exception {
collector.collect(new Tuple3<K, Edge<K, EV>, Vertex<K, VV>>(edge.getSource(), edge, vertex));
}
}
private static final class TripletJoinFunction<K extends Comparable<K> & Serializable, VV extends Serializable,
EV extends Serializable> implements FlatJoinFunction<Vertex<K, VV>, Tuple3<K, Edge<K, EV>, Vertex<K, VV>>,
Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>> {
@Override
public void join(Vertex<K, VV> vertex, Tuple3<K, Edge<K, EV>, Vertex<K, VV>>
edgeVertex, Collector<Tuple3<Vertex<K, VV>, Edge<K, EV>,
Vertex<K, VV>>> collector) throws Exception {
collector.collect(new Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>(
vertex, edgeVertex.f1, edgeVertex.f2
));
}
}
// --------------------------------------------------------------------------------------------
// Wrapping UDFs
// --------------------------------------------------------------------------------------------
private static final class GatherUdf<K extends Comparable<K> & Serializable, VV extends Serializable,
EV extends Serializable, M> extends RichMapFunction<Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>>,
Tuple2<K, M>> implements ResultTypeQueryable<Tuple2<K, M>> {
private static final long serialVersionUID = 1L;
private transient TypeInformation<Tuple2<K, M>> resultType;
private final GatherFunction<K, VV, EV, M> gatherFunction;
private GatherUdf(GatherFunction<K, VV, EV, M> gatherFunction, TypeInformation<Tuple2<K, M>> resultType) {
this.gatherFunction = gatherFunction;
this.resultType = resultType;
}
@Override
public Tuple2<K, M> map(Tuple3<Vertex<K, VV>, Edge<K, EV>, Vertex<K, VV>> triplet) throws Exception {
return this.gatherFunction.gather(triplet);
}
@Override
public void open(Configuration parameters) throws Exception {
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.gatherFunction.init(getIterationRuntimeContext());
}
this.gatherFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.gatherFunction.postSuperstep();
}
@Override
public TypeInformation<Tuple2<K, M>> getProducedType() {
return this.resultType;
}
}
private static final class SumUdf<K extends Comparable<K> & Serializable, VV extends Serializable,
EV extends Serializable, M> extends RichReduceFunction<Tuple2<K, M>> implements Serializable{
private final SumFunction<K, VV, EV, M> sumFunction;
private SumUdf(SumFunction<K, VV, EV, M> sumFunction) {
this.sumFunction = sumFunction;
}
@Override
public Tuple2<K, M> reduce(Tuple2<K, M> arg0, Tuple2<K, M> arg1) throws Exception {
return this.sumFunction.sum(arg0, arg1);
}
@Override
public void open(Configuration parameters) throws Exception {
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.sumFunction.init(getIterationRuntimeContext());
}
this.sumFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.sumFunction.postSuperstep();
}
}
private static final class ApplyUdf<K extends Comparable<K> & Serializable,
VV extends Serializable, EV extends Serializable, M> extends RichFlatJoinFunction<Tuple2<K, M>,
Vertex<K, VV>, Vertex<K, VV>> implements Serializable {
private final ApplyFunction<K, VV, EV, M> applyFunction;
private ApplyUdf(ApplyFunction<K, VV, EV, M> applyFunction) {
this.applyFunction = applyFunction;
}
@Override
public void join(Tuple2<K, M> arg0, Vertex<K, VV> arg1, Collector<Vertex<K, VV>> out) throws Exception {
this.applyFunction.apply(arg0, arg1, out);
}
@Override
public void open(Configuration parameters) throws Exception {
if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
this.applyFunction.init(getIterationRuntimeContext());
}
this.applyFunction.preSuperstep();
}
@Override
public void close() throws Exception {
this.applyFunction.postSuperstep();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment