diff --git a/replication.py b/replication.py index 1ec8b2f..6ec850c 100644 --- a/replication.py +++ b/replication.py @@ -22,7 +22,7 @@ class ReplicatedDataFactory(object): """ Register a new replicated datatype implementation """ - types.append((supported_types, implementation)) + self.supported_types.append((dtype, implementation)) def match_type(self,data): for stypes, implementation in self.supported_types: diff --git a/replication_client.py b/replication_client.py index b78fdac..8ae8e7c 100644 --- a/replication_client.py +++ b/replication_client.py @@ -7,13 +7,16 @@ logging.basicConfig(level=logging.DEBUG) log = logging.getLogger(__name__) class Client(object): - def __init__(self): + def __init__(self,config=None): self.rep_store = {} self.net = ClientNetService(self.rep_store) def connect(self): - self.net.run() - + self.net.start() + + def state(self): + return self.net.state + def stop(self): self.net.stop() @@ -36,10 +39,10 @@ class ClientNetService(threading.Thread): self.subscriber.connect("tcp://127.0.0.1:5561") self.subscriber.linger = 0 - self.publish = self.context.socket(zmq.PULL) - self.publish.bind("tcp://*:5562") + self.publish = self.context.socket(zmq.PUSH) + self.publish.connect("tcp://127.0.0.1:5562") - # For teststing purpose + self.state = 0 def run(self): @@ -50,18 +53,82 @@ class ClientNetService(threading.Thread): poller.register(self.subscriber, zmq.POLLIN) poller.register(self.publish, zmq.POLLOUT) + self.state = 1 + while not self.exit_event.is_set(): items = dict(poller.poll(10)) if not items: log.error("No request ") - def stop(self): self.exit_event.set() - #Wait the end of the run - while self.exit_event.is_set(): - time.sleep(.1) + self.state = 0 + +class Server(): + def __init__(self,config=None): + self.rep_store = {} + self.net = ServerNetService(self.rep_store) + + def serve(self): + self.net.start() + + def state(self): + return self.net.state + + def stop(self): + self.net.stop() + + +class ServerNetService(threading.Thread): + def __init__(self,store_reference=None): + # Threading + threading.Thread.__init__(self) + self.name = "NetLink" + self.daemon = True + self.exit_event = threading.Event() + self.store = store_reference + + self.context = zmq.Context.instance() + + # Update request + self.snapshot = self.context.socket(zmq.ROUTER) + self.snapshot.setsockopt(zmq.IDENTITY, b'SERVER') + self.snapshot.setsockopt(zmq.RCVHWM, 60) + self.snapshot.bind("tcp://*:5560") + + # Update all clients + self.publisher = self.context.socket(zmq.PUB) + self.publisher.setsockopt(zmq.SNDHWM, 60) + self.publisher.bind("tcp://*:5561") + time.sleep(0.2) + + # Update collector + self.pull = self.context.socket(zmq.PULL) + self.pull.setsockopt(zmq.RCVHWM, 60) + self.pull.bind("tcp://*:5562") + + # poller for socket aggregation + self.poller = zmq.Poller() + self.poller.register(self.snapshot, zmq.POLLIN) + self.poller.register(self.pull, zmq.POLLIN) + + self.state = 0 + + def run(self): + log.debug("Running Server Net service") + + poller = zmq.Poller() + poller.register(self.snapshot, zmq.POLLIN) + poller.register(self.pull, zmq.POLLIN) + + self.state = 1 + + while not self.exit_event.is_set(): + items = dict(poller.poll(10)) + + if not items: + log.info("No request ") \ No newline at end of file diff --git a/test_replication.py b/test_replication.py index b390840..55a9534 100644 --- a/test_replication.py +++ b/test_replication.py @@ -2,7 +2,8 @@ import unittest from replication import ReplicatedDatablock, ReplicatedDataFactory import umsgpack import logging -from replication_client import Client +from replication_client import Client, Server +import time log = logging.getLogger(__name__) @@ -30,24 +31,41 @@ class RepSampleData(ReplicatedDatablock): class TestData(unittest.TestCase): def setUp(self): self.map = {} - self.client_api = Client() + self.server_api = Server() + + def test_server_launching(self): + self.server_api.serve() + time.sleep(1) + self.assertEqual(self.server_api.state(),1) + def test_setup_data_factory(self): factory = ReplicatedDataFactory() factory.register_type(SampleData, RepSampleData) - def test_run_client(self): + data_sample = SampleData() + rep_sample = factory.construct(data_sample)(owner="toto") + self.assertEqual(isinstance(rep_sample,RepSampleData), True) + + def test_client_connect(self): self.client_api.connect() + time.sleep(1) + self.assertEqual(self.client_api.state(),1) - def test_stop_client(self): + def test_client_stop(self): self.client_api.stop() + time.sleep(1) + self.assertEqual(self.client_api.state(),0) + + def test_client_add_rep(self): + pass + + def test_add_replicated_value(self): pass - def test_create_replicated_data(self): - self.assertNotEqual(self.sample_data.uuid,None) - +