Skip to content

Instantly share code, notes, and snippets.

@anthonylouisbsb
Last active June 10, 2020 12:09
Show Gist options
  • Save anthonylouisbsb/5e97040cc44e2f87bede2c2ab0b20830 to your computer and use it in GitHub Desktop.
Save anthonylouisbsb/5e97040cc44e2f87bede2c2ab0b20830 to your computer and use it in GitHub Desktop.
class ShannonDBClient:
def __init__(self):
self.zk_host = None
self.node_manager = None
self.data_centers = None
self.hash_system = None
self._initial_node_seeds = None
self._all_shards_first_ids = None
self.record_ids_src = None
self.record_ids_dict = None
self._all_shards = None
self.cluster_config = None
self._config = None
self.next_column_creation_node_index = 0
self.connections = {}
self.connections_lock = asyncio.Lock()
self.connection_pool_size = None
self.general_client_config = None
@classmethod
async def create_client(cls, initial_node_seeds, zk_host=None,
insertion_type="kafka", connections_per_node=1,
origin="SHANNONDBMOCKORIGIN", config_file_path=None):
self = ShannonDBClient()
self.general_client_config = load_shannondb_client_general_configs(config_file_path)
self.connection_pool_size = self.general_client_config.get("connection_pool_size", 2)
self.record_ids_src = self.general_client_config.get("record_ids_src", "rec_ids")
if not zk_host:
zk_host = self.general_client_config['zookeeper']['hosts']
self._initial_node_seeds = Node.from_list_tuple(initial_node_seeds)
self.zk_host = zk_host
self.origin = origin
self._public_chunk_size = 0
self._connections_per_node = connections_per_node
self.data_seeds = list()
self.node_id_hostname_map = dict()
await self._config_client()
self.data_centers = {
k.split(".")[2]: [int(n) for n in
self.cluster_config[k].split(",")]
for k in self.cluster_config if
k.startswith("shannondb.datacenter")}
self.hash_system = EntityIdHashSystem()
aggregation.setup_aggregation(self)
setup_insertion(self, insertion_type=insertion_type)
self.node_manager = NodeManager(self.zk_host, self)
return self
def get_data_center_by_node(self, node):
for (data_center, nodes) in self.data_centers.items():
if node.id in nodes:
return data_center
raise "Could not find data center for node %s" % node
async def _load_cluster_config(self, seeds):
response = {}
def callback(result):
if result['status'] == 'error':
raise exceptions.SearchException(result.get('code'),
result.get('msg'))
else:
response["cluster_config"] = result["cluster_config"]
for seed in seeds:
try:
logging.info("fetching cluster config from {}".format(seed))
except:
logging.warning("Error establishing connecting seed %s.", seed,
exc_info=True)
continue
try:
message = {"type": "cluster_config"}
tmp_result = await self.execute_single_node(seed,
CLUSTER_INFO_COMMAND,
message)
callback(tmp_result)
if response:
return response["cluster_config"]
except Exception as error:
logging.warning("Error waiting cluster info.", exc_info=True)
logging.debug(error)
raise exceptions.ConnectionException(
message="Could not fetch cluster config from seeds %s" % seeds)
def get_client_config(self):
return self._config
async def _config_client(self):
self._config = configparser.ConfigParser()
self._config.add_section("client")
self._config.set("client", "zk_host", self.zk_host)
if not self.cluster_config:
self.cluster_config = await self._load_cluster_config(
self._initial_node_seeds)
for key_, value_ in self.cluster_config.items():
value = value_
if not isinstance(value_, str):
value = str(value_)
self._config.set("client", key_, value)
node_id = 0
for data_node in self.cluster_config["shannondb.nodes"]:
hostname_port = data_node if ":" in data_node \
else "%s:%i" % (data_node, self.cluster_config["port"])
if hostname_port not in self.data_seeds:
self.data_seeds.append(hostname_port)
self.node_id_hostname_map[node_id] = hostname_port
node_id += 1
async def execute_on_cluster(self, worker, command, timeout=None):
message = "%s %s" % (worker, json.dumps(command))
compressed_message = bytearray(snappy.compress(message))
if timeout is None:
timeout = self.general_client_config['global_timeout']
online_nodes = set(self.node_manager.get_online_nodes())
futures = []
for node in online_nodes:
futures.append(
self._send_shannondb_message(node, compressed_message))
if timeout == -1:
timeout = None
if futures:
done, _ = await asyncio.wait(futures, timeout=timeout)
return online_nodes, done
else:
return online_nodes, []
async def execute_single_node(self, node, worker, command, timeout=None):
message = "%s %s" % (worker, json.dumps(command))
compressed_message = bytearray(snappy.compress(message))
if timeout is None:
timeout = self.general_client_config['global_timeout']
if timeout is -1:
result = await self._send_shannondb_message(node,
compressed_message)
else:
try:
result = await asyncio.wait_for(
self._send_shannondb_message(node, compressed_message),
timeout=timeout)
except asyncio.TimeoutError:
raise Exception("Timeout occurred while processing this query")
return result
async def execute_single_node_id(self, node, worker, command):
"""Returns the node_id alongside the result from server"""
try:
response = await self.execute_single_node(node, worker, command)
return node.id, response
except Exception as err:
logging.error("Error when getting result/score page: "
"{}".format(str(err)))
return node.id, None
async def create_connection(self, node):
loop = asyncio.get_event_loop()
_, protocol = await loop.create_connection(ShannonDBClientProtocol,
*node.address)
return protocol
async def get_usage_info(self, group):
responses = []
up_nodes = self.get_online_nodes()
query = {
"billing": {
"group": group
}
}
_, responses = await self.execute_on_cluster(ADMIN_COMMAND, query)
final_response = {}
offline_shards = 0
for response in responses:
try:
# If shard fails to respond, return empty dict
response = response.result()
response_msg = json.loads(response['msg'])
result = response_msg["result"]
except:
offline_shards += 1
logging.info("Found an offline shard billing_info request")
continue
for key, value in result.items():
if key not in final_response:
final_response[key] = value
else:
final_response[key]["value"] += value["value"]
if offline_shards == 0:
final_response["status"] = "ok"
else:
final_response["status"] = "{} shards failed to respond".format(
offline_shards)
return final_response
async def remove_database_data(self, query, node_id=-1, **kwargs):
response = {
"msg": ""
}
def format_message(result):
if result['status'] == 'error':
raise exceptions.SearchException(result.get('code'),
result.get('msg'))
elif ("msg" in result and result["msg"] == "Columns removed") or \
response["msg"] == "":
response["msg"] = result["msg"]
response["status"] = result["status"]
retry_count = 0
while True:
try:
up_nodes = self.get_online_nodes()
if node_id >= 0:
result = self.execute_single_node(node_id, REMOVE_COMMAND,
query)
format_message(await result)
else:
for node in up_nodes:
result = self.execute_single_node(node.id,
REMOVE_COMMAND, query)
format_message(await result)
break
except Exception as error:
if retry_count < self.get_max_retry_count():
logging.info(
"Got error when executing command, trying again... error: {}",
error)
self.reconnect()
retry_count += 1
else:
raise error
return response
async def _send_shannondb_message(self, node, compressed_message):
if isinstance(node, int):
node = self.get_node_by_id(node)
if node in self.connections:
node_connections_index, node_connections = self.connections[node]
else:
async with self.connections_lock:
if node in self.connections:
node_connections_index, node_connections = self.connections[
node]
else:
node_connections = [await self.create_connection(
node)] * self.connection_pool_size
node_connections_index = 0
self.connections[node] = [0, node_connections]
self.connections[node][0] = (node_connections_index + 1) % self.connection_pool_size
try:
connection = node_connections[node_connections_index]
except IndexError:
connection = await self.create_connection(node)
while True:
try:
response = await connection.send_message(compressed_message)
break
except asyncio.CancelledError:
logging.info(
"CancelledError: Error while sending request to ShannonDB")
except Exception:
if connection:
try:
connection.close()
except Exception as e:
logging.exception(e)
try:
connection = await self.create_connection(node)
node_connections[node_connections_index] = connection
except Exception as e:
raise e
return response
def get_node_by_id(self, node_id):
node_ids = [n for n in self.node_manager.get_online_nodes() if
n.id == node_id]
return node_ids[0]
async def delete_command(self, query_):
return await delete_command(self, query_)
async def update_command(self, query_):
return await update_command(self, query_)
async def execute_count(self, query_, is_table_type_dimension=False):
return await execute_count(self, query_, is_table_type_dimension)
async def execute_aggregation_query(self, query,
order_by_aggregation=None,
get_value_hashes_topk=False,
get_value_hashes_only=False,
lazy_translation=False,
**kwargs):
return await aggregation.execute_aggregation_query(
self, query, order_by_aggregation=order_by_aggregation,
get_value_hashes_topk=get_value_hashes_topk,
get_value_hashes_only=get_value_hashes_only,
lazy_translation=lazy_translation, **kwargs)
async def insertion_from_query(self, query):
return await insertion_where.insertion_from_query(self, query)
async def is_insertion_from_query_command_finished(self, query):
return await insertion_where.is_insertion_from_query_command_finished(self, query)
async def execute_aggregation(self, query, **kwargs):
return await aggregation.execute_aggregation(
self, query, **kwargs)
async def get_shards_for_column(self, column):
return await get_shards_for_column(self, column)
async def generate_shannondb_ids(self, total_to_gen, groups):
return await generate_shannondb_ids(self, total_to_gen, groups)
def _get_all_shards(self):
if not self._all_shards:
self._all_shards = sorted(
self.get_shards_by_registered_node(),
key=attrgetter('first_id'))
return self._all_shards
def _get_public_chunk_size(self):
if not self._public_chunk_size:
all_shards = self._get_all_shards()
self._public_chunk_size = all_shards[-1].first_id + all_shards[
-1].capacity
return self._public_chunk_size
def _get_all_shards_first_ids(self):
if not self._all_shards_first_ids:
self._all_shards_first_ids = [s.first_id for s in
self._get_all_shards()]
return self._all_shards_first_ids
def get_shard_id(self, id_):
original_id = id_
all_shards = self._get_all_shards()
try:
public_chunk_size = self._get_public_chunk_size()
id_ %= public_chunk_size
except:
# It may happen that public_chunk_size is 0
pass
shards_first_ids = self._get_all_shards_first_ids()
i = bisect.bisect_left(shards_first_ids, id_) - 1
if i == -1:
i = 0
elif all_shards[i].first_id + all_shards[i].capacity == id_:
i += 1
shard = all_shards[i]
if shard.first_id > id_ or shard.first_id + shard.capacity <= id_:
raise Exception("Couldn't find shard for ID %s" % original_id)
return shard
def get_auto_generated_recid(self, shard):
return get_auto_generated_recid(self, shard)
def insertion_command(self, param_list, is_table_type_dimension=False):
return insertion_command(self, param_list, is_table_type_dimension)
async def node_insertion(self, param_list, is_table_type_dimension=False):
shard_to_bytes = {}
insertion_command(self, param_list, is_table_type_dimension,
shard_to_bytes=shard_to_bytes)
for shard, bytearray in shard_to_bytes.items():
result = await send_insertion_command_to_shard(self, shard,
messages_bytes=bytearray)
if not result:
raise Exception('Error during direct insertion')
return {'status': 'OK'}
def force_kafka_flush(self):
return force_kafka_flush(self)
async def close(self):
self.force_kafka_flush()
write_record_ids_data(self)
self.node_manager.close_zookeeper()
def get_node_by_shard(self, shard):
return self.node_manager.get_node_by_shard(shard)
def get_replica_node_by_shard(self, shard):
return self.node_manager.get_replica_node_by_shard(shard)
def get_shards_by_registered_node(self):
return self.node_manager.shards_by_registered_node
async def translate_hash_to_string(self, column_to_hashes, groups, column_info=None):
return await translate_hash_to_string(self, column_to_hashes, groups, column_info)
async def create_triggers_in_shannondb(self, database_path, triggers_list):
return await create_triggers_in_shannodb(self, database_path, triggers_list)
async def update_relation(self, query):
return await update_relation(self, query)
async def create_single_column(self, query, validated=False):
return await create_single_column(self, query, validated)
async def create_database(self, query):
return await create_database(self, query)
async def send_dictionary_message(self, values, column_name, groups):
node_id = get_node_to_send_column_command(self)
logging.info("sending column creation command to node: %s", node_id.id)
query = {
"values": values,
"column_name": column_name,
"groups": groups
}
result = await self.execute_single_node(node_id, DICTIONARY_ADD_COMMAND,
query)
if result['status'] == 'error':
raise exceptions.ColumnException(result['code'], result.get(
'msg', ''))
async def create_columns_in_bulk(self, columns):
return await create_columns_in_bulk(self, columns)
async def delete_columns_in_bulk(self, columns=None, group=None, **kwargs):
return await delete_columns_in_bulk(self, columns, group, **kwargs)
async def get_list_of_columns(self):
return await get_list_of_columns(self)
def execute_result_or_score(self, worker, query, columns=None,
page_size=128, should_prepare=True,
checkpoint=None, order_by=None,
is_table_type_dimension=False
):
return execute_result_or_score(self, worker, query,
columns=columns,
page_size=page_size,
should_prepare=should_prepare,
checkpoint=checkpoint,
order_by=order_by,
is_table_type_dimension=is_table_type_dimension
)
async def execute_event_count(self, query):
return await execute_event_count(self, query)
def get_max_retry_count(self):
try:
return self.general_client_config.get("max_retry_count", 3)
except:
return 3
async def reconnect(self):
await self._config_client()
def get_preferred_datacenter_for_column_creation(self):
try:
return self.general_client_config.get("preferred_datacenter_for_column_creation",
'DC01')
except:
return 'DC01'
def get_registered_nodes(self):
return self.node_manager.get_registered_nodes()
def get_online_nodes(self):
return self.node_manager.get_online_nodes()
def get_shard_for_entity_id(self, entity_id):
return self.hash_system.get_shard(entity_id)
async def sql_query(self, sql, database_path, query_flag, column_list=[]):
query = {
"sql": sql,
"columns": column_list
}
nodes, responses = await self.execute_on_cluster("sql", query)
final_response = await get_sql_query_response(self, database_path, nodes, responses, query_flag)
return final_response
async def get_entity_ids(self, entity_ids, groups, column_name,
also_add=False):
shard_to_entity_id = {}
for entity_id in entity_ids:
shard_id = self.get_shard_for_entity_id(entity_id)
if shard_id in shard_to_entity_id:
shard_to_entity_id[shard_id].append(entity_id)
else:
shard_to_entity_id[shard_id] = [entity_id]
node_to_command = {}
for shard_id, entities in shard_to_entity_id.items():
shard = Shard.get_shard_from_shannon_id(shard_id)
node_id = self.get_node_by_shard(shard)
if node_id in node_to_command:
node_to_command[node_id]["shards"].update({shard_id: entities})
else:
node_to_command[node_id] = {"shards": {shard_id: entities},
"groups": groups,
"column_name": column_name,
"type": "get_entities",
"also_add": also_add}
tasks = []
for node, command in node_to_command.items():
task = asyncio.ensure_future(
self.execute_single_node(node.id, ENTITYID_INSERT_COMMAND,
command))
tasks.append(task)
responses = await asyncio.gather(*tasks, return_exceptions=True)
final_response = {}
new_entities = []
for response in responses:
if isinstance(response, Exception):
raise exceptions.IndexException(1, "Some nodes are unavailable "
"now, error %s" % response)
if "status" in response:
status = response["status"]
if "error" in status:
raise exceptions.IndexException(response.get('code'),
response.get('msg'))
if "result" in response:
result = response["result"]
entities_map = result["entities"]
new = result["new-entities"]
final_response.update(entities_map)
new_entities.extend(new)
if not final_response:
for entity in entity_ids:
final_response[entity] = "UNSET"
return final_response, new_entities
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment