feat: test progress

This commit is contained in:
Swann Martinez
2019-07-18 18:15:01 +02:00
parent 51093b1307
commit 672d215b40
3 changed files with 67 additions and 44 deletions

View File

@ -72,7 +72,7 @@ class ReplicatedDatablock(object):
""" """
uuid = None # uuid used as key (string) uuid = None # uuid used as key (string)
pointer = None # dcc data ref (DCC type) pointer = None # dcc data ref (DCC type)
buffer = None # data blob (json) buffer = None # raw data (json)
str_type = None # data type name (string) str_type = None # data type name (string)
deps = [None] # dependencies array (string) deps = [None] # dependencies array (string)
owner = None # Data owner (string) owner = None # Data owner (string)
@ -82,8 +82,13 @@ class ReplicatedDatablock(object):
self.uuid = uuid if uuid else str(uuid4()) self.uuid = uuid if uuid else str(uuid4())
assert(owner) assert(owner)
self.owner = owner self.owner = owner
self.pointer = data
self.buffer = buffer if buffer else None if data:
self.pointer = data
elif buffer:
self.buffer = self.deserialize(buffer)
else:
raise ValueError("Not enought parameter in constructor")
self.str_type = type(self).__name__ self.str_type = type(self).__name__
def push(self, socket): def push(self, socket):
@ -112,7 +117,6 @@ class ReplicatedDatablock(object):
str_type = str_type.decode() str_type = str_type.decode()
owner = owner.decode() owner = owner.decode()
uuid = uuid.decode() uuid = uuid.decode()
data = self.deserialize(data)
instance = factory.construct_from_net(str_type)(owner=owner, uuid=uuid, buffer=data) instance = factory.construct_from_net(str_type)(owner=owner, uuid=uuid, buffer=data)
@ -176,7 +180,7 @@ class RepCommand(ReplicatedDatablock):
return pickle.dumps(data) return pickle.dumps(data)
def deserialize(self,data): def deserialize(self,data):
return pickle.load(data) return pickle.loads(data)
def apply(self,data,target): def apply(self,data,target):
target = data target = data

View File

@ -13,13 +13,14 @@ STATE_ACTIVE = 2
class Client(object): class Client(object):
def __init__(self,factory=None, config=None): def __init__(self,factory=None, id='default'):
assert(factory) assert(factory)
self._rep_store = {} self._rep_store = {}
self._net_client = ClientNetService( self._net_client = ClientNetService(
store_reference=self._rep_store, store_reference=self._rep_store,
factory=factory) factory=factory,
id=id)
self._factory = factory self._factory = factory
def connect(self): def connect(self):
@ -58,7 +59,7 @@ class Client(object):
pass pass
class ClientNetService(threading.Thread): class ClientNetService(threading.Thread):
def __init__(self,store_reference=None, factory=None): def __init__(self,store_reference=None, factory=None,id="default"):
# Threading # Threading
threading.Thread.__init__(self) threading.Thread.__init__(self)
@ -68,6 +69,7 @@ class ClientNetService(threading.Thread):
self._exit_event = threading.Event() self._exit_event = threading.Event()
self._factory = factory self._factory = factory
self._store_reference = store_reference self._store_reference = store_reference
self._id = id
assert(self._factory) assert(self._factory)
@ -75,7 +77,7 @@ class ClientNetService(threading.Thread):
self.context = zmq.Context.instance() self.context = zmq.Context.instance()
self.snapshot = self.context.socket(zmq.DEALER) self.snapshot = self.context.socket(zmq.DEALER)
self.snapshot.setsockopt(zmq.IDENTITY, b'SERVER') self.snapshot.setsockopt(zmq.IDENTITY, self._id.encode())
self.snapshot.connect("tcp://127.0.0.1:5560") self.snapshot.connect("tcp://127.0.0.1:5560")
self.subscriber = self.context.socket(zmq.SUB) self.subscriber = self.context.socket(zmq.SUB)
@ -90,7 +92,7 @@ class ClientNetService(threading.Thread):
def run(self): def run(self):
logger.info("Client is online") logger.info("{} online".format(self._id))
poller = zmq.Poller() poller = zmq.Poller()
poller.register(self.snapshot, zmq.POLLIN) poller.register(self.snapshot, zmq.POLLIN)
poller.register(self.subscriber, zmq.POLLIN) poller.register(self.subscriber, zmq.POLLIN)
@ -103,6 +105,7 @@ class ClientNetService(threading.Thread):
ACTIVE : Do nothing ACTIVE : Do nothing
""" """
if self.state == STATE_INITIAL: if self.state == STATE_INITIAL:
logger.debug('{} : request snapshot'.format(self._id))
self.snapshot.send(b"SNAPSHOT_REQUEST") self.snapshot.send(b"SNAPSHOT_REQUEST")
self.state = STATE_SYNCING self.state = STATE_SYNCING
@ -112,19 +115,22 @@ class ClientNetService(threading.Thread):
SYNCING : Ask for snapshots SYNCING : Ask for snapshots
ACTIVE : Do nothing ACTIVE : Do nothing
""" """
items = dict(poller.poll(1)) items = dict(poller.poll(10))
if self.snapshot in items: if self.snapshot in items:
if self.state == STATE_SYNCING: if self.state == STATE_SYNCING:
datablock = ReplicatedDatablock.pull(self.snapshot, self._factory) datablock = ReplicatedDatablock.pull(self.snapshot, self._factory)
if isinstance(datablock, RepCommand): if datablock.buffer == 'SNAPSHOT_END':
self.state = STATE_ACTIVE
logger.debug('{} : snapshot done'.format(self._id))
# We receive updates from the server ! # We receive updates from the server !
if self.subscriber in items: if self.subscriber in items:
if self.state == STATE_ACTIVE: if self.state == STATE_ACTIVE:
logger.debug("Receiving changes from server") logger.debug("{} : Receiving changes from server".format(self._id))
datablock = ReplicatedDatablock.pull(self.subscriber, self._factory) datablock = ReplicatedDatablock.pull(self.subscriber, self._factory)
datablock.store(self._store_reference) datablock.store(self._store_reference)
@ -216,7 +222,7 @@ class ServerNetService(threading.Thread):
while not self._exit_event.is_set(): while not self._exit_event.is_set():
# Non blocking poller # Non blocking poller
socks = dict(poller.poll()) socks = dict(poller.poll(10))
# Snapshot system for late join (Server - Client) # Snapshot system for late join (Server - Client)
if self.snapshot in socks: if self.snapshot in socks:

View File

@ -26,57 +26,70 @@ class RepSampleData(ReplicatedDatablock):
def deserialize(self,data): def deserialize(self,data):
import pickle import pickle
return pickle.load(data) return pickle.loads(data)
class TestDataReplication(unittest.TestCase):
class TestDataFactory(unittest.TestCase):
def test_data_factory(self): def test_data_factory(self):
factory = ReplicatedDataFactory() factory = ReplicatedDataFactory()
factory.register_type(SampleData, RepSampleData) factory.register_type(SampleData, RepSampleData)
data_sample = SampleData() data_sample = SampleData()
rep_sample = factory.construct_from_dcc(data_sample)(owner="toto") rep_sample = factory.construct_from_dcc(data_sample)(owner="toto", data=data_sample)
self.assertEqual(isinstance(rep_sample,RepSampleData), True) self.assertEqual(isinstance(rep_sample,RepSampleData), True)
def test_basic_client_start(self):
factory = ReplicatedDataFactory()
factory.register_type(SampleData, RepSampleData)
server = Server(factory=factory) class TestDataReplication(unittest.TestCase):
server.serve() def __init__(self,methodName='runTest'):
unittest.TestCase.__init__(self, methodName)
client = Client(factory=factory) self.factory = ReplicatedDataFactory()
client.connect() self.factory.register_type(SampleData, RepSampleData)
time.sleep(1) self.server = Server(factory=self.factory)
self.server.serve()
self.assertEqual(client.state(), 2) self.client = Client(factory=self.factory, id="client_1")
self.client.connect()
# def test_register_client_data(self): self.client2 = Client(factory=self.factory, id="client_2")
# # Setup data factory self.client2.connect()
# factory = ReplicatedDataFactory()
# factory.register_type(SampleData, RepSampleData)
# server = Server(factory=factory)
# server.serve()
# client = Client(factory=factory) def test_register_client_data(self):
# client.connect() data_sample_key = self.client.register(SampleData())
# client2 = Client(factory=factory) #Waiting for server to receive the datas
# client2.connect() time.sleep(2)
# data_sample_key = client.register(SampleData()) test_key = self.client2._rep_store[data_sample_key]
# #Waiting for server to receive the datas #Check if the server receive them
# time.sleep(1) self.assertNotEqual(test_key, None)
def test_register_client_data2(self):
data_sample_key = self.client.register(SampleData())
# #Check if the server receive them #Waiting for server to receive the datas
# self.assertNotEqual(client2._rep_store[data_sample_key],None) time.sleep(2)
test_key = self.client2._rep_store[data_sample_key]
#Check if the server receive them
self.assertNotEqual(test_key, None)
def suite():
suite = unittest.TestSuite()
suite.addTest(TestDataFactory('test_data_factory'))
suite.addTest(TestDataReplication('test_register_client_data'))
suite.addTest(TestDataReplication('test_register_client_data2'))
return suite
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() # unittest.main()
runner = unittest.TextTestRunner(failfast=True)
runner.run(suite())