Created
February 4, 2017 20:25
-
-
Save eliasah/2c0481c626d3e23d32564a61f0ffb3c2 to your computer and use it in GitHub Desktop.
Gist for SO question: http://stackoverflow.com/questions/42043842/apache-spark-using-all-available-memory-on-small-dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import scipy | |
import scipy.stats as sts | |
import random | |
import pyspark | |
import pyspark.sql.types as stypes | |
import pyspark.sql.functions as sfunctions | |
def update_vectors(K, list_of_weights, list_of_dicts): | |
agg_dict = {} | |
for i in range(len(list_of_weights)): | |
weight = list_of_weights[i] | |
dict_ = list_of_dicts[i] | |
for k, v in dict_.items(): | |
if k in agg_dict: | |
agg_dict[k] += v * weight | |
else: | |
agg_dict[k] = v * weight | |
# normalize | |
sum_values = sum(agg_dict.values()) | |
for k, v in agg_dict.items(): | |
agg_dict[k] /= sum_values | |
# keep only top K items (or with equal value) | |
if len(agg_dict) > K: | |
sorted_items = sorted(agg_dict.items(), key=lambda x:x[1], reverse=True) | |
_, Kth_value = sorted_items[K - 1] | |
valid_items = [(k, v) for k,v in sorted_items if v >= Kth_value] | |
return dict(valid_items) | |
else: | |
return agg_dict | |
def propagation_step(to_, from_, K, nodes, edges): | |
""" | |
@param to_ - the source of vectors to aggregate (str in ['query', 'doc', 'task']) | |
@param from_ - the destination of new vectors (str in ['query', 'doc', 'task']) | |
@param K - the number of terms to keep | |
@param nodes - the nodes frame | |
@param edges - the edges frame | |
Returns a new nodes dataframe with updated vectors. | |
--- | |
The algorithm is equivalent to the following query, but written in Spark dataframe syntax. | |
This query can't be used with HiveQL because temporary tables can't be updated after each | |
step in the propagation algorithm -- and creating new temporary tables is not yet supported. | |
SELECT n.id, | |
update_vectors({K}, collect_list(e.weight), collect_list(nn.dst_vector)) as updated_vector | |
FROM nodes n | |
JOIN ( SELECT src, dst, weight | |
FROM edges | |
WHERE type = '{edge_type}' ) e | |
ON n.id = e.src | |
JOIN ( SELECT id, vector as dst_vector | |
FROM nodes | |
WHERE type = '{from_}') nn | |
ON e.dst = nn.id | |
WHERE n.type = '{to_}' | |
GROUP BY n.id | |
""" | |
filtered_edges = edges.filter(edges.type == '%s-%s' % (to_, from_)).select(['src', 'dst', 'weight']) | |
filtered_nodes = nodes.filter(nodes.type == from_).select(['id', 'vector']).withColumnRenamed('vector', 'dst_vector').withColumnRenamed('id', 'dst_id') | |
filtered_edges = filtered_edges.join(filtered_nodes, filtered_edges.dst == filtered_nodes.dst_id) | |
agg_nodes = nodes.filter(nodes.type == to_).join(filtered_edges, nodes.id == filtered_edges.src).select(['id', 'weight', 'dst_vector']) | |
g = agg_nodes.groupBy('id') | |
agg_df = (g.agg(sfunctions.collect_list(agg_nodes.weight), sfunctions.collect_list(agg_nodes.dst_vector)) | |
.withColumnRenamed('collect_list(weight)', 'weight').withColumnRenamed('collect_list(dst_vector)', 'dst_vector')) | |
agg_df = agg_df.withColumn('updated_vector', update_vectorsUDF(sfunctions.lit(K), agg_df.weight, agg_df.dst_vector)) | |
nodes = nodes.join(agg_df.select(['id', 'updated_vector']), on='id', how='left') | |
agg_df.unpersist() | |
nodes = nodes.withColumn('new_vector', sfunctions.when(sfunctions.isnull(nodes.updated_vector), nodes.vector).otherwise(nodes.updated_vector)) | |
nodes = nodes.drop('vector').drop('updated_vector').withColumnRenamed('new_vector', 'vector') | |
return nodes | |
if __name__ == '__main__': | |
sc = pyspark.SparkContext() | |
hv = pyspark.HiveContext(sc) | |
# | |
# Generate fake data | |
# | |
doc_data = [] | |
query_data = [] | |
task_data = [] | |
for i in range(50): | |
doc_data.append(('doc_%02d' % i, 'doc', | |
dict(zip(list(map(int, (np.random.randint(0, high=40, size=10)))), [0.1] * 10)))) | |
for i in range(25): | |
query_data.append(('query_%02d' % i, 'query', {})) | |
for i in range(10): | |
task_data.append(('task_%02d' % i, 'task', {})) | |
all_data = [] | |
all_data.extend(doc_data) | |
all_data.extend(query_data) | |
all_data.extend(task_data) | |
# Create fake edges | |
edge_data = [] | |
for q in range(25): | |
for d in range(50): | |
if sts.binom.rvs(10, 0.2, size=1)[0] > 3: | |
edge_data.append(('query_%02d' % q, 'doc_%02d' % d, "query-doc", int(np.random.randint(1, high=11)))) | |
edge_data.append(('doc_%02d' % d, 'query_%02d' % q, "doc-query", int(np.random.randint(1, high=11)))) | |
for t in range(10): | |
for q in range(25): | |
if sts.binom.rvs(10, 0.2, size=1)[0] > 3: | |
edge_data.append(('task_%02d' % t, 'query_%02d' % q, "task-query", int(np.random.randint(1, high=11)))) | |
edge_data.append(('query_%02d' % q, 'task_%02d' % t, "query-task", int(np.random.randint(1, high=11)))) | |
for t in range(10): | |
for tt in range(t, 10): | |
if sts.binom.rvs(10, 0.2, size=1)[0] > 3: | |
edge_data.append(('task_%02d' % t, 'task_%02d' % tt, "task-task", int(np.random.randint(1, high=11)))) | |
edge_data.append(('task_%02d' % tt, 'task_%02d' % t, "task-task", int(np.random.randint(1, high=11)))) | |
node_schema = stypes.StructType([ | |
stypes.StructField("id", stypes.StringType(), False), | |
stypes.StructField("type", stypes.StringType(), False), | |
stypes.StructField("vector", stypes.MapType(stypes.IntegerType(), stypes.FloatType()), False) | |
]) | |
# | |
update_vectorsUDF = sfunctions.udf(update_vectors, stypes.MapType(stypes.IntegerType(), stypes.FloatType())) | |
nodes = hv.createDataFrame(all_data, schema=node_schema) | |
edges = hv.createDataFrame(edge_data, ["src", "dst", "type", "weight"]) | |
K = 20 | |
nodes = propagation_step('query', 'doc', K, nodes, edges) | |
nodes = propagation_step('task', 'query', K, nodes, edges) | |
nodes = propagation_step('task', 'task', K, nodes, edges) | |
nodes = propagation_step('query', 'task', K, nodes, edges) | |
nodes = propagation_step('doc', 'query', K, nodes, edges) | |
nodes.show(5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment