Skip to content

Instantly share code, notes, and snippets.

@ebernhardson
Created January 10, 2019 23:05
Show Gist options
  • Save ebernhardson/349eada75f3f45644617c3b187e602f9 to your computer and use it in GitHub Desktop.
Save ebernhardson/349eada75f3f45644617c3b187e602f9 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Inferring Query Intent from Reformulations and Clicks\n",
"============================================\n",
"Radlinski, Szummer, Craswell.\n",
"\n",
"* https://dl.acm.org/citation.cfm?id=1772859\n",
"* http://www.ambuehler.ethz.ch/CDstore/www2010/www/p1171.pdf\n",
"\n",
"Implementation Status: mostly complete, probably broken"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import findspark\n",
"findspark.init('/usr/lib/spark2')\n",
"from pyspark.sql import SparkSession\n",
"import os\n",
"\n",
"master = 'yarn'\n",
"builder = (\n",
" SparkSession.builder\n",
" .appName('query understanding')\n",
" .config(\n",
" 'spark.driver.extraJavaOptions',\n",
" '-Dhttp.proxyHost=webproxy.eqiad.wmnet -Dhttp.proxyPort=8080 -Dhttps.proxyHost=webproxy.eqiad.wmnet -Dhttps.proxyPort=8080')\n",
" .config('spark.jars.packages', 'graphframes:graphframes:0.6.0-spark2.3-s_2.11')\n",
")\n",
"if master == 'yarn':\n",
" os.environ['PYSPARK_SUBMIT_ARGS'] = '--archives spark_venv.zip#venv pyspark-shell'\n",
" os.environ['PYSPARK_PYTHON'] = 'venv/bin/python'\n",
" builder = (\n",
" builder\n",
" .master('yarn')\n",
" .config('spark.sql.shuffle.partitions', 400)\n",
" .config('spark.dynamicAllocation.maxExecutors', 200)\n",
" .config('spark.executor.memory', '2048m')\n",
" )\n",
"elif master == 'local':\n",
" builder = (\n",
" builder\n",
" .master('local[12]')\n",
" .config('spark.driver.memory', '8g')\n",
" )\n",
"else:\n",
" raise Exception()\n",
"\n",
"spark = builder.getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"BASE_PATH = '/user/ebernhardson/query_understanding/201811'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import functions as F, types as T, Window\n",
"from pyspark.ml.feature import StringIndexer\n",
"import numpy as np\n",
"\n",
"df = (\n",
" spark.read.table('discovery.query_clicks_hourly')\n",
" .where(F.col('year') == 2018)\n",
" .where(F.col('month') == 11)\n",
" .withColumn('query', F.lower(F.trim(F.col('query'))))\n",
" # high volume ip's tend to skew things and blow up executors. They\n",
" # aren't as meaningful as a wide variety of users either.\n",
" .withColumn(\n",
" 'q_by_ip_day',\n",
" F.count(F.lit(1)).over(Window\n",
" .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)\n",
" .partitionBy('year', 'month', 'day', 'ip')))\n",
" .where(F.col('q_by_ip_day') < 50)\n",
" .drop('q_by_ip_day')\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Expand (substep 1)\n",
"================\n",
"\n",
"Find k=10 most frequent valid reformulations of q. q' is a reformulation of q if:\n",
"\n",
"* q was followed by q' within 10 minutes by 2 distinct users\n",
"* Of all pairs of queries (q_i, q') issued by any user within 10 minutes, (q, q') occured at least a fraction δ = 0.001 of the time\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Accept the top 10 most frequent valid reformulations q' of q\n",
"TOP_K = 10\n",
"# Reformulations must occur within 10 minutes of each other\n",
"SESSION_TIMEOUT = 600\n",
"# For q' to be a valid reformulation of q, q must make up at\n",
"# least 0.1% of the reformulations of q'.\n",
"MIN_REFORMULATE_RATIO = 0.001\n",
"# For q' to be a valid reformulation of q, the reformulation\n",
"# must be made by at least 2 distinct users\n",
"MIN_REFORMULATE_USERS = 2"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class Lagged:\n",
" def __init__(self, timeout):\n",
" self.timeout = timeout\n",
" self.timestamps = []\n",
" self.data = []\n",
" \n",
" def __iter__(self):\n",
" return iter(self.data)\n",
" \n",
" def add(self, timestamp, value):\n",
" assert not self.timestamps or self.timestamps[-1] <= timestamp, \"timestamps out of order\"\n",
" self.timestamps.append(timestamp)\n",
" self.data.append(value)\n",
" timeout_before = timestamp - self.timeout\n",
" while self.timestamps and self.timestamps[0] < timeout_before:\n",
" self.timestamps.pop()\n",
" self.data.pop()\n",
"\n",
"def expand_query_pairs(row, session_timeout=SESSION_TIMEOUT):\n",
" # q was followed by q' within 10 minutes\n",
" current_queries = Lagged(session_timeout)\n",
" reformulations = set()\n",
" for timestamp, query in sorted(row.queries, key=lambda q: q.timestamp):\n",
" current_queries.add(timestamp, query)\n",
" for source_query in current_queries:\n",
" if query != source_query:\n",
" reformulations.add((row.wikiid, row.identity, source_query, query))\n",
" yield from reformulations"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def full_window(*fields):\n",
" return (\n",
" Window\n",
" .partitionBy(*fields)\n",
" .rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))\n",
"\n",
"rdd_possible_reformulations = (\n",
" df\n",
" .groupby('wikiid', 'identity')\n",
" .agg(F.collect_list(F.struct('timestamp', 'query')).alias('queries'))\n",
" .rdd.flatMap(expand_query_pairs)\n",
")\n",
"\n",
"# graph edges from query to reformulated query weighted by the\n",
"# number of distinct identities performing the reformulation\n",
"df_valid_reformulations = (\n",
" spark.createDataFrame(rdd_possible_reformulations, T.StructType([\n",
" T.StructField('wikiid', T.StringType(), nullable=False),\n",
" T.StructField('identity', T.StringType(), nullable=False),\n",
" T.StructField('src', T.StringType(), nullable=False),\n",
" T.StructField('dst', T.StringType(), nullable=False),\n",
" ]))\n",
" # Some identities may have performed the same reformulation many times, drop them\n",
" .drop_duplicates()\n",
" # Weight of the edge is the number of times separate identities made this reformulation\n",
" .groupBy('wikiid', 'src', 'dst')\n",
" .count()\n",
" .withColumnRenamed('count', 'weight')\n",
" # Of all pairs of queries (q_i, q') issued by any user within 10 minutes\n",
" # (q, q') by at least two users\n",
" .where(F.col('weight') >= MIN_REFORMULATE_USERS)\n",
" # (q, q') occured at least a fraction δ = 0.001 of the time\n",
" .withColumn('dst_weight', F.sum('weight').over(full_window('wikiid', 'dst')))\n",
" .where(F.col('weight') / F.col('dst_weight') > MIN_REFORMULATE_RATIO)\n",
" # top k per query\n",
" .withColumn('rank', F.rank().over(Window.partitionBy('wikiid', 'src').orderBy(F.col('weight').desc())))\n",
" .where(F.col('rank') < TOP_K)\n",
" .select('wikiid', 'src', 'dst', 'weight')\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Expand (substep 2)\n",
"================\n",
"Gather up query neighborhoods. Each neighborhood is defined by it's root and all nodes within 2 degrees."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# The second step from reformulated query to related query goes\n",
"# in the opposite direction. This renaming is mostly for clarity\n",
"df_valid_reformulations.persist()\n",
"df_second_step_edges = df_valid_reformulations.select(\n",
" 'wikiid',\n",
" F.col('dst').alias('src'),\n",
" F.col('src').alias('dst'),\n",
")\n",
"# Edge from src to all nodes two steps away. We don't need the path,\n",
"# these edges are all going to be replaced with edges from cosine sim\n",
"# within the neighborhoods.\n",
"df_two_step_edges = (\n",
" df_valid_reformulations.alias('first')\n",
" .join(df_second_step_edges.alias('second'), \n",
" on=(F.col('first.wikiid') == F.col('second.wikiid')) & (F.col('first.dst') == F.col('second.src')))\n",
" .select('first.wikiid', 'first.src', 'second.dst')\n",
")\n",
"\n",
"df_neighborhoods = (\n",
" df_valid_reformulations\n",
" .select(\n",
" # Each neighborhood is named for its root\n",
" 'wikiid',\n",
" F.col('src').alias('neighborhood'),\n",
" F.col('dst').alias('member'))\n",
" .union(df_two_step_edges.select(\n",
" 'wikiid',\n",
" F.col('first.src').alias('neighborhood'),\n",
" F.col('second.dst').alias('member')))\n",
" # There can be multiple paths between src and dst,\n",
" # but we only care that one exists.\n",
" .drop_duplicates()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DataFrame[wikiid: string, src: string, dst: string, weight: bigint]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path_df_neighborhoods = os.path.join(BASE_PATH, 'df_neighborhoods')\n",
"#df_neighborhoods.write.parquet(path_df_neighborhoods)\n",
"df_neighborhoods = spark.read.parquet(path_df_neighborhoods)\n",
"df_valid_reformulations.unpersist()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Filter (substep 1)\n",
"==============\n",
"\n",
"Perform a random walk over the bipartite click graph to find nearby queries"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"\n",
"def random_walk(row):\n",
" \"\"\"Calculate random walk destination probability of a cluster\n",
" \n",
" The random walk to want to perform is:\n",
" \n",
" q_i -> click page -> q_j\n",
" \"\"\"\n",
" G = nx.Graph()\n",
" for edge in row.edges:\n",
" # q_i -> click page\n",
" G.add_edge(row.src, edge.dst, weight=edge.weight)\n",
" for src in edge.srcs:\n",
" # click page -> q_j\n",
" G.add_edge(edge.dst, src.src, weight=src.weight)\n",
" A = nx.adj_matrix(G).todense()\n",
" # matrix -> ndarray\n",
" A = np.array(A, dtype=np.float64)\n",
" # normalize weights by row, this is equiv to: A[i] = A[i] / sum(A[i])\n",
" row_weights = np.sum(A, axis=1)\n",
" A = np.divide(A.T, row_weights, out=np.zeros_like(A), where=row_weights != 0.).T\n",
" # weighted degree matrix\n",
" D = np.diag(np.sum(A, axis=1))\n",
" # transition probabilities\n",
" T = np.dot(np.linalg.inv(D), A)\n",
" # start at the source query\n",
" p = np.zeros((A.shape[0], 1), dtype=np.float64)\n",
" row_src_idx = list(G.nodes()).index(row.src)\n",
" p[row_src_idx, 0] = 1.\n",
" # node weight is the sum of the steps of the transition matrix, basically\n",
" # the reach \n",
" node_weight = 1\n",
" # walk 2 steps, first to the clicked page and then to the queries that clicked on that page.\n",
" walk_length = 2\n",
" for i in range(walk_length):\n",
" p = np.dot(T, p)\n",
" node_weight += np.sum(p.ravel())\n",
" node_prob = [(node, float(prob)) for node, prob in zip(G.nodes(), p.ravel()) if prob > 0]\n",
" return (row.wikiid, row.src, node_prob, node_weight)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Edges from queries performed to the page clicked weighted by\n",
"# the number of identities clicking.\n",
"df_edge = (\n",
" df\n",
" .select(\n",
" 'wikiid',\n",
" 'identity',\n",
" F.col('query').alias('src'),\n",
" F.explode('clicks.pageid').alias('dst'))\n",
" # Weight by number of identities\n",
" .groupBy('wikiid', 'src', 'dst')\n",
" .agg(F.countDistinct('identity').alias('weight'))\n",
")\n",
"\n",
"# Attach to each edge the set of queries that lead into it's clicked page.\n",
"# This is very expensive on the large neighborhoods, split into many thousands\n",
"# of partitions to try and keep them separate to somewhat reduce skew \n",
"# (before, 400 partitions. 395 finish in 2 minutes, remaining 40+ minutes)\n",
"df_query_to_click_to_query = (\n",
" df_edge\n",
" .repartition(10000, ['wikiid', 'dst'])\n",
" # This has the undesirable side effect of duplicating the neighborhood. A dst\n",
" # with 10k inbound edges will duplicate itself 10k times.\n",
" # TODO: Find a way that doesn't duplicate. \n",
" .withColumn(\n",
" 'srcs',\n",
" F.collect_list(F.struct('src', 'weight')).over(full_window('wikiid', 'dst')))\n",
" .repartition(10000, ['wikiid', 'src'])\n",
" .groupby('wikiid', 'src')\n",
" .agg(F.collect_list(F.struct('dst', 'weight', 'srcs')).alias('edges'))\n",
")\n",
"\n",
"# perform the random walk.\n",
"df_rw = spark.createDataFrame(\n",
" df_query_to_click_to_query.rdd.map(random_walk),\n",
" T.StructType([\n",
" df_query_to_click_to_query.schema['wikiid'],\n",
" df_query_to_click_to_query.schema['src'],\n",
" # Would be more efficient as a SparseVector, but we would need to\n",
" # turn all the queries into 0-indexed values with StringIndexer,\n",
" # and then debugging has a harder time with less context.\n",
" T.StructField('rw', T.ArrayType(T.StructType([\n",
" T.StructField('q', T.StringType(), nullable=False),\n",
" T.StructField('prob', T.DoubleType(), nullable=False) \n",
" ]))),\n",
" T.StructField('weight', T.DoubleType(), nullable=False)\n",
" ]))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DataFrame[wikiid: string, src: string, dst: int, weight: bigint]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path_df_rw = os.path.join(BASE_PATH, 'df_random_walk_prob')\n",
"#df_rw.write.parquet(path_df_rw)\n",
"df_rw = spark.read.parquet(path_df_rw)\n",
"df_edge.unpersist()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Filter (substep 2)\n",
"==============\n",
"Generate a new graph using the similarity between nodes in a neighborhood\n",
"to define the edges. All nodes in this graph are queries."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# The similarity threshold here can be quite low, it only needs to filter\n",
"# absolute junk. The emitted graph is very small and can easily\n",
"# be post filtered\n",
"SIM_THRESHOLD = 0.01\n",
"\n",
"# Calculate the new set of edges from random walk similarities within each neighborhood\n",
"def neighborhood_similarity(row):\n",
" queries = list(set(rw.q for node in row.nodes for rw in node.rw))\n",
" idx = {q: i for i, q in enumerate(queries)}\n",
" A = np.zeros((len(row.nodes), len(queries)), dtype=np.float32)\n",
" for i, node in enumerate(row.nodes):\n",
" A[i, [idx[rw.q] for rw in node.rw]] = [rw.prob for rw in node.rw]\n",
" # cosine similarity between all pairs\n",
" sim = np.dot(A, A.T)\n",
" square_mag = np.diag(sim)\n",
" inv_square_mag = 1. / square_mag\n",
" inv_square_mag[np.isinf(inv_square_mag)] = 0\n",
" inv_mag = np.sqrt(inv_square_mag)\n",
" cosine = (sim * inv_mag).T * inv_mag\n",
" x, y = np.where(cosine > SIM_THRESHOLD)\n",
" for i, j in zip(x, y):\n",
" if i == j:\n",
" continue\n",
" # Emits edges for (i, j) and (j, i), making it an undirected graph\n",
" yield row.wikiid, row.nodes[i].member, row.nodes[j].member, float(cosine[i, j])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Couple tests of neighborhood similarity\n",
"from pyspark.sql import Row\n",
"\n",
"row = Row(\n",
" wikiid='test',\n",
" nodes=[\n",
" Row(member='a', rw=[Row(q='hi', prob=0.123), Row(q='mom', prob=0.543)]),\n",
" Row(member='b', rw=[Row(q='hi', prob=0.123), Row(q='mom', prob=0.543)]),\n",
" Row(member='c', rw=[Row(q='other', prob=0.123), Row(q='thing', prob=0.543)]),\n",
" ])\n",
"seen = 0\n",
"for wikiid, src, dst, weight in neighborhood_similarity(row):\n",
" seen += 1\n",
" assert wikiid == 'test'\n",
" assert src in ('a', 'b')\n",
" assert dst in ('a', 'b')\n",
" assert src != dst\n",
" assert weight == 1.0\n",
"assert seen == 2"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"df_neighborhoods_grouped_with_rw = (\n",
" df_neighborhoods\n",
" .alias('left')\n",
" .join(\n",
" # inner join removes all the unclicked queries that aren't in\n",
" # the random walk but were used to expand the neighborhoods.\n",
" # TODO: The original paper isn't explicit, but this must be what\n",
" # they are doing? Unfortunately this removes most of the potential\n",
" # vertices from the graph for our click logs.\n",
" df_rw.alias('right'), how='inner',\n",
" on=(F.col('left.wikiid') == F.col('right.wikiid')) & (F.col('left.member') == F.col('right.src'))\n",
" )\n",
" .select('left.wikiid', 'left.neighborhood', 'left.member', 'right.rw')\n",
" .groupby('wikiid', 'neighborhood')\n",
" .agg(F.collect_list(F.struct('member', 'rw')).alias('nodes'))\n",
")\n",
"\n",
"# Contains edges from queries to reformulations of those queries\n",
"# weighted by their random walk similarity on the bipartite click graph.\n",
"df_random_walk_sim_edges = spark.createDataFrame(\n",
" df_neighborhoods_grouped_with_rw.rdd.flatMap(neighborhood_similarity),\n",
" T.StructType([\n",
" T.StructField('wikiid', T.StringType(), nullable=False),\n",
" T.StructField('src', T.StringType(), nullable=False),\n",
" T.StructField('dst', T.StringType(), nullable=False),\n",
" T.StructField('weight', T.DoubleType(), nullable=False),\n",
" ])\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"path_random_walk_sim_edges = os.path.join(BASE_PATH, 'random_walk_sim_edges')\n",
"#df_random_walk_sim_edges.write.parquet(path_random_walk_sim_edges)\n",
"df_random_walk_sim_edges = spark.read.parquet(path_random_walk_sim_edges)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Clustering\n",
"========="
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"random_walk_sim_edges = (\n",
" df_random_walk_sim_edges\n",
" .where(F.col('wikiid') == 'enwiki')\n",
" .collect()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx\n",
"G = nx.Graph()\n",
"for edge in random_walk_sim_edges:\n",
" if edge.weight > 0.000001:\n",
" G.add_edge(edge.src, edge.dst, weight=edge.weight)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from networkx.algorithms import community as nx_c\n",
"communities = list(nx_c.asyn_lpa_communities(G, weight='weight', seed=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"len(communities)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"for community in sorted(communities, key=len, reverse=True)[300:500]:\n",
" print('*' * 20)\n",
" for node_id in community:\n",
" print(node_id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from networkx.algorithms.centrality import edge_current_flow_betweenness_centrality\n",
"\n",
"edge_current_flow_betweenness_centrality(G, weight='weight')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import networkx.algorithms.components\n",
"conn_comp = list(networkx.algorithms.components.connected_components(G))\n",
"conn_comp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Number of query reformulations after initial filtering\n",
"df_valid_reformulations.count()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Number of queries reformulated from\n",
"df_valid_reformulations.groupby('wikiid', 'src').count().drop('wikiid', 'src').describe().show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Number of queries reformulated into\n",
"df_valid_reformulations.groupby('wikiid', 'dst').count().drop('wikiid', 'dst').describe().show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Number of distinct queries\n",
"df_valid_reformulations.select('wikiid', F.explode(F.array(F.col('src'), F.col('dst')))).drop_duplicates().count()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Total members of all neighborhoods\n",
"df_neighborhoods.count()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 516k unique members of all neighborhoods. 2.3 neighborhoods per query, up to 522.\n",
"df_neighborhoods.groupby('wikiid', 'member').count().drop('wikiid', 'member').describe().show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 294k neighborhoods. 4.1 members per neighborhood, up to 444\n",
"df_neighborhoods.groupby('wikiid', 'neighborhood').count().drop('wikiid', 'neighborhood').describe().show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Number of non-zero random walk probabilities per source query\n",
"df_rw.select(F.size('rw'), 'weight').describe().show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"num_edges = (\n",
" df_edge\n",
" .where(F.col('wikiid') == 'enwiki')\n",
" .count()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# Convert the bipartite query/click graph into a sparse adjacency matrix\n",
"import scipy.sparse\n",
"import itertools\n",
"from collections import defaultdict\n",
"\n",
"VERT_TYPE_QUERY = 1\n",
"VERT_TYPE_PAGE = 2\n",
"\n",
"data = np.empty(2 * num_edges, dtype=np.float32)\n",
"row_ind = np.empty(2 * num_edges, dtype=np.int32)\n",
"col_ind = np.empty(2 * num_edges, dtype=np.int32)\n",
"next_vert_id = itertools.count()\n",
"vert_ids = defaultdict(lambda: next(next_vert_id))\n",
"for i, row in enumerate(df_edge.where(F.col('wikiid') == 'enwiki').toLocalIterator()):\n",
" src_vert = vert_ids[(VERT_TYPE_QUERY, row.wikiid, row.src)]\n",
" dst_vert = vert_ids[(VERT_TYPE_PAGE, row.wikiid, row.dst)]\n",
" \n",
" i *= 2\n",
" row_ind[i] = src_vert\n",
" col_ind[i] = dst_vert\n",
" data[i] = row.weight\n",
" \n",
" row_ind[i + 1] = dst_vert\n",
" col_ind[i + 1] = src_vert\n",
" data[i + 1] = row.weight"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"num_vertices = max(vert_ids.values()) + 1\n",
"A = scipy.sparse.csr_matrix((data, (row_ind, col_ind)), shape=(num_vertices, num_vertices))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min, sys: 44 ms, total: 1min\n",
"Wall time: 1min\n"
]
}
],
"source": [
"import numba\n",
"\n",
"# Normalize each row so it sums to 1.\n",
"# This makes A[i, j] the probability a user at page vertex i will transition to vertex j\n",
"#@numba.njit()\n",
"def norm_by_indptr(indptr, data):\n",
" for i in range(indptr.shape[0] - 1):\n",
" start = indptr[i]\n",
" end = indptr[i + 1]\n",
" data[start:end] /= np.sum(data[start:end])\n",
" \n",
"%time norm_by_indptr(A.indptr, A.data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import tqdm\n",
"\n",
"indptr = A.indptr\n",
"indices = A.indices\n",
"vert_id = 3\n",
"\n",
"@numba.njit(nogil=True)\n",
"def csr_plus_csr(\n",
" n_row, n_col,\n",
" Ap, Aj, Ax,\n",
" Bp, Bj, Bx,\n",
" Cp, Cj, Cx\n",
"):\n",
" Cp[0] = 0\n",
" nnz = 0\n",
" for i in range(n_row):\n",
" A_pos = Ap[i]\n",
" B_pos = Bp[i]\n",
" A_end = Ap[i+1]\n",
" B_end = Bp[i+1]\n",
" \n",
" while A_pos < A_end and B_pos < B_end:\n",
" A_j = Aj[A_pos]\n",
" B_j = Bj[B_pos]\n",
" if A_j == B_j:\n",
" result = Ax[A_pos] + Bx[B_pos]\n",
" if result != 0:\n",
" Cj[nnz] = A_j\n",
" Cx[nnz] = result\n",
" nnz += 1\n",
" A_pos += 1\n",
" B_pos += 1\n",
" elif A_j < B_j:\n",
" result = Ax[A_pos]\n",
" Cj[nnz] = A_j\n",
" Cx[nnz] = result\n",
" nnz += 1\n",
" else: # B_j < A_j\n",
" result = Bx[B_pos]\n",
" if result != 0:\n",
" Cj[nnz] = B_j\n",
" Cx[nnz] = result\n",
" nnz += 1\n",
" B_pos += 1\n",
" # tail\n",
" while A_pos < A_end:\n",
" result = Ax[A_pos]\n",
" if result != 0:\n",
" Cj[nnz] = Aj[A-pos]\n",
" Cx[nnz] = result\n",
" nnz += 1\n",
" A_pos += 1\n",
" while B_pos < B_end:\n",
" result = Bx[B_pos]\n",
" if result != 0:\n",
" Cj[nnz] = Bj[B_pos]\n",
" Cx[nnz] = result\n",
" nnz += 1\n",
" B_pos += 1\n",
" Cp[i+1] = nnz\n",
" \n",
"def something_add(A, B):\n",
" maxnnz = A.nnz + B.nnz\n",
" indptr = np.empty_like(A.indptr)\n",
" indices = np.empty(maxnnz, dtype=A.indices.dtype)\n",
" data = np.empty(maxnnz, dtype=A.data.dtype)\n",
" \n",
" csr_plus_csr(\n",
" A.indptr, A.indices, A.data,\n",
" B.indptr, B.indices, B.data,\n",
" indptr, indices, data)\n",
" return indptr, indices[:indptr[-1]], data[:indptr[-1]]\n",
" \n",
"@numba.njit(nogil=True)\n",
"def foo(indptr, indices, data, vert_id):\n",
" s1_start = indptr[vert_id]\n",
" s1_end = indptr[vert_id + 1]\n",
" s1_indices = indptr[s1_start:s1_end]\n",
" s1_prob = data[s1_indices]\n",
" for idx in range(s1_start, s1_end):\n",
" pos = indptr[idx]\n",
" to_vert_id = indices[idx]\n",
" transition_prob = data[idx]\n",
" \n",
" return last_idx\n",
"\n",
"print(foo(A.indptr, A.indices, 3))\n",
"i = 0\n",
"for (vert_type, vert_wiki, vert_name), vert_id in tqdm.tqdm_notebook(vert_ids.items(), total=len(vert_ids)):\n",
" if vert_type != VERT_TYPE_QUERY:\n",
" continue\n",
" foo(A.indptr, A.indices, vert_id)\n",
" #i += 1\n",
" #if i >= 5:\n",
" # break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"from itertools import islice\n",
"import tqdm\n",
"\n",
"@numba.njit()\n",
"def size_est(indptr, indices):\n",
" nnz = 0\n",
" for to_vert_id in indices:\n",
" nnz += indptr[to_vert_id + 1] - indptr[to_vert_id]\n",
" return nnz\n",
" \n",
"@numba.njit()\n",
"def walk_numba(indptr, indices, data, vert_id):\n",
" \"\"\"Walk two steps in a sparse adjancecy matrix\n",
" \n",
" The ugliest part of this is the result handling. To achieve\n",
" any kind of speed we can't do allocation here, but we don't\n",
" know ahead of time the result set size.\n",
" \"\"\"\n",
" # step 1 starting at vert_id\n",
" start = indptr[vert_id]\n",
" end = indptr[vert_id + 1]\n",
" nnz = size_est(indptr, indices[start:end])\n",
" # step 2, starting at step 1 results\n",
" #p_step_2[:] = 0\n",
" out_indices = np.empty(nnz, dtype=indices.dtype)\n",
" out_data = np.empty(nnz, dtype=data.dtype)\n",
" \n",
" nz = 0\n",
" for i in range(start, end):\n",
" to_vert_id = indices[i]\n",
" transition_prob = data[i]\n",
" #for to_vert_id, transition_prob in zip(indices[start:end], data[start:end]):\n",
" start = indptr[to_vert_id]\n",
" end = indptr[to_vert_id + 1]\n",
" size = end - start\n",
" out_indices[nz:nz+size] = indices[start:end]\n",
" out_data[nz:nz+size] = transition_prob * data[start:end]\n",
" nz += size\n",
" return out_indices[:nz], out_data[:nz]\n",
"\n",
"def walk_scipy(A, vert_id):\n",
" # Probability of ariving at each clicked page from source query vert_name\n",
" p_step_1 = A[vert_id]\n",
" # Probability of ariving at each query from clicked page\n",
" p_step_2 = None\n",
" for to_vert, transition_prob in zip(p_step_1.indices, p_step_1.data):\n",
" one_step_2_prob = transition_prob * A[to_vert]\n",
" if p_step_2 is None:\n",
" p_step_2 = one_step_2_prob\n",
" else:\n",
" p_step_2 += one_step_2_prob\n",
"\n",
"nnz = 0\n",
"for vert_id in tqdm.tqdm_notebook(vert_ids.values(), total=len(vert_ids)):\n",
" if vert_type != VERT_TYPE_QUERY:\n",
" continue\n",
" indices = A.indices[A.indptr[vert_id]:A.indptr[vert_id + 1]]\n",
" nnz += size_est(A.indptr, indices)\n",
"print(nnz)\n",
" \n",
"results = []\n",
"for (vert_type, vert_wiki, vert_name), vert_id in tqdm.tqdm_notebook(vert_ids.items(), total=len(vert_ids)):\n",
" if vert_type != VERT_TYPE_QUERY:\n",
" continue\n",
" if True:\n",
" res = walk_numba(A.indptr, A.indices, A.data, vert_id)\n",
" else:\n",
" res = walk_scipy(A, vert_id)\n",
" results.append(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(sum(x.size for x in results))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment