-
-
Save anthonylouisbsb/5e97040cc44e2f87bede2c2ab0b20830 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ShannonDBClient: | |
def __init__(self): | |
self.zk_host = None | |
self.node_manager = None | |
self.data_centers = None | |
self.hash_system = None | |
self._initial_node_seeds = None | |
self._all_shards_first_ids = None | |
self.record_ids_src = None | |
self.record_ids_dict = None | |
self._all_shards = None | |
self.cluster_config = None | |
self._config = None | |
self.next_column_creation_node_index = 0 | |
self.connections = {} | |
self.connections_lock = asyncio.Lock() | |
self.connection_pool_size = None | |
self.general_client_config = None | |
@classmethod | |
async def create_client(cls, initial_node_seeds, zk_host=None, | |
insertion_type="kafka", connections_per_node=1, | |
origin="SHANNONDBMOCKORIGIN", config_file_path=None): | |
self = ShannonDBClient() | |
self.general_client_config = load_shannondb_client_general_configs(config_file_path) | |
self.connection_pool_size = self.general_client_config.get("connection_pool_size", 2) | |
self.record_ids_src = self.general_client_config.get("record_ids_src", "rec_ids") | |
if not zk_host: | |
zk_host = self.general_client_config['zookeeper']['hosts'] | |
self._initial_node_seeds = Node.from_list_tuple(initial_node_seeds) | |
self.zk_host = zk_host | |
self.origin = origin | |
self._public_chunk_size = 0 | |
self._connections_per_node = connections_per_node | |
self.data_seeds = list() | |
self.node_id_hostname_map = dict() | |
await self._config_client() | |
self.data_centers = { | |
k.split(".")[2]: [int(n) for n in | |
self.cluster_config[k].split(",")] | |
for k in self.cluster_config if | |
k.startswith("shannondb.datacenter")} | |
self.hash_system = EntityIdHashSystem() | |
aggregation.setup_aggregation(self) | |
setup_insertion(self, insertion_type=insertion_type) | |
self.node_manager = NodeManager(self.zk_host, self) | |
return self | |
def get_data_center_by_node(self, node): | |
for (data_center, nodes) in self.data_centers.items(): | |
if node.id in nodes: | |
return data_center | |
raise "Could not find data center for node %s" % node | |
async def _load_cluster_config(self, seeds): | |
response = {} | |
def callback(result): | |
if result['status'] == 'error': | |
raise exceptions.SearchException(result.get('code'), | |
result.get('msg')) | |
else: | |
response["cluster_config"] = result["cluster_config"] | |
for seed in seeds: | |
try: | |
logging.info("fetching cluster config from {}".format(seed)) | |
except: | |
logging.warning("Error establishing connecting seed %s.", seed, | |
exc_info=True) | |
continue | |
try: | |
message = {"type": "cluster_config"} | |
tmp_result = await self.execute_single_node(seed, | |
CLUSTER_INFO_COMMAND, | |
message) | |
callback(tmp_result) | |
if response: | |
return response["cluster_config"] | |
except Exception as error: | |
logging.warning("Error waiting cluster info.", exc_info=True) | |
logging.debug(error) | |
raise exceptions.ConnectionException( | |
message="Could not fetch cluster config from seeds %s" % seeds) | |
def get_client_config(self): | |
return self._config | |
async def _config_client(self): | |
self._config = configparser.ConfigParser() | |
self._config.add_section("client") | |
self._config.set("client", "zk_host", self.zk_host) | |
if not self.cluster_config: | |
self.cluster_config = await self._load_cluster_config( | |
self._initial_node_seeds) | |
for key_, value_ in self.cluster_config.items(): | |
value = value_ | |
if not isinstance(value_, str): | |
value = str(value_) | |
self._config.set("client", key_, value) | |
node_id = 0 | |
for data_node in self.cluster_config["shannondb.nodes"]: | |
hostname_port = data_node if ":" in data_node \ | |
else "%s:%i" % (data_node, self.cluster_config["port"]) | |
if hostname_port not in self.data_seeds: | |
self.data_seeds.append(hostname_port) | |
self.node_id_hostname_map[node_id] = hostname_port | |
node_id += 1 | |
async def execute_on_cluster(self, worker, command, timeout=None): | |
message = "%s %s" % (worker, json.dumps(command)) | |
compressed_message = bytearray(snappy.compress(message)) | |
if timeout is None: | |
timeout = self.general_client_config['global_timeout'] | |
online_nodes = set(self.node_manager.get_online_nodes()) | |
futures = [] | |
for node in online_nodes: | |
futures.append( | |
self._send_shannondb_message(node, compressed_message)) | |
if timeout == -1: | |
timeout = None | |
if futures: | |
done, _ = await asyncio.wait(futures, timeout=timeout) | |
return online_nodes, done | |
else: | |
return online_nodes, [] | |
async def execute_single_node(self, node, worker, command, timeout=None): | |
message = "%s %s" % (worker, json.dumps(command)) | |
compressed_message = bytearray(snappy.compress(message)) | |
if timeout is None: | |
timeout = self.general_client_config['global_timeout'] | |
if timeout is -1: | |
result = await self._send_shannondb_message(node, | |
compressed_message) | |
else: | |
try: | |
result = await asyncio.wait_for( | |
self._send_shannondb_message(node, compressed_message), | |
timeout=timeout) | |
except asyncio.TimeoutError: | |
raise Exception("Timeout occurred while processing this query") | |
return result | |
async def execute_single_node_id(self, node, worker, command): | |
"""Returns the node_id alongside the result from server""" | |
try: | |
response = await self.execute_single_node(node, worker, command) | |
return node.id, response | |
except Exception as err: | |
logging.error("Error when getting result/score page: " | |
"{}".format(str(err))) | |
return node.id, None | |
async def create_connection(self, node): | |
loop = asyncio.get_event_loop() | |
_, protocol = await loop.create_connection(ShannonDBClientProtocol, | |
*node.address) | |
return protocol | |
async def get_usage_info(self, group): | |
responses = [] | |
up_nodes = self.get_online_nodes() | |
query = { | |
"billing": { | |
"group": group | |
} | |
} | |
_, responses = await self.execute_on_cluster(ADMIN_COMMAND, query) | |
final_response = {} | |
offline_shards = 0 | |
for response in responses: | |
try: | |
# If shard fails to respond, return empty dict | |
response = response.result() | |
response_msg = json.loads(response['msg']) | |
result = response_msg["result"] | |
except: | |
offline_shards += 1 | |
logging.info("Found an offline shard billing_info request") | |
continue | |
for key, value in result.items(): | |
if key not in final_response: | |
final_response[key] = value | |
else: | |
final_response[key]["value"] += value["value"] | |
if offline_shards == 0: | |
final_response["status"] = "ok" | |
else: | |
final_response["status"] = "{} shards failed to respond".format( | |
offline_shards) | |
return final_response | |
async def remove_database_data(self, query, node_id=-1, **kwargs): | |
response = { | |
"msg": "" | |
} | |
def format_message(result): | |
if result['status'] == 'error': | |
raise exceptions.SearchException(result.get('code'), | |
result.get('msg')) | |
elif ("msg" in result and result["msg"] == "Columns removed") or \ | |
response["msg"] == "": | |
response["msg"] = result["msg"] | |
response["status"] = result["status"] | |
retry_count = 0 | |
while True: | |
try: | |
up_nodes = self.get_online_nodes() | |
if node_id >= 0: | |
result = self.execute_single_node(node_id, REMOVE_COMMAND, | |
query) | |
format_message(await result) | |
else: | |
for node in up_nodes: | |
result = self.execute_single_node(node.id, | |
REMOVE_COMMAND, query) | |
format_message(await result) | |
break | |
except Exception as error: | |
if retry_count < self.get_max_retry_count(): | |
logging.info( | |
"Got error when executing command, trying again... error: {}", | |
error) | |
self.reconnect() | |
retry_count += 1 | |
else: | |
raise error | |
return response | |
async def _send_shannondb_message(self, node, compressed_message): | |
if isinstance(node, int): | |
node = self.get_node_by_id(node) | |
if node in self.connections: | |
node_connections_index, node_connections = self.connections[node] | |
else: | |
async with self.connections_lock: | |
if node in self.connections: | |
node_connections_index, node_connections = self.connections[ | |
node] | |
else: | |
node_connections = [await self.create_connection( | |
node)] * self.connection_pool_size | |
node_connections_index = 0 | |
self.connections[node] = [0, node_connections] | |
self.connections[node][0] = (node_connections_index + 1) % self.connection_pool_size | |
try: | |
connection = node_connections[node_connections_index] | |
except IndexError: | |
connection = await self.create_connection(node) | |
while True: | |
try: | |
response = await connection.send_message(compressed_message) | |
break | |
except asyncio.CancelledError: | |
logging.info( | |
"CancelledError: Error while sending request to ShannonDB") | |
except Exception: | |
if connection: | |
try: | |
connection.close() | |
except Exception as e: | |
logging.exception(e) | |
try: | |
connection = await self.create_connection(node) | |
node_connections[node_connections_index] = connection | |
except Exception as e: | |
raise e | |
return response | |
def get_node_by_id(self, node_id): | |
node_ids = [n for n in self.node_manager.get_online_nodes() if | |
n.id == node_id] | |
return node_ids[0] | |
async def delete_command(self, query_): | |
return await delete_command(self, query_) | |
async def update_command(self, query_): | |
return await update_command(self, query_) | |
async def execute_count(self, query_, is_table_type_dimension=False): | |
return await execute_count(self, query_, is_table_type_dimension) | |
async def execute_aggregation_query(self, query, | |
order_by_aggregation=None, | |
get_value_hashes_topk=False, | |
get_value_hashes_only=False, | |
lazy_translation=False, | |
**kwargs): | |
return await aggregation.execute_aggregation_query( | |
self, query, order_by_aggregation=order_by_aggregation, | |
get_value_hashes_topk=get_value_hashes_topk, | |
get_value_hashes_only=get_value_hashes_only, | |
lazy_translation=lazy_translation, **kwargs) | |
async def insertion_from_query(self, query): | |
return await insertion_where.insertion_from_query(self, query) | |
async def is_insertion_from_query_command_finished(self, query): | |
return await insertion_where.is_insertion_from_query_command_finished(self, query) | |
async def execute_aggregation(self, query, **kwargs): | |
return await aggregation.execute_aggregation( | |
self, query, **kwargs) | |
async def get_shards_for_column(self, column): | |
return await get_shards_for_column(self, column) | |
async def generate_shannondb_ids(self, total_to_gen, groups): | |
return await generate_shannondb_ids(self, total_to_gen, groups) | |
def _get_all_shards(self): | |
if not self._all_shards: | |
self._all_shards = sorted( | |
self.get_shards_by_registered_node(), | |
key=attrgetter('first_id')) | |
return self._all_shards | |
def _get_public_chunk_size(self): | |
if not self._public_chunk_size: | |
all_shards = self._get_all_shards() | |
self._public_chunk_size = all_shards[-1].first_id + all_shards[ | |
-1].capacity | |
return self._public_chunk_size | |
def _get_all_shards_first_ids(self): | |
if not self._all_shards_first_ids: | |
self._all_shards_first_ids = [s.first_id for s in | |
self._get_all_shards()] | |
return self._all_shards_first_ids | |
def get_shard_id(self, id_): | |
original_id = id_ | |
all_shards = self._get_all_shards() | |
try: | |
public_chunk_size = self._get_public_chunk_size() | |
id_ %= public_chunk_size | |
except: | |
# It may happen that public_chunk_size is 0 | |
pass | |
shards_first_ids = self._get_all_shards_first_ids() | |
i = bisect.bisect_left(shards_first_ids, id_) - 1 | |
if i == -1: | |
i = 0 | |
elif all_shards[i].first_id + all_shards[i].capacity == id_: | |
i += 1 | |
shard = all_shards[i] | |
if shard.first_id > id_ or shard.first_id + shard.capacity <= id_: | |
raise Exception("Couldn't find shard for ID %s" % original_id) | |
return shard | |
def get_auto_generated_recid(self, shard): | |
return get_auto_generated_recid(self, shard) | |
def insertion_command(self, param_list, is_table_type_dimension=False): | |
return insertion_command(self, param_list, is_table_type_dimension) | |
async def node_insertion(self, param_list, is_table_type_dimension=False): | |
shard_to_bytes = {} | |
insertion_command(self, param_list, is_table_type_dimension, | |
shard_to_bytes=shard_to_bytes) | |
for shard, bytearray in shard_to_bytes.items(): | |
result = await send_insertion_command_to_shard(self, shard, | |
messages_bytes=bytearray) | |
if not result: | |
raise Exception('Error during direct insertion') | |
return {'status': 'OK'} | |
def force_kafka_flush(self): | |
return force_kafka_flush(self) | |
async def close(self): | |
self.force_kafka_flush() | |
write_record_ids_data(self) | |
self.node_manager.close_zookeeper() | |
def get_node_by_shard(self, shard): | |
return self.node_manager.get_node_by_shard(shard) | |
def get_replica_node_by_shard(self, shard): | |
return self.node_manager.get_replica_node_by_shard(shard) | |
def get_shards_by_registered_node(self): | |
return self.node_manager.shards_by_registered_node | |
async def translate_hash_to_string(self, column_to_hashes, groups, column_info=None): | |
return await translate_hash_to_string(self, column_to_hashes, groups, column_info) | |
async def create_triggers_in_shannondb(self, database_path, triggers_list): | |
return await create_triggers_in_shannodb(self, database_path, triggers_list) | |
async def update_relation(self, query): | |
return await update_relation(self, query) | |
async def create_single_column(self, query, validated=False): | |
return await create_single_column(self, query, validated) | |
async def create_database(self, query): | |
return await create_database(self, query) | |
async def send_dictionary_message(self, values, column_name, groups): | |
node_id = get_node_to_send_column_command(self) | |
logging.info("sending column creation command to node: %s", node_id.id) | |
query = { | |
"values": values, | |
"column_name": column_name, | |
"groups": groups | |
} | |
result = await self.execute_single_node(node_id, DICTIONARY_ADD_COMMAND, | |
query) | |
if result['status'] == 'error': | |
raise exceptions.ColumnException(result['code'], result.get( | |
'msg', '')) | |
async def create_columns_in_bulk(self, columns): | |
return await create_columns_in_bulk(self, columns) | |
async def delete_columns_in_bulk(self, columns=None, group=None, **kwargs): | |
return await delete_columns_in_bulk(self, columns, group, **kwargs) | |
async def get_list_of_columns(self): | |
return await get_list_of_columns(self) | |
def execute_result_or_score(self, worker, query, columns=None, | |
page_size=128, should_prepare=True, | |
checkpoint=None, order_by=None, | |
is_table_type_dimension=False | |
): | |
return execute_result_or_score(self, worker, query, | |
columns=columns, | |
page_size=page_size, | |
should_prepare=should_prepare, | |
checkpoint=checkpoint, | |
order_by=order_by, | |
is_table_type_dimension=is_table_type_dimension | |
) | |
async def execute_event_count(self, query): | |
return await execute_event_count(self, query) | |
def get_max_retry_count(self): | |
try: | |
return self.general_client_config.get("max_retry_count", 3) | |
except: | |
return 3 | |
async def reconnect(self): | |
await self._config_client() | |
def get_preferred_datacenter_for_column_creation(self): | |
try: | |
return self.general_client_config.get("preferred_datacenter_for_column_creation", | |
'DC01') | |
except: | |
return 'DC01' | |
def get_registered_nodes(self): | |
return self.node_manager.get_registered_nodes() | |
def get_online_nodes(self): | |
return self.node_manager.get_online_nodes() | |
def get_shard_for_entity_id(self, entity_id): | |
return self.hash_system.get_shard(entity_id) | |
async def sql_query(self, sql, database_path, query_flag, column_list=[]): | |
query = { | |
"sql": sql, | |
"columns": column_list | |
} | |
nodes, responses = await self.execute_on_cluster("sql", query) | |
final_response = await get_sql_query_response(self, database_path, nodes, responses, query_flag) | |
return final_response | |
async def get_entity_ids(self, entity_ids, groups, column_name, | |
also_add=False): | |
shard_to_entity_id = {} | |
for entity_id in entity_ids: | |
shard_id = self.get_shard_for_entity_id(entity_id) | |
if shard_id in shard_to_entity_id: | |
shard_to_entity_id[shard_id].append(entity_id) | |
else: | |
shard_to_entity_id[shard_id] = [entity_id] | |
node_to_command = {} | |
for shard_id, entities in shard_to_entity_id.items(): | |
shard = Shard.get_shard_from_shannon_id(shard_id) | |
node_id = self.get_node_by_shard(shard) | |
if node_id in node_to_command: | |
node_to_command[node_id]["shards"].update({shard_id: entities}) | |
else: | |
node_to_command[node_id] = {"shards": {shard_id: entities}, | |
"groups": groups, | |
"column_name": column_name, | |
"type": "get_entities", | |
"also_add": also_add} | |
tasks = [] | |
for node, command in node_to_command.items(): | |
task = asyncio.ensure_future( | |
self.execute_single_node(node.id, ENTITYID_INSERT_COMMAND, | |
command)) | |
tasks.append(task) | |
responses = await asyncio.gather(*tasks, return_exceptions=True) | |
final_response = {} | |
new_entities = [] | |
for response in responses: | |
if isinstance(response, Exception): | |
raise exceptions.IndexException(1, "Some nodes are unavailable " | |
"now, error %s" % response) | |
if "status" in response: | |
status = response["status"] | |
if "error" in status: | |
raise exceptions.IndexException(response.get('code'), | |
response.get('msg')) | |
if "result" in response: | |
result = response["result"] | |
entities_map = result["entities"] | |
new = result["new-entities"] | |
final_response.update(entities_map) | |
new_entities.extend(new) | |
if not final_response: | |
for entity in entity_ids: | |
final_response[entity] = "UNSET" | |
return final_response, new_entities |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment