Skip to content

Instantly share code, notes, and snippets.

@rgs1
Created April 25, 2014 22:33
Show Gist options
  • Save rgs1/11305394 to your computer and use it in GitHub Desktop.
Save rgs1/11305394 to your computer and use it in GitHub Desktop.
commit ff99c46b4db900ca7710771a7f7cb5f65e3416e9
Author: Raul Gutierrez S <rgs@itevenworks.net>
Date: Fri Apr 25 14:06:20 2014 -0700
Reset watches when entering CONNECTED state
Signed-off-by: Raul Gutierrez S <rgs@itevenworks.net>
diff --git a/kazoo/client.py b/kazoo/client.py
index fb32b97782a5..45908dfad167 100644
--- a/kazoo/client.py
+++ b/kazoo/client.py
@@ -38,6 +38,7 @@ from kazoo.protocol.serialization import (
SetACL,
GetData,
SetData,
+ SetWatches,
Sync,
Transaction
)
@@ -426,6 +427,9 @@ class KazooClient(object):
dead_state = self._state in LOST_STATES
self._state = state
+ if state == KeeperState.CONNECTED:
+ self._set_watches()
+
# If we were previously closed or had an expired session, and
# are now connecting, don't bother with the rest of the
# transitions since they only apply after
@@ -450,7 +454,18 @@ class KazooClient(object):
self._live.clear()
self._notify_pending(state)
self._make_state_change(KazooState.SUSPENDED)
- self._reset_watchers()
+ if state != KeeperState.CONNECTING:
+ self._reset_watchers()
+
+ # TODO: separate exist from data watches
+ def _set_watches(self):
+ async_result = self.handler.async_result()
+ sw = SetWatches(self.last_zxid,
+ self._data_watchers.keys(),
+ self._data_watchers.keys(),
+ self._child_watchers.keys())
+ self._call(sw, async_result)
+ return async_result
def _notify_pending(self, state):
"""Used to clear a pending response queue and request queue
diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py
index f44f49a3247e..ecc8f965ba63 100644
--- a/kazoo/protocol/serialization.py
+++ b/kazoo/protocol/serialization.py
@@ -14,6 +14,7 @@ int_int_struct = struct.Struct('!ii')
int_int_long_struct = struct.Struct('!iiq')
int_long_int_long_struct = struct.Struct('!iqiq')
+long_struct = struct.Struct('!q')
multiheader_struct = struct.Struct('!iBi')
reply_header_struct = struct.Struct('!iqi')
stat_struct = struct.Struct('!qqqqiiiqiiq')
@@ -53,6 +54,14 @@ def write_string(bytes):
return int_struct.pack(len(utf8_str)) + utf8_str
+def write_string_vector(v):
+ b = bytearray()
+ b.extend(int_struct.pack(len(v)))
+ for s in v:
+ b.extend(write_string(s))
+ return b
+
+
def write_buffer(bytes):
if bytes is None:
return int_struct.pack(-1)
@@ -360,6 +369,24 @@ class Auth(namedtuple('Auth', 'auth_type scheme auth')):
write_string(self.auth))
+class SetWatches(
+ namedtuple('SetWatches',
+ 'relativeZxid, dataWatches, existWatches, childWatches')):
+ type = 101
+
+ def serialize(self):
+ b = bytearray()
+ b.extend(long_struct.pack(self.relativeZxid))
+ b.extend(write_string_vector(self.dataWatches))
+ b.extend(write_string_vector(self.existWatches))
+ b.extend(write_string_vector(self.childWatches))
+ return b
+
+ @classmethod
+ def deserialize(self, bytes, offset):
+ return True
+
+
class Watch(namedtuple('Watch', 'type state path')):
@classmethod
def deserialize(cls, bytes, offset):
diff --git a/kazoo/testing/harness.py b/kazoo/testing/harness.py
index 0a9079ac0d99..0ea9155a24de 100644
--- a/kazoo/testing/harness.py
+++ b/kazoo/testing/harness.py
@@ -2,8 +2,10 @@
import atexit
import logging
import os
+import socket
import uuid
import threading
+import time
import unittest
from kazoo.client import KazooClient
@@ -123,6 +125,27 @@ class KazooTestHarness(unittest.TestCase):
raise Exception("Failed to see client reconnect")
self.client.retry(self.client.get_async, '/')
+
+ def force_reconnect(self):
+ client = self.client
+ state_change_event = client.handler.event_object()
+
+ def listener(state):
+ if state is KazooState.SUSPENDED:
+ state_change_event.set()
+
+ client.add_listener(listener)
+
+ client._connection._socket.shutdown(socket.SHUT_RDWR)
+
+ state_change_event.wait(1)
+ self.assertTrue(state_change_event.is_set())
+
+ # wait until we are back
+ while not client.connected:
+ time.sleep(0.1)
+
+
def setup_zookeeper(self, **client_options):
"""Create a ZK cluster and chrooted :class:`KazooClient`
diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py
index a2c670dac6ea..5a1022b09559 100644
--- a/kazoo/tests/test_client.py
+++ b/kazoo/tests/test_client.py
@@ -1,4 +1,3 @@
-import socket
import sys
import threading
import time
@@ -223,9 +222,9 @@ class TestConnection(KazooTestCase):
def test_add_auth_on_reconnect(self):
self.client.add_auth("digest", "jsmith:jsmith")
- self.client._connection._socket.shutdown(socket.SHUT_RDWR)
- while not self.client.connected:
- time.sleep(0.1)
+
+ self.force_reconnect()
+
self.assertTrue(("digest", "jsmith:jsmith") in self.client.auth_data)
def test_session_expire(self):
@@ -878,6 +877,31 @@ class TestClient(KazooTestCase):
finally:
self.cluster[0].run()
+ def test_set_watches_on_reconnect(self):
+ client = self.client
+ watch_event = client.handler.event_object()
+
+ client.create("/tacos")
+
+ # set the watch
+ def w(we):
+ eq_(we.path, "/tacos")
+ watch_event.set()
+
+ client.get_children("/tacos", watch=w)
+
+ # force a reconnect
+ self.force_reconnect()
+
+ # watches should still be there
+ self.assertTrue(len(client._child_watchers) == 1)
+
+ # ... and they should fire
+ client.create("/tacos/hello_", "", ephemeral=True, sequence=True)
+
+ watch_event.wait(1)
+ self.assertTrue(watch_event.is_set())
+
dummy_dict = {
'aversion': 1, 'ctime': 0, 'cversion': 1,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment