Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Created April 11, 2023 02:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save napsternxg/5a32bba213181c353ac9ef40d8c786d8 to your computer and use it in GitHub Desktop.
Save napsternxg/5a32bba213181c353ac9ef40d8c786d8 to your computer and use it in GitHub Desktop.
Improve edgelist processing speed of PyTorchBiggraph-Pytorch using pyarrow parquet reader.
diff --git a/torchbiggraph/converters/importers.py b/torchbiggraph/converters/importers.py
index fa84bc6..765e9fa 100644
--- a/torchbiggraph/converters/importers.py
+++ b/torchbiggraph/converters/importers.py
@@ -28,6 +28,7 @@ from torchbiggraph.graph_storages import (
RELATION_TYPE_STORAGES,
)
from torchbiggraph.types import UNPARTITIONED
+from tqdm import tqdm
def log(msg):
@@ -58,7 +59,7 @@ class TSVEdgelistReader(EdgelistReader):
def read(self, path: Path):
with path.open("rt") as tf:
- for line_num, line in enumerate(tf, start=1):
+ for line_num, line in tqdm(enumerate(tf, start=1), desc=f"{self.__class__.__name__} read"):
words = line.split(self.delimiter)
try:
lhs_word = words[self.lhs_col]
@@ -76,7 +77,7 @@ class TSVEdgelistReader(EdgelistReader):
) from None
-class ParquetEdgelistReader(EdgelistReader):
+class OldParquetEdgelistReader(EdgelistReader):
def __init__(
self,
lhs_col: str,
@@ -105,7 +106,8 @@ class ParquetEdgelistReader(EdgelistReader):
with path.open("rb") as tf:
columns = [self.lhs_col, self.rhs_col, self.rel_col, self.weight_col]
fetch_columns = [c for c in columns if c is not None]
- for row in parquet.reader(tf, columns=fetch_columns):
+ fetch_col_idx = [i for i, c in enumerate(columns) if c is not None]
+ for row in tqdm(parquet.reader(tf, columns=fetch_columns), desc=f"{self.__class__.__name__} read"):
offset = 0
ret = []
for c in columns:
@@ -116,6 +118,46 @@ class ParquetEdgelistReader(EdgelistReader):
ret.append(None)
yield tuple(ret)
+
+
+class ParquetEdgelistReader(EdgelistReader):
+ def __init__(
+ self,
+ lhs_col: str,
+ rhs_col: str,
+ rel_col: Optional[str],
+ weight_col: Optional[str],
+ ):
+ """Reads edgelists from a Parquet file.
+
+ col arguments can either be the column name or the offset of the col.
+ """
+ self.lhs_col = lhs_col
+ self.rhs_col = rhs_col
+ self.rel_col = rel_col
+ self.weight_col = weight_col
+
+ def read(self, path: Path):
+ try:
+ import pyarrow.parquet as pq
+ except ImportError as e:
+ raise ImportError(
+ f"{e}. HINT: You can install pyarrow by running "
+ "'pip install pyarrow'"
+ )
+ columns = [self.lhs_col, self.rhs_col, self.rel_col, self.weight_col]
+ fetch_columns = [c for c in columns if c is not None]
+ parquet_file = pq.ParquetFile(path)
+ num_rows = parquet_file.metadata.num_rows
+ def _null_gen():
+ while True:
+ yield None
+ def _reader():
+ for batch in parquet_file.iter_batches(batch_size=10_000_000, columns=fetch_columns):
+ nulled_batch = [batch[c].to_numpy(zero_copy_only=False) if c else _null_gen() for c in columns]
+ yield from zip(*nulled_batch)
+ parquet_file.close()
+ yield from tqdm(_reader(), total=num_rows, desc=f"{self.__class__.__name__} read")
def collect_relation_types(
@@ -173,7 +215,7 @@ def collect_entities_by_type(
counters[entity_name] = Counter()
log("Searching for the entities in the edge files...")
- for edgepath in edge_paths:
+ for edgepath in tqdm(edge_paths, desc="edge_paths"):
for lhs_word, rhs_word, rel_word, _weight in edgelist_reader.read(edgepath):
if dynamic_relations or rel_word is None:
rel_id = 0
@@ -219,7 +261,7 @@ def generate_entity_path_files(
entity_storage.prepare()
relation_type_storage.prepare()
- for entity_name, entities in entities_by_type.items():
+ for entity_name, entities in tqdm(entities_by_type.items(), desc="entities_by_type"):
for part in range(entities.num_parts):
log(
f"- Writing count of entity type {entity_name} " f"and partition {part}"
@@ -255,7 +297,7 @@ def generate_edge_path_files(
relation_configs: List[RelationSchema],
dynamic_relations: bool,
edgelist_reader: EdgelistReader,
- n_flush_edges: int = 100000,
+ n_flush_edges: int = 10_000_000,
) -> None:
log(
f"Preparing edge path {edge_path_out}, "
@@ -322,8 +364,8 @@ def generate_edge_path_files(
part_data.clear()
processed = processed + 1
- if processed % 100000 == 0:
- log(f"- Processed {processed} edges so far...")
+ # if processed % 100000 == 0:
+ # log(f"- Processed {processed} edges so far...")
for (lhs_part, rhs_part), part_data in data.items():
if len(part_data) > 0:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment