Created
April 11, 2023 02:41
-
-
Save napsternxg/5a32bba213181c353ac9ef40d8c786d8 to your computer and use it in GitHub Desktop.
Improve edgelist processing speed of PyTorchBiggraph-Pytorch using pyarrow parquet reader.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torchbiggraph/converters/importers.py b/torchbiggraph/converters/importers.py | |
index fa84bc6..765e9fa 100644 | |
--- a/torchbiggraph/converters/importers.py | |
+++ b/torchbiggraph/converters/importers.py | |
@@ -28,6 +28,7 @@ from torchbiggraph.graph_storages import ( | |
RELATION_TYPE_STORAGES, | |
) | |
from torchbiggraph.types import UNPARTITIONED | |
+from tqdm import tqdm | |
def log(msg): | |
@@ -58,7 +59,7 @@ class TSVEdgelistReader(EdgelistReader): | |
def read(self, path: Path): | |
with path.open("rt") as tf: | |
- for line_num, line in enumerate(tf, start=1): | |
+ for line_num, line in tqdm(enumerate(tf, start=1), desc=f"{self.__class__.__name__} read"): | |
words = line.split(self.delimiter) | |
try: | |
lhs_word = words[self.lhs_col] | |
@@ -76,7 +77,7 @@ class TSVEdgelistReader(EdgelistReader): | |
) from None | |
-class ParquetEdgelistReader(EdgelistReader): | |
+class OldParquetEdgelistReader(EdgelistReader): | |
def __init__( | |
self, | |
lhs_col: str, | |
@@ -105,7 +106,8 @@ class ParquetEdgelistReader(EdgelistReader): | |
with path.open("rb") as tf: | |
columns = [self.lhs_col, self.rhs_col, self.rel_col, self.weight_col] | |
fetch_columns = [c for c in columns if c is not None] | |
- for row in parquet.reader(tf, columns=fetch_columns): | |
+ fetch_col_idx = [i for i, c in enumerate(columns) if c is not None] | |
+ for row in tqdm(parquet.reader(tf, columns=fetch_columns), desc=f"{self.__class__.__name__} read"): | |
offset = 0 | |
ret = [] | |
for c in columns: | |
@@ -116,6 +118,46 @@ class ParquetEdgelistReader(EdgelistReader): | |
ret.append(None) | |
yield tuple(ret) | |
+ | |
+ | |
+class ParquetEdgelistReader(EdgelistReader): | |
+ def __init__( | |
+ self, | |
+ lhs_col: str, | |
+ rhs_col: str, | |
+ rel_col: Optional[str], | |
+ weight_col: Optional[str], | |
+ ): | |
+ """Reads edgelists from a Parquet file. | |
+ | |
+ col arguments can either be the column name or the offset of the col. | |
+ """ | |
+ self.lhs_col = lhs_col | |
+ self.rhs_col = rhs_col | |
+ self.rel_col = rel_col | |
+ self.weight_col = weight_col | |
+ | |
+ def read(self, path: Path): | |
+ try: | |
+ import pyarrow.parquet as pq | |
+ except ImportError as e: | |
+ raise ImportError( | |
+ f"{e}. HINT: You can install pyarrow by running " | |
+ "'pip install pyarrow'" | |
+ ) | |
+ columns = [self.lhs_col, self.rhs_col, self.rel_col, self.weight_col] | |
+ fetch_columns = [c for c in columns if c is not None] | |
+ parquet_file = pq.ParquetFile(path) | |
+ num_rows = parquet_file.metadata.num_rows | |
+ def _null_gen(): | |
+ while True: | |
+ yield None | |
+ def _reader(): | |
+ for batch in parquet_file.iter_batches(batch_size=10_000_000, columns=fetch_columns): | |
+ nulled_batch = [batch[c].to_numpy(zero_copy_only=False) if c else _null_gen() for c in columns] | |
+ yield from zip(*nulled_batch) | |
+ parquet_file.close() | |
+ yield from tqdm(_reader(), total=num_rows, desc=f"{self.__class__.__name__} read") | |
def collect_relation_types( | |
@@ -173,7 +215,7 @@ def collect_entities_by_type( | |
counters[entity_name] = Counter() | |
log("Searching for the entities in the edge files...") | |
- for edgepath in edge_paths: | |
+ for edgepath in tqdm(edge_paths, desc="edge_paths"): | |
for lhs_word, rhs_word, rel_word, _weight in edgelist_reader.read(edgepath): | |
if dynamic_relations or rel_word is None: | |
rel_id = 0 | |
@@ -219,7 +261,7 @@ def generate_entity_path_files( | |
entity_storage.prepare() | |
relation_type_storage.prepare() | |
- for entity_name, entities in entities_by_type.items(): | |
+ for entity_name, entities in tqdm(entities_by_type.items(), desc="entities_by_type"): | |
for part in range(entities.num_parts): | |
log( | |
f"- Writing count of entity type {entity_name} " f"and partition {part}" | |
@@ -255,7 +297,7 @@ def generate_edge_path_files( | |
relation_configs: List[RelationSchema], | |
dynamic_relations: bool, | |
edgelist_reader: EdgelistReader, | |
- n_flush_edges: int = 100000, | |
+ n_flush_edges: int = 10_000_000, | |
) -> None: | |
log( | |
f"Preparing edge path {edge_path_out}, " | |
@@ -322,8 +364,8 @@ def generate_edge_path_files( | |
part_data.clear() | |
processed = processed + 1 | |
- if processed % 100000 == 0: | |
- log(f"- Processed {processed} edges so far...") | |
+ # if processed % 100000 == 0: | |
+ # log(f"- Processed {processed} edges so far...") | |
for (lhs_part, rhs_part), part_data in data.items(): | |
if len(part_data) > 0: |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment