feat: test progress
This commit is contained in:
@ -72,7 +72,7 @@ class ReplicatedDatablock(object):
|
|||||||
"""
|
"""
|
||||||
uuid = None # uuid used as key (string)
|
uuid = None # uuid used as key (string)
|
||||||
pointer = None # dcc data ref (DCC type)
|
pointer = None # dcc data ref (DCC type)
|
||||||
buffer = None # data blob (json)
|
buffer = None # raw data (json)
|
||||||
str_type = None # data type name (string)
|
str_type = None # data type name (string)
|
||||||
deps = [None] # dependencies array (string)
|
deps = [None] # dependencies array (string)
|
||||||
owner = None # Data owner (string)
|
owner = None # Data owner (string)
|
||||||
@ -82,8 +82,13 @@ class ReplicatedDatablock(object):
|
|||||||
self.uuid = uuid if uuid else str(uuid4())
|
self.uuid = uuid if uuid else str(uuid4())
|
||||||
assert(owner)
|
assert(owner)
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.pointer = data
|
|
||||||
self.buffer = buffer if buffer else None
|
if data:
|
||||||
|
self.pointer = data
|
||||||
|
elif buffer:
|
||||||
|
self.buffer = self.deserialize(buffer)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not enought parameter in constructor")
|
||||||
self.str_type = type(self).__name__
|
self.str_type = type(self).__name__
|
||||||
|
|
||||||
def push(self, socket):
|
def push(self, socket):
|
||||||
@ -112,7 +117,6 @@ class ReplicatedDatablock(object):
|
|||||||
str_type = str_type.decode()
|
str_type = str_type.decode()
|
||||||
owner = owner.decode()
|
owner = owner.decode()
|
||||||
uuid = uuid.decode()
|
uuid = uuid.decode()
|
||||||
data = self.deserialize(data)
|
|
||||||
|
|
||||||
instance = factory.construct_from_net(str_type)(owner=owner, uuid=uuid, buffer=data)
|
instance = factory.construct_from_net(str_type)(owner=owner, uuid=uuid, buffer=data)
|
||||||
|
|
||||||
@ -176,7 +180,7 @@ class RepCommand(ReplicatedDatablock):
|
|||||||
return pickle.dumps(data)
|
return pickle.dumps(data)
|
||||||
|
|
||||||
def deserialize(self,data):
|
def deserialize(self,data):
|
||||||
return pickle.load(data)
|
return pickle.loads(data)
|
||||||
|
|
||||||
def apply(self,data,target):
|
def apply(self,data,target):
|
||||||
target = data
|
target = data
|
||||||
|
@ -13,13 +13,14 @@ STATE_ACTIVE = 2
|
|||||||
|
|
||||||
|
|
||||||
class Client(object):
|
class Client(object):
|
||||||
def __init__(self,factory=None, config=None):
|
def __init__(self,factory=None, id='default'):
|
||||||
assert(factory)
|
assert(factory)
|
||||||
|
|
||||||
self._rep_store = {}
|
self._rep_store = {}
|
||||||
self._net_client = ClientNetService(
|
self._net_client = ClientNetService(
|
||||||
store_reference=self._rep_store,
|
store_reference=self._rep_store,
|
||||||
factory=factory)
|
factory=factory,
|
||||||
|
id=id)
|
||||||
self._factory = factory
|
self._factory = factory
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
@ -58,7 +59,7 @@ class Client(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class ClientNetService(threading.Thread):
|
class ClientNetService(threading.Thread):
|
||||||
def __init__(self,store_reference=None, factory=None):
|
def __init__(self,store_reference=None, factory=None,id="default"):
|
||||||
|
|
||||||
# Threading
|
# Threading
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
@ -68,6 +69,7 @@ class ClientNetService(threading.Thread):
|
|||||||
self._exit_event = threading.Event()
|
self._exit_event = threading.Event()
|
||||||
self._factory = factory
|
self._factory = factory
|
||||||
self._store_reference = store_reference
|
self._store_reference = store_reference
|
||||||
|
self._id = id
|
||||||
|
|
||||||
assert(self._factory)
|
assert(self._factory)
|
||||||
|
|
||||||
@ -75,7 +77,7 @@ class ClientNetService(threading.Thread):
|
|||||||
self.context = zmq.Context.instance()
|
self.context = zmq.Context.instance()
|
||||||
|
|
||||||
self.snapshot = self.context.socket(zmq.DEALER)
|
self.snapshot = self.context.socket(zmq.DEALER)
|
||||||
self.snapshot.setsockopt(zmq.IDENTITY, b'SERVER')
|
self.snapshot.setsockopt(zmq.IDENTITY, self._id.encode())
|
||||||
self.snapshot.connect("tcp://127.0.0.1:5560")
|
self.snapshot.connect("tcp://127.0.0.1:5560")
|
||||||
|
|
||||||
self.subscriber = self.context.socket(zmq.SUB)
|
self.subscriber = self.context.socket(zmq.SUB)
|
||||||
@ -90,7 +92,7 @@ class ClientNetService(threading.Thread):
|
|||||||
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
logger.info("Client is online")
|
logger.info("{} online".format(self._id))
|
||||||
poller = zmq.Poller()
|
poller = zmq.Poller()
|
||||||
poller.register(self.snapshot, zmq.POLLIN)
|
poller.register(self.snapshot, zmq.POLLIN)
|
||||||
poller.register(self.subscriber, zmq.POLLIN)
|
poller.register(self.subscriber, zmq.POLLIN)
|
||||||
@ -103,6 +105,7 @@ class ClientNetService(threading.Thread):
|
|||||||
ACTIVE : Do nothing
|
ACTIVE : Do nothing
|
||||||
"""
|
"""
|
||||||
if self.state == STATE_INITIAL:
|
if self.state == STATE_INITIAL:
|
||||||
|
logger.debug('{} : request snapshot'.format(self._id))
|
||||||
self.snapshot.send(b"SNAPSHOT_REQUEST")
|
self.snapshot.send(b"SNAPSHOT_REQUEST")
|
||||||
self.state = STATE_SYNCING
|
self.state = STATE_SYNCING
|
||||||
|
|
||||||
@ -112,19 +115,22 @@ class ClientNetService(threading.Thread):
|
|||||||
SYNCING : Ask for snapshots
|
SYNCING : Ask for snapshots
|
||||||
ACTIVE : Do nothing
|
ACTIVE : Do nothing
|
||||||
"""
|
"""
|
||||||
items = dict(poller.poll(1))
|
items = dict(poller.poll(10))
|
||||||
|
|
||||||
if self.snapshot in items:
|
if self.snapshot in items:
|
||||||
if self.state == STATE_SYNCING:
|
if self.state == STATE_SYNCING:
|
||||||
datablock = ReplicatedDatablock.pull(self.snapshot, self._factory)
|
datablock = ReplicatedDatablock.pull(self.snapshot, self._factory)
|
||||||
|
|
||||||
if isinstance(datablock, RepCommand):
|
if datablock.buffer == 'SNAPSHOT_END':
|
||||||
|
self.state = STATE_ACTIVE
|
||||||
|
logger.debug('{} : snapshot done'.format(self._id))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# We receive updates from the server !
|
# We receive updates from the server !
|
||||||
if self.subscriber in items:
|
if self.subscriber in items:
|
||||||
if self.state == STATE_ACTIVE:
|
if self.state == STATE_ACTIVE:
|
||||||
logger.debug("Receiving changes from server")
|
logger.debug("{} : Receiving changes from server".format(self._id))
|
||||||
datablock = ReplicatedDatablock.pull(self.subscriber, self._factory)
|
datablock = ReplicatedDatablock.pull(self.subscriber, self._factory)
|
||||||
datablock.store(self._store_reference)
|
datablock.store(self._store_reference)
|
||||||
|
|
||||||
@ -216,7 +222,7 @@ class ServerNetService(threading.Thread):
|
|||||||
|
|
||||||
while not self._exit_event.is_set():
|
while not self._exit_event.is_set():
|
||||||
# Non blocking poller
|
# Non blocking poller
|
||||||
socks = dict(poller.poll())
|
socks = dict(poller.poll(10))
|
||||||
|
|
||||||
# Snapshot system for late join (Server - Client)
|
# Snapshot system for late join (Server - Client)
|
||||||
if self.snapshot in socks:
|
if self.snapshot in socks:
|
||||||
|
@ -26,57 +26,70 @@ class RepSampleData(ReplicatedDatablock):
|
|||||||
def deserialize(self,data):
|
def deserialize(self,data):
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
return pickle.load(data)
|
return pickle.loads(data)
|
||||||
|
|
||||||
class TestDataReplication(unittest.TestCase):
|
|
||||||
|
|
||||||
|
class TestDataFactory(unittest.TestCase):
|
||||||
def test_data_factory(self):
|
def test_data_factory(self):
|
||||||
factory = ReplicatedDataFactory()
|
factory = ReplicatedDataFactory()
|
||||||
factory.register_type(SampleData, RepSampleData)
|
factory.register_type(SampleData, RepSampleData)
|
||||||
data_sample = SampleData()
|
data_sample = SampleData()
|
||||||
rep_sample = factory.construct_from_dcc(data_sample)(owner="toto")
|
rep_sample = factory.construct_from_dcc(data_sample)(owner="toto", data=data_sample)
|
||||||
|
|
||||||
self.assertEqual(isinstance(rep_sample,RepSampleData), True)
|
self.assertEqual(isinstance(rep_sample,RepSampleData), True)
|
||||||
|
|
||||||
def test_basic_client_start(self):
|
|
||||||
factory = ReplicatedDataFactory()
|
|
||||||
factory.register_type(SampleData, RepSampleData)
|
|
||||||
|
|
||||||
server = Server(factory=factory)
|
class TestDataReplication(unittest.TestCase):
|
||||||
server.serve()
|
def __init__(self,methodName='runTest'):
|
||||||
|
unittest.TestCase.__init__(self, methodName)
|
||||||
|
|
||||||
client = Client(factory=factory)
|
self.factory = ReplicatedDataFactory()
|
||||||
client.connect()
|
self.factory.register_type(SampleData, RepSampleData)
|
||||||
|
|
||||||
time.sleep(1)
|
self.server = Server(factory=self.factory)
|
||||||
|
self.server.serve()
|
||||||
|
|
||||||
self.assertEqual(client.state(), 2)
|
self.client = Client(factory=self.factory, id="client_1")
|
||||||
|
self.client.connect()
|
||||||
|
|
||||||
# def test_register_client_data(self):
|
self.client2 = Client(factory=self.factory, id="client_2")
|
||||||
# # Setup data factory
|
self.client2.connect()
|
||||||
# factory = ReplicatedDataFactory()
|
|
||||||
# factory.register_type(SampleData, RepSampleData)
|
|
||||||
|
|
||||||
# server = Server(factory=factory)
|
|
||||||
# server.serve()
|
|
||||||
|
|
||||||
# client = Client(factory=factory)
|
def test_register_client_data(self):
|
||||||
# client.connect()
|
data_sample_key = self.client.register(SampleData())
|
||||||
|
|
||||||
# client2 = Client(factory=factory)
|
#Waiting for server to receive the datas
|
||||||
# client2.connect()
|
time.sleep(2)
|
||||||
|
|
||||||
# data_sample_key = client.register(SampleData())
|
test_key = self.client2._rep_store[data_sample_key]
|
||||||
|
|
||||||
# #Waiting for server to receive the datas
|
#Check if the server receive them
|
||||||
# time.sleep(1)
|
self.assertNotEqual(test_key, None)
|
||||||
|
|
||||||
# #Check if the server receive them
|
def test_register_client_data2(self):
|
||||||
# self.assertNotEqual(client2._rep_store[data_sample_key],None)
|
data_sample_key = self.client.register(SampleData())
|
||||||
|
|
||||||
|
#Waiting for server to receive the datas
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
test_key = self.client2._rep_store[data_sample_key]
|
||||||
|
|
||||||
|
#Check if the server receive them
|
||||||
|
self.assertNotEqual(test_key, None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def suite():
|
||||||
|
suite = unittest.TestSuite()
|
||||||
|
suite.addTest(TestDataFactory('test_data_factory'))
|
||||||
|
suite.addTest(TestDataReplication('test_register_client_data'))
|
||||||
|
suite.addTest(TestDataReplication('test_register_client_data2'))
|
||||||
|
|
||||||
|
return suite
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
# unittest.main()
|
||||||
|
runner = unittest.TextTestRunner(failfast=True)
|
||||||
|
runner.run(suite())
|
Reference in New Issue
Block a user