Last active
March 22, 2023 11:34
-
-
Save drew2a/440455a0099b27bcb7bc1fe329b20d02 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import datetime | |
import logging | |
import math | |
import time | |
from pathlib import Path | |
from typing import Optional | |
from elastic_transport import TransportError | |
from elasticsearch import Elasticsearch | |
from elasticsearch.helpers import bulk | |
from pony.orm import db_session | |
from pony.utils import between | |
from tribler.core.components.component import Component | |
from tribler.core.components.ipv8.ipv8_component import Ipv8Component | |
from tribler.core.components.key.key_component import KeyComponent | |
from tribler.core.components.knowledge.rules.tag_rules_processor import LAST_PROCESSED_TORRENT_ID | |
from tribler.core.components.metadata_store.db.serialization import REGULAR_TORRENT | |
from tribler.core.components.metadata_store.metadata_store_component import MetadataStoreComponent | |
from tribler.core.utilities.tiny_tribler_service import TinyTriblerService | |
from tribler.core.utilities.unicode import hexlify | |
INDEX = 'tribler' | |
ELASTIC_HOST = "http://localhost:9200" | |
_logger = logging.getLogger('Indexer') | |
class TitlesProcessor: | |
def __init__(self, mds, elastic: Elasticsearch): | |
self.logger = logging.getLogger(self.__class__.__name__) | |
self.start_time = time.time() | |
self.batch_size = 10000 | |
self.mds = mds | |
self.elastic = elastic | |
with db_session: | |
self.mds.set_value(LAST_PROCESSED_TORRENT_ID, '0') | |
def start(self): | |
asyncio.get_event_loop().create_task(self.process_batch()) | |
@db_session | |
async def process_batch(self): | |
with db_session: | |
start = int(self.mds.get_value(LAST_PROCESSED_TORRENT_ID, default='0')) | |
max_row_id = self.mds.get_max_rowid() | |
end = min(start + self.batch_size, max_row_id) | |
percent = 100 * start // max_row_id | |
elapsed = math.floor(time.time() - self.start_time) | |
estimated = elapsed * max_row_id // end | |
estimated_str = datetime.timedelta(seconds=estimated) | |
remaining_str = datetime.timedelta(seconds=estimated - elapsed) | |
self.logger.info( | |
f'Processing batch [{start}..{end}] of {max_row_id}, {percent}%. ' | |
f'Remaining: {remaining_str}. Estimated: {estimated_str}' | |
) | |
batch = self.mds.TorrentMetadata.select( | |
lambda t: between(t.rowid, start, end) and t.metadata_type == REGULAR_TORRENT | |
) | |
# extracted = 0 | |
actions = [] | |
for torrent in batch: | |
if doc := self.create_doc(torrent.infohash, torrent.title, torrent.id_): | |
actions.append(doc) | |
try: | |
successful, errors = bulk(self.elastic, actions, stats_only=True) | |
except TransportError as e: | |
self.logger.exception(e) | |
self.mds.set_value(LAST_PROCESSED_TORRENT_ID, str(end)) | |
self.logger.info(f'Successful: {successful}, errors: {errors}') | |
is_finished = end >= max_row_id | |
if is_finished: | |
self.logger.info('Finish batch processing, cancel process_batch task') | |
else: | |
asyncio.get_event_loop().create_task(self.process_batch()) | |
@staticmethod | |
def create_doc(infohash: bytes, name: Optional[str], metadata_id: int): | |
if not name: | |
return None | |
infohash_str = hexlify(infohash) | |
doc = { | |
'_index': INDEX, | |
'title': name, | |
'infohash': infohash_str, | |
'metadata_id': metadata_id | |
} | |
return doc | |
class TitleProcessorComponent(Component): | |
async def run(self): | |
await super().run() | |
mds_component = await self.require_component(MetadataStoreComponent) | |
elastic_client = Elasticsearch(ELASTIC_HOST, request_timeout=120) | |
titles_processor = TitlesProcessor( | |
mds=mds_component.mds, | |
elastic=elastic_client | |
) | |
titles_processor.start() | |
async def shutdown(self): | |
await super().shutdown() | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(name)s(%(lineno)d) - %(message)s') | |
elastic_transport_logger = logging.getLogger('elastic_transport') | |
elastic_transport_logger.setLevel(logging.WARNING) | |
service = TinyTriblerService( | |
state_dir=Path('./.Tribler'), | |
components=[ | |
TitleProcessorComponent(), MetadataStoreComponent(), KeyComponent(), Ipv8Component() | |
] | |
) | |
service.run(fragile=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment