Skip to content

Commit

Permalink
Merge pull request #2 from Parsely/bugfix/timeout
Browse files Browse the repository at this point in the history
Add cycle function for active connections
  • Loading branch information
aldraco authored Jan 22, 2019
2 parents e712e68 + fd13c43 commit 68f78df
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 15 deletions.
3 changes: 2 additions & 1 deletion fluster/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__version__ = '0.0.5'
__version__ = '0.1.0'

from .utils import round_controlled
from .cluster import FlusterCluster
from .exceptions import ClusterEmptyError

Expand Down
62 changes: 54 additions & 8 deletions fluster/cluster.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
from itertools import cycle
import functools
import logging

Expand All @@ -22,6 +23,9 @@ class FlusterCluster(object):
Ideal cases for this are things like caches, where another copy of data
isn't a huge problem (provided expiries are respected).
The FlusterCluster instance can be iterated through, and only active
connections will be returned.
"""

@classmethod
Expand All @@ -38,8 +42,32 @@ def __init__(self,
multiplier=penalty_box_wait_multiplier)
self.active_clients = self._prep_clients(clients)
self.initial_clients = {c.pool_id: c for c in clients}
self.clients = cycle(self.initial_clients.values())
self._sort_clients()

def __iter__(self):
"""Updates active clients each time it's iterated through."""
self._prune_penalty_box()
return self

def __next__(self):
"""Always returns a client, or raises an Exception if none are available."""
# raise Exception if no clients are available
if len(self.active_clients) == 0:
raise ClusterEmptyError('All clients are down.')

# refresh connections if they're back up
self._prune_penalty_box()

# return the first client that's active
for client in self.clients:
if client in self.active_clients:
return client

def next(self):
"""Python 2/3 compatibility."""
return self.__next__()

def _sort_clients(self):
"""Make sure clients are sorted consistently for consistent results."""
self.active_clients.sort(key=lambda c: c.pool_id)
Expand Down Expand Up @@ -73,10 +101,7 @@ def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except (ConnectionError, TimeoutError): # TO THE PENALTY BOX!
if client in self.active_clients: # hasn't been removed yet
log.warning('%r marked down.', client)
self.active_clients.remove(client)
self.penalty_box.add(client)
self._penalize_client(client)
raise
return functools.update_wrapper(wrapper, fn)

Expand All @@ -92,11 +117,10 @@ def wrapper(*args, **kwargs):
log.debug('Wrapping %s', name)
setattr(client, name, wrap(obj))

def get_client(self, shard_key):
"""Get the client for a given shard, based on what's available.
def _prune_penalty_box(self):
"""Restores clients that have reconnected.
If the proper client isn't available, the next available client
is returned. If no clients are available, an exception is raised.
This function should be called first for every public method.
"""
added = False
for client in self.penalty_box.get():
Expand All @@ -106,6 +130,14 @@ def get_client(self, shard_key):
if added:
self._sort_clients()

def get_client(self, shard_key):
"""Get the client for a given shard, based on what's available.
If the proper client isn't available, the next available client
is returned. If no clients are available, an exception is raised.
"""
self._prune_penalty_box()

if len(self.active_clients) == 0:
raise ClusterEmptyError('All clients are down.')

Expand All @@ -125,11 +157,25 @@ def get_client(self, shard_key):
pos = hashed % len(self.active_clients)
return self.active_clients[pos]

def _penalize_client(self, client):
"""Place client in the penalty box.
:param client: Client object
"""
if client in self.active_clients: # hasn't been removed yet
log.warning('%r marked down.', client)
self.active_clients.remove(client)
self.penalty_box.add(client)
else:
log.info("%r not in active client list.")

def zrevrange_with_int_score(self, key, max_score, min_score):
"""Get the zrevrangebyscore across the cluster.
Highest score for duplicate element is returned.
A faster method should be written if scores are not needed.
"""
self._prune_penalty_box()

if len(self.active_clients) == 0:
raise ClusterEmptyError('All clients are down.')

Expand Down
4 changes: 3 additions & 1 deletion fluster/penalty_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def get(self):
now = time.time()
while self._clients and self._clients[0][0] < now:
_, (client, last_wait) = heapq.heappop(self._clients)
connect_start = time.time()
try:
client.echo('test') # reconnected if this succeeds.
self._client_ids.remove(client.pool_id)
yield client
except (ConnectionError, TimeoutError):
timer = time.time() - connect_start
wait = min(int(last_wait * self._multiplier), self._max_wait)
heapq.heappush(self._clients,
(time.time() + wait, (client, wait)))
log.info('%r is still down. Retrying in %ss.', client, wait)
log.info('%r is still down after a %s second attempt to connect. Retrying in %ss.', client, timer, wait)
15 changes: 15 additions & 0 deletions fluster/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def round_controlled(cycled_iterable, rounds=1):
"""Raise StopIteration after <rounds> passes through a cycled iterable."""
round_start = None
rounds_completed = 0

for item in cycled_iterable:
if round_start is None:
round_start = item
elif item == round_start:
rounds_completed += 1

if rounds_completed == rounds:
raise StopIteration

yield item
134 changes: 129 additions & 5 deletions tests/fluster/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from redis.exceptions import ConnectionError
from testinstances import RedisInstance

from fluster import FlusterCluster
from fluster import FlusterCluster, ClusterEmptyError
import redis


class FlusterClusterTests(unittest.TestCase):
Expand Down Expand Up @@ -96,14 +97,137 @@ def test_consistent_hashing(self):
# Bring it back up
self.instances[0] = RedisInstance(10101)

def test_cycle_clients(self):
# should cycle through clients indefinately
returned_clients = set()
limit = 15

assert True

for idx, client in enumerate(self.cluster):
returned_clients.update([client])
assert client is not None
if idx >= limit:
break

assert idx == 15
assert len(returned_clients) == len(self.cluster.active_clients)

def test_cycle_clients_with_failures(self):
# should not include inactive nodes
self.instances[0].terminate()
limit = 6
counter = 0

for idx, client in enumerate(self.cluster):
assert client is not None
try:
client.incr('key', 1)
counter += 1
except Exception as e:
print("oops", client, e)
continue # exception handled by the cluster
if idx >= limit:
break

# Restart instance
self.instances[0] = RedisInstance(10101)
time.sleep(0.5)

assert counter == 6 # able to continue even when node is down
assert 2 == len(self.cluster.active_clients)
assert 2 == len(self.cluster.initial_clients.values()) - 1

# should add restarted nodes back to the list after reported failure
# calling __iter__ again checks the penalty box
counter = 0
for idx, client in enumerate(self.cluster):
if idx >= limit:
break
client.incr('key', 1)
counter += 1

assert counter == limit
assert len(self.cluster.active_clients) == 3 # to verify it added the node back

def test_long_running_iterations(self):
# long-running iterations should still add connections back to the cluster
drop_client = 3
restart_client = 10
client_available = restart_client + 1

for idx, client in enumerate(self.cluster):
# attempt to use each client
try:
client.incr('key', 1)
except Exception:
continue # exception handled by the cluster
# mimic connection dropping out and returning
if idx == drop_client:
self.instances[0].terminate()
elif idx == restart_client:
self.instances[0] = RedisInstance(10101)
# client should be visible after calling next() again
elif idx == client_available:
assert len(self.cluster.active_clients) == 3
break

def test_cycle_clients_tracking(self):
# should track separate cycle entry points for each instance
cluster_instance_1 = self.cluster
# connect to already-running testinstances, instead of making more,
# to mimic two FlusterCluster instances
redis_clients = [redis.StrictRedis(port=conn.port)
for conn in self.instances]
cluster_instance_2 = FlusterCluster([i for i in redis_clients],
penalty_box_min_wait=0.5)

# advance cluster instance one
next(cluster_instance_1)

# should not start at the same point
assert next(cluster_instance_1) != next(cluster_instance_2)

for temp_conn in redis_clients:
del temp_conn

def test_dropped_connections_while_iterating(self):
# dropped connections in the middle of an iteration should not cause an infinite loop
# and should raise an exception
limit = 21

assert len(self.cluster.active_clients) == 3

drop_at_idx = (5, 6, 7) # at these points, kill a connection
killed = 0
with self.assertRaises(ClusterEmptyError) as context:
for idx, client in enumerate(self.cluster):
if idx >= limit:
break # in case the test fails to stop
if idx in drop_at_idx:
self.instances[killed].terminate()
killed += 1
print('killed ', idx, killed)
try:
client.incr('key', 1)
except:
pass # mimic err handling
self.assertTrue('All clients are down.' in str(context.exception))

assert idx == 8 # the next iteration after the last client was killed

# restart all the instances
for instance, port in enumerate(range(10101, 10104)):
self.instances[instance] = RedisInstance(port)

def test_zrevrange(self):
"""Add a sorted set, turn off the client, add to the set,
turn the client back on, check results
"""
key = 'foo'
for element, count in zip(self.keys, (1.0, 2.0, 3.0)):
client = self.cluster.get_client(element)
client.zadd(key, count, element)
client.zadd(key, {element: count})
revrange = self.cluster.zrevrange_with_int_score(key, '+inf', 2)
self.assertEqual(set([3, 2]), set(revrange.values()))

Expand All @@ -114,12 +238,12 @@ def test_zrevrange(self):
new_count = 5
client = self.cluster.get_client(dropped_element)
try:
client.zadd(key, new_count, dropped_element)
client.zadd(key, {dropped_element: new_count})
raise Exception("Should not get here, client was terminated")
except ConnectionError:
client = self.cluster.get_client(dropped_element)
print('replaced client', client)
client.zadd(key, new_count, dropped_element)
client.zadd(key, {dropped_element: new_count})
revrange = self.cluster.zrevrange_with_int_score(key, '+inf', 2)
self.assertEqual(set([new_count, 2]), set(revrange.values()))

Expand All @@ -131,6 +255,6 @@ def test_zrevrange(self):
self.assertEqual(set([new_count, 2]), set(revrange.values())) #restarted instance is empty in this case

client = self.cluster.get_client(dropped_element)
client.zadd(key, 3, dropped_element) #put original value back in
client.zadd(key, {dropped_element: 3}) #put original value back in
revrange = self.cluster.zrevrange_with_int_score(key, '+inf', 2)
self.assertEqual(set([new_count, 2]), set(revrange.values())) #max value found for duplicates is returned
54 changes: 54 additions & 0 deletions tests/fluster/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import absolute_import, print_function
import unittest
import sys

from testinstances import RedisInstance
from fluster import FlusterCluster, round_controlled


class FlusterClusterTests(unittest.TestCase):
def assertCountEqual(self, a, b):
if sys.version_info > (3, 0):
super(FlusterClusterTests, self).assertCountEqual(a, b)
else:
self.assertItemsEqual(a, b)

@classmethod
def setUpClass(cls):
cls.instances = [RedisInstance(10101),
RedisInstance(10102),
RedisInstance(10103)]

@classmethod
def tearDownClass(cls):
for instance in cls.instances:
instance.terminate()

def setUp(self):
self.cluster = FlusterCluster([i.conn for i in self.instances],
penalty_box_min_wait=0.5)
self.keys = ['hi', 'redis', 'test'] # hashes to 3 separate values

def tearDown(self):
for instance in self.instances:
if hasattr(instance.conn, 'pool_id'):
delattr(instance.conn, 'pool_id')

def test_round_controller(self):
# the round controller should track rounds and limit iterations
repeated_sublist = list(range(0, 3))
lis = repeated_sublist * 5
desired_rounds = 4 # don't iterate through the whole list
for idx, item in enumerate(round_controlled(lis, rounds=desired_rounds)):
pass

assert idx == desired_rounds * len(repeated_sublist) - 1

# more specific application
desired_rounds = 3

for idx, conn in enumerate(round_controlled(self.cluster, rounds=desired_rounds)):
pass

# should raise stopiteration at appropriate time
assert idx == (desired_rounds * len(self.cluster.active_clients) - 1)

0 comments on commit 68f78df

Please sign in to comment.