feat: address and port connection parameter
fix: thread stop now correctly
This commit is contained in:
@ -23,12 +23,13 @@ class Client(object):
|
||||
id=id)
|
||||
self._factory = factory
|
||||
|
||||
def connect(self):
|
||||
self._net_client.start()
|
||||
def connect(self,address="127.0.0.1",port=5560):
|
||||
self._net_client.connect(address=address,port=port)
|
||||
|
||||
def disconnect(self):
|
||||
self._net_client.stop()
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self._net_client.state
|
||||
|
||||
@ -75,21 +76,24 @@ class ClientNetService(threading.Thread):
|
||||
|
||||
# Networking
|
||||
self.context = zmq.Context.instance()
|
||||
|
||||
self.snapshot = self.context.socket(zmq.DEALER)
|
||||
self.snapshot.setsockopt(zmq.IDENTITY, self._id.encode())
|
||||
self.snapshot.connect("tcp://127.0.0.1:5560")
|
||||
|
||||
self.subscriber = self.context.socket(zmq.SUB)
|
||||
self.subscriber.setsockopt_string(zmq.SUBSCRIBE, '')
|
||||
self.subscriber.connect("tcp://127.0.0.1:5561")
|
||||
self.subscriber.linger = 0
|
||||
|
||||
self.publish = self.context.socket(zmq.PUSH)
|
||||
self.publish.connect("tcp://127.0.0.1:5562")
|
||||
|
||||
self.state = STATE_INITIAL
|
||||
|
||||
def connect(self,address='127.0.0.1', port=5560):
|
||||
if self.state == STATE_INITIAL:
|
||||
logger.debug("connecting on {}:{}".format(address,port))
|
||||
self.snapshot = self.context.socket(zmq.DEALER)
|
||||
self.snapshot.setsockopt(zmq.IDENTITY, self._id.encode())
|
||||
self.snapshot.connect("tcp://{}:{}".format(address, port))
|
||||
|
||||
self.subscriber = self.context.socket(zmq.SUB)
|
||||
self.subscriber.setsockopt_string(zmq.SUBSCRIBE, '')
|
||||
self.subscriber.connect("tcp://{}:{}".format(address, port+1))
|
||||
self.subscriber.linger = 0
|
||||
|
||||
self.publish = self.context.socket(zmq.PUSH)
|
||||
self.publish.connect("tcp://{}:{}".format(address, port+2))
|
||||
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
logger.info("{} online".format(self._id))
|
||||
@ -138,7 +142,11 @@ class ClientNetService(threading.Thread):
|
||||
logger.error("No request ")
|
||||
|
||||
|
||||
time.sleep(1)
|
||||
self.snapshot.close()
|
||||
self.subscriber.close()
|
||||
self.publish.close()
|
||||
|
||||
self._exit_event.clear()
|
||||
|
||||
def setup(self,id="Client"):
|
||||
pass
|
||||
@ -146,9 +154,9 @@ class ClientNetService(threading.Thread):
|
||||
def stop(self):
|
||||
self._exit_event.set()
|
||||
|
||||
self.snapshot.close()
|
||||
self.subscriber.close()
|
||||
self.publish.close()
|
||||
#Wait the end of the run
|
||||
while self._exit_event.is_set():
|
||||
time.sleep(.1)
|
||||
|
||||
self.state = 0
|
||||
|
||||
@ -159,8 +167,8 @@ class Server():
|
||||
self._rep_store = {}
|
||||
self._net = ServerNetService(store_reference=self._rep_store, factory=factory)
|
||||
|
||||
def serve(self):
|
||||
self._net.start()
|
||||
def serve(self,port=5560):
|
||||
self._net.listen(port=port)
|
||||
|
||||
def state(self):
|
||||
return self._net.state
|
||||
@ -187,27 +195,27 @@ class ServerNetService(threading.Thread):
|
||||
self.state = 0
|
||||
self.factory = factory
|
||||
|
||||
self.bind_ports()
|
||||
|
||||
def bind_ports(self):
|
||||
def listen(self, port=5560):
|
||||
try:
|
||||
# Update request
|
||||
self.snapshot = self.context.socket(zmq.ROUTER)
|
||||
self.snapshot.setsockopt(zmq.IDENTITY, b'SERVER')
|
||||
self.snapshot.setsockopt(zmq.RCVHWM, 60)
|
||||
self.snapshot.bind("tcp://*:5560")
|
||||
self.snapshot.bind("tcp://*:{}".format(port))
|
||||
|
||||
# Update all clients
|
||||
self.publisher = self.context.socket(zmq.PUB)
|
||||
self.publisher.setsockopt(zmq.SNDHWM, 60)
|
||||
self.publisher.bind("tcp://*:5561")
|
||||
self.publisher.bind("tcp://*:{}".format(port+1))
|
||||
time.sleep(0.2)
|
||||
|
||||
# Update collector
|
||||
self.pull = self.context.socket(zmq.PULL)
|
||||
self.pull.setsockopt(zmq.RCVHWM, 60)
|
||||
self.pull.bind("tcp://*:5562")
|
||||
self.pull.bind("tcp://*:{}".format(port+2))
|
||||
|
||||
self.start()
|
||||
except zmq.error.ZMQError:
|
||||
logger.error("Address already in use, change net config")
|
||||
|
||||
@ -222,7 +230,7 @@ class ServerNetService(threading.Thread):
|
||||
|
||||
while not self._exit_event.is_set():
|
||||
# Non blocking poller
|
||||
socks = dict(poller.poll(10))
|
||||
socks = dict(poller.poll(1))
|
||||
|
||||
# Snapshot system for late join (Server - Client)
|
||||
if self.snapshot in socks:
|
||||
@ -269,12 +277,17 @@ class ServerNetService(threading.Thread):
|
||||
# Update all clients
|
||||
datablock.push(self.publisher)
|
||||
|
||||
|
||||
def stop(self):
|
||||
self._exit_event.set()
|
||||
|
||||
self.snapshot.close()
|
||||
self.pull.close()
|
||||
self.publisher.close()
|
||||
|
||||
self._exit_event.clear()
|
||||
|
||||
def stop(self):
|
||||
self._exit_event.set()
|
||||
|
||||
#Wait the end of the run
|
||||
while self._exit_event.is_set():
|
||||
time.sleep(.1)
|
||||
|
||||
self.state = 0
|
@ -42,50 +42,60 @@ class TestDataReplication(unittest.TestCase):
|
||||
def __init__(self,methodName='runTest'):
|
||||
unittest.TestCase.__init__(self, methodName)
|
||||
|
||||
self.factory = ReplicatedDataFactory()
|
||||
self.factory.register_type(SampleData, RepSampleData)
|
||||
def test_empty_snapshot(self):
|
||||
# Setup
|
||||
factory = ReplicatedDataFactory()
|
||||
factory.register_type(SampleData, RepSampleData)
|
||||
|
||||
self.server = Server(factory=self.factory)
|
||||
self.server.serve()
|
||||
server = Server(factory=factory)
|
||||
client = Client(factory=factory, id="client_test_callback")
|
||||
|
||||
self.client = Client(factory=self.factory, id="client_1")
|
||||
self.client.connect()
|
||||
server.serve(port=5570)
|
||||
client.connect(port=5570)
|
||||
|
||||
self.client2 = Client(factory=self.factory, id="client_2")
|
||||
self.client2.connect()
|
||||
test_state = client.state
|
||||
|
||||
server.stop()
|
||||
client.disconnect()
|
||||
|
||||
self.assertNotEqual(test_state, 2)
|
||||
|
||||
|
||||
def test_register_client_data(self):
|
||||
data_sample_key = self.client.register(SampleData())
|
||||
# Setup
|
||||
factory = ReplicatedDataFactory()
|
||||
factory.register_type(SampleData, RepSampleData)
|
||||
|
||||
server = Server(factory=factory)
|
||||
server.serve(port=5560)
|
||||
|
||||
client = Client(factory=factory, id="client_1")
|
||||
client.connect(port=5560)
|
||||
|
||||
client2 = Client(factory=factory, id="client_2")
|
||||
client2.connect(port=5560)
|
||||
|
||||
time.sleep(1)
|
||||
data_sample_key = client.register(SampleData())
|
||||
|
||||
#Waiting for server to receive the datas
|
||||
time.sleep(2)
|
||||
|
||||
test_key = self.client2._rep_store[data_sample_key]
|
||||
test_key = client2._rep_store[data_sample_key]
|
||||
|
||||
|
||||
client.disconnect()
|
||||
client2.disconnect()
|
||||
server.stop()
|
||||
|
||||
#Check if the server receive them
|
||||
self.assertNotEqual(test_key, None)
|
||||
|
||||
def test_register_client_data2(self):
|
||||
data_sample_key = self.client.register(SampleData())
|
||||
|
||||
#Waiting for server to receive the datas
|
||||
time.sleep(2)
|
||||
|
||||
test_key = self.client2._rep_store[data_sample_key]
|
||||
|
||||
#Check if the server receive them
|
||||
self.assertNotEqual(test_key, None)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def suite():
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTest(TestDataFactory('test_data_factory'))
|
||||
suite.addTest(TestDataReplication('test_empty_snapshot'))
|
||||
suite.addTest(TestDataReplication('test_register_client_data'))
|
||||
suite.addTest(TestDataReplication('test_register_client_data2'))
|
||||
|
||||
return suite
|
||||
|
||||
|
Reference in New Issue
Block a user