Skip to content

Instantly share code, notes, and snippets.

@mushkevych
Forked from awestendorf/mongo_rebalance.py
Created November 15, 2011 00:48
Show Gist options
  • Save mushkevych/1365738 to your computer and use it in GitHub Desktop.
Save mushkevych/1365738 to your computer and use it in GitHub Desktop.
An example of rebalancing a pymongo MasterSlaveConnection
"""
Created on 2011-04-23
@author: Bohdan Mushkevych
@author: Aaron Westendorf
"""
import functools
import time
from pymongo.errors import AutoReconnect
from pymongo.connection import Connection as MongoConnection
from pymongo.uri_parser import parse_host
from pymongo.master_slave_connection import MasterSlaveConnection
#MongoDB database name
mongo_db_name = 'example_db'
# Replica set - list of hosts
replica_set_example_1 = ['mongo-11.production.ec2.company.com:27017',
'mongo-12.production.ec2.company.com:27017',
'mongo-13.production.ec2.company.com:27017'],
replica_set_example_2 = ['mongo-21.production.ec2.company.com:27017',
'mongo-22.production.ec2.company.com:27017',
'mongo-23.production.ec2.company.com:27017'],
# Replica set - names
REPLICA_SET_EXAMPLE_1 = 'rs_example_1'
REPLICA_SET_EXAMPLE_2 = 'rs_example_2'
# Collections examples
COLLECTION_EXAMPLE_1 = 'example_1_collection'
COLLECTION_EXAMPLE_2 = 'example_2_collection'
def with_reconnect(func):
"""
Handle when AutoReconnect is raised from pymongo. This is the standard error
raised for everything from "host disconnected" to "couldn't connect to host"
and more.
The sleep handles the edge case when the state of a replica set changes, and
the cursor raises AutoReconnect because the master may have changed. It can
take some time for the replica set to stop raising this exception, and the
small sleep and iteration count gives us a couple of seconds before we fail
completely.
"""
@functools.wraps(func)
def _reconnector(*args, **kwargs):
for _ in xrange(0, 20):
try:
return func(*args, **kwargs)
except AutoReconnect:
time.sleep(0.250)
raise
return _reconnector
class ClusterConnection(MasterSlaveConnection):
""" - UTC friendly
- redirects all reads to ReplicaSet slaves
- all writes go to ReplicaSet master
- re-connect to lost slaves node from ReplicaSet every 5 min
- automatic handling of AutoReconnect or Master change
"""
VALIDATE_INTERVAL = 300 # 5 minutes
def __init__(self, logger, host_list):
"""@param host_list: initial list of nodes in ReplicaSet (can change during the life time)"""
self.logger = logger
self.host_list = host_list
master_connection = MongoConnection(self.host_list)
slave_log_list = []
slave_connections = []
for host in self.host_list:
slave_host, slave_port = parse_host(host)
# remove master from list of slaves, so no reads are going its way
# however, allow master to handle reads if its the only node in ReplicaSet
if len(self.host_list) > 1 \
and slave_host == master_connection._Connection__host \
and slave_port == master_connection._Connection__port:
continue
slave_log_list.append('%s:%r' % (slave_host, slave_port))
slave_connections.append(MongoConnection(host=slave_host, port=slave_port, slave_okay=True, _connect=False))
self.logger.info('ClusterConnection.init: master %r, slaves: %r' % (master_connection, slave_log_list))
super(ClusterConnection, self).__init__(master=master_connection, slaves=slave_connections)
self._last_validate_time = time.time()
@property
def tz_aware(self):
""" True stands for local-aware timezone, False for UTC """
return False
def get_w_number(self):
""" number of nodes to replicate highly_important data on insert/update """
w_number = 1
master_host_port = (self.master._Connection__host, self.master._Connection__port)
# For each connection that is not master - increase w_number
for slave in self.slaves:
host_port = (slave._Connection__host, slave._Connection__port)
if host_port == master_host_port:
continue
if host_port == (None, None):
continue
else:
w_number += 1
return w_number
def validate_slaves(self):
"""
1. If we're at the check interval, confirm that all slaves are connected to their
intended hosts and if not, reconnect them.
2. Remove master from list of slaves.
"""
if time.time() - self._last_validate_time < self.VALIDATE_INTERVAL:
return
master_host_port = (self.master._Connection__host, self.master._Connection__port)
hosts_ports = [parse_host(uri) for uri in self.host_list]
# For each connection that is not pointing to a configured slave:
# - disconnect it and remove from the list.
for slave in self.slaves:
host_port = (slave._Connection__host, slave._Connection__port)
if host_port == master_host_port:
# use case: master connection is among slaves
if len(self.slaves) > 1:
# remove master from list of slaves, so no reads are going its way
# however, allow master to handle reads if its the only node in ReplicaSet
slave.disconnect()
self.slaves.remove(slave)
hosts_ports.remove(master_host_port)
elif host_port not in hosts_ports:
slave.disconnect()
self.slaves.remove(slave)
else:
hosts_ports.remove(host_port)
# use case: remove master URI from "re-connection" list, if there are other active slave connections
if len(self.slaves) > 0 and master_host_port in hosts_ports:
# if at least one slave is active - do not try to (re)connect to master
hosts_ports.remove(master_host_port)
# For all hosts where there wasn't an existing connection, create one
for host, port in hosts_ports:
self.slaves.append(MongoConnection(host=host, port=port, slave_okay=True, _connect=False))
self.logger.info('ClusterConnection.validate: master %r, slaves: %r' % (self.master, self.slaves))
self._last_validate_time = time.time()
def get_master_host_port(self):
""" @return current host and port of the master node in Replica Set"""
return self.master._Connection__host, self.master._Connection__port
class ReplicaSetContext:
_DB_HOST_LIST = '_db_host_list'
connection_pool = dict()
REPLICA_SET_CONTEXT = {
REPLICA_SET_EXAMPLE_1: {_DB_HOST_LIST: replica_set_example_1},
REPLICA_SET_EXAMPLE_2: {_DB_HOST_LIST: replica_set_example_2},
}
@classmethod
def get_connection(cls, logger, replica_set):
""" method creates ClusterConnection to replica set and returns it"""
record = cls.REPLICA_SET_CONTEXT[replica_set]
if replica_set not in cls.connection_pool:
host_list = record[cls._DB_HOST_LIST]
cls.connection_pool[replica_set] = ClusterConnection(logger, host_list)
else:
cls.connection_pool[replica_set].validate_slaves()
return cls.connection_pool[replica_set]
class CollectionContext:
_REPLICA_SET = 'replica_set'
COLLECTION_CONTEXT = {
COLLECTION_EXAMPLE_1 : { _REPLICA_SET : REPLICA_SET_EXAMPLE_1},
COLLECTION_EXAMPLE_2 : { _REPLICA_SET : REPLICA_SET_EXAMPLE_2}
}
@classmethod
def get_fixed_connection(cls, logger, collection_name, slave_ok=True):
""" Method retrieves non-balancing connection from ReplicaSetContext.
Such connection is locked to one slave node, and will not handle its unavailability.
Returns fully specified connection to collection."""
try:
rs = cls.COLLECTION_CONTEXT[collection_name][cls._REPLICA_SET]
rs_connection = ReplicaSetContext.get_connection(logger, rs)
fixed_connection = None
if slave_ok:
# case A: client requests slave-tolerant connection
for slave in rs_connection.slaves:
host_port = (slave._Connection__host, slave._Connection__port)
if host_port == (None, None):
continue
else:
fixed_connection = slave
if not slave_ok or fixed_connection is None:
# case B: ReplicaSet has no valid slave connection, or master connection was requested
fixed_connection = rs_connection.master
synergy = fixed_connection[mongo_db_name]
return synergy[collection_name]
except Exception:
logger.error('CollectionContext error: %r' % collection_name, exc_info=True)
@classmethod
def get_collection(cls, logger, collection_name):
""" method retrieves connection from ReplicaSetContext and
links it to the collection name. Returns fully specified connection to collection.
Avoid pooling at this level, as it blocks ClusterConnection load balancing"""
try:
rs = cls.COLLECTION_CONTEXT[collection_name][cls._REPLICA_SET]
db_connection = ReplicaSetContext.get_connection(logger, rs)
synergy = db_connection[mongo_db_name]
return synergy[collection_name]
except Exception:
logger.error('CollectionContext error: %s' % collection_name, exc_info=True)
@classmethod
def get_w_number(cls, logger, collection_name):
""" w number indicates number of nodes to replicate _highly_important_ data on insert/update
replication of write shall be used only for System collections """
try:
rs = cls.COLLECTION_CONTEXT[collection_name][cls._REPLICA_SET]
db_connection = ReplicaSetContext.get_connection(logger, rs)
return db_connection.get_w_number()
except Exception:
logger.error('CollectionContext error: %s' % collection_name, exc_info=True)
@classmethod
def get_master_host_port(cls, logger, collection_name):
""" @return current host and port of the master node in Replica Set"""
try:
rs = cls.COLLECTION_CONTEXT[collection_name][cls._REPLICA_SET]
db_connection = ReplicaSetContext.get_connection(logger, rs)
return db_connection.get_master_host_port()
except Exception:
logger.error('CollectionContext error: %s' % collection_name, exc_info=True)
@mushkevych
Copy link
Author

Updated code to correspond to pymongo 2.X.X drivers, where "pymongo.connection._str_to_host" was replaced by "pymongo.connection.parse_host"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment