refacor: code formatting

This commit is contained in:
Swann
2019-07-23 20:18:51 +02:00
parent ea645044c6
commit 7ff01273be
3 changed files with 116 additions and 126 deletions

View File

@ -1,19 +1,14 @@
import logging
from uuid import uuid4
import json
try:
from .libs import umsgpack
except:
# Server import
from libs import umsgpack
import zmq
import logging
import pickle
from enum import Enum
from uuid import uuid4
import zmq
logger = logging.getLogger(__name__)
class RepState(Enum):
ADDED = 0
COMMITED = 1
@ -29,9 +24,8 @@ class ReplicatedDataFactory(object):
self.supported_types = []
# Default registered types
self.register_type(str,RepCommand)
self.register_type(str, RepCommand)
self.register_type(RepDeleteCommand, RepDeleteCommand)
def register_type(self, dtype, implementation):
"""
@ -75,7 +69,7 @@ class ReplicatedDatablock(object):
uuid = None # uuid used as key (string)
pointer = None # dcc data ref (DCC type)
buffer = None # raw data (json)
str_type = None # data type name (string)
str_type = None # data type name (string)
deps = [None] # dependencies array (string)
owner = None # Data owner (string)
state = None # Data state (RepState)
@ -123,11 +117,10 @@ class ReplicatedDatablock(object):
uuid = uuid.decode()
instance = factory.construct_from_net(str_type)(owner=owner, uuid=uuid)
instance.buffer = instance.deserialize(data)
instance.buffer = instance.deserialize(data)
return instance
def store(self, dict, persistent=False):
"""
I want to store my replicated data. Persistent means into the disk
@ -142,20 +135,17 @@ class ReplicatedDatablock(object):
return self.uuid
def deserialize(self, data):
"""
BUFFER -> JSON
"""
raise NotImplementedError
def serialize(self, data):
"""
JSON -> BUFFER
"""
raise NotImplementedError
def dump(self):
"""
@ -165,13 +155,11 @@ class ReplicatedDatablock(object):
return json.dumps(self.pointer)
def load(self,target=None):
def load(self, target=None):
"""
JSON -> DCC
"""
raise NotImplementedError
def resolve(self):
"""
@ -181,39 +169,39 @@ class ReplicatedDatablock(object):
"""
raise NotImplementedError
def __repr__(self):
return "{uuid} - owner: {owner} - type: {type}".format(
uuid=self.uuid,
owner=self.owner,
type=self.str_type
)
)
class RepCommand(ReplicatedDatablock):
def serialize(self,data):
def serialize(self, data):
return pickle.dumps(data)
def deserialize(self,data):
def deserialize(self, data):
return pickle.loads(data)
def load(self,target):
def load(self, target):
target = self.pointer
class RepDeleteCommand(ReplicatedDatablock):
def serialize(self,data):
def serialize(self, data):
return pickle.dumps(data)
def deserialize(self,data):
def deserialize(self, data):
return pickle.loads(data)
def store(self,rep_store):
def store(self, rep_store):
assert(self.buffer)
if rep_store and self.buffer in rep_store.keys():
del rep_store[self.buffer]
# class RepObject(ReplicatedDatablock):
# def deserialize(self):
# try:

View File

@ -1,8 +1,10 @@
import threading
import logging
import zmq
import threading
import time
from replication import ReplicatedDatablock, RepCommand,RepDeleteCommand
import zmq
from replication import RepCommand, RepDeleteCommand, ReplicatedDatablock
from replication_graph import ReplicationGraph
logger = logging.getLogger(__name__)
@ -13,7 +15,7 @@ STATE_ACTIVE = 2
class Client(object):
def __init__(self,factory=None, id='default'):
def __init__(self, factory=None, id='default'):
assert(factory)
self._rep_store = ReplicationGraph()
@ -23,11 +25,11 @@ class Client(object):
id=id)
self._factory = factory
def connect(self,address="127.0.0.1",port=5560):
def connect(self, address="127.0.0.1", port=5560):
"""
Connect to the server
"""
self._net_client.connect(address=address,port=port)
self._net_client.connect(address=address, port=port)
def disconnect(self):
"""
@ -52,22 +54,23 @@ class Client(object):
find a better way to handle replication behavior
"""
assert(object)
# Construct the coresponding replication type
new_item = self._factory.construct_from_dcc(object)(owner="client", pointer=object)
new_item = self._factory.construct_from_dcc(
object)(owner="client", pointer=object)
if new_item:
logger.info("Registering {} on {}".format(object,new_item.uuid))
logger.info("Registering {} on {}".format(object, new_item.uuid))
new_item.store(self._rep_store)
logger.info("Pushing new registered value")
new_item.push(self._net_client.publish)
return new_item.uuid
else:
raise TypeError("Type not supported")
def unregister(self,object_uuid,clean=False):
def unregister(self, object_uuid, clean=False):
"""
Unregister for replication the given
object.
@ -76,27 +79,30 @@ class Client(object):
"""
if object_uuid in self._rep_store.keys():
delete_command = RepDeleteCommand(owner='client', buffer=object_uuid)
delete_command = RepDeleteCommand(
owner='client', buffer=object_uuid)
# remove the key from our store
delete_command.store(self._rep_store)
delete_command.push(self._net_client.publish)
else:
raise KeyError("Cannot unregister key")
def pull(self,object=None):
def pull(self, object=None):
"""
Asynchonous pull
Here we want to pull all waiting changes and apply them
"""
pass
class ClientNetService(threading.Thread):
def __init__(self,store_reference=None, factory=None,id="default"):
def __init__(self, store_reference=None, factory=None, id="default"):
# Threading
threading.Thread.__init__(self)
self.name = "ClientNetLink"
self.daemon = True
self._exit_event = threading.Event()
self._factory = factory
self._store_reference = store_reference
@ -108,16 +114,16 @@ class ClientNetService(threading.Thread):
self.context = zmq.Context.instance()
self.state = STATE_INITIAL
def connect(self,address='127.0.0.1', port=5560):
def connect(self, address='127.0.0.1', port=5560):
"""
Network socket setup
"""
if self.state == STATE_INITIAL:
logger.debug("connecting on {}:{}".format(address,port))
logger.debug("connecting on {}:{}".format(address, port))
self.snapshot = self.context.socket(zmq.DEALER)
self.snapshot.setsockopt(zmq.IDENTITY, self._id.encode())
self.snapshot.connect("tcp://{}:{}".format(address, port))
self.subscriber = self.context.socket(zmq.SUB)
self.subscriber.setsockopt_string(zmq.SUBSCRIBE, '')
# self.subscriber.setsockopt(zmq.IDENTITY, self._id.encode())
@ -146,7 +152,6 @@ class ClientNetService(threading.Thread):
logger.debug('{} : request snapshot'.format(self._id))
self.snapshot.send(b"SNAPSHOT_REQUEST")
self.state = STATE_SYNCING
"""NET IN
Given the net state we do something:
@ -157,7 +162,8 @@ class ClientNetService(threading.Thread):
if self.snapshot in items:
if self.state == STATE_SYNCING:
datablock = ReplicatedDatablock.pull(self.snapshot, self._factory)
datablock = ReplicatedDatablock.pull(
self.snapshot, self._factory)
if 'SNAPSHOT_END' in datablock.buffer:
self.state = STATE_ACTIVE
@ -168,37 +174,38 @@ class ClientNetService(threading.Thread):
# We receive updates from the server !
if self.subscriber in items:
if self.state == STATE_ACTIVE:
logger.debug("{} : Receiving changes from server".format(self._id))
datablock = ReplicatedDatablock.pull(self.subscriber, self._factory)
logger.debug(
"{} : Receiving changes from server".format(self._id))
datablock = ReplicatedDatablock.pull(
self.subscriber, self._factory)
datablock.store(self._store_reference)
if not items:
logger.error("No request ")
self.snapshot.close()
self.subscriber.close()
self.publish.close()
self._exit_event.clear()
def stop(self):
self._exit_event.set()
#Wait the end of the run
# Wait the end of the run
while self._exit_event.is_set():
time.sleep(.1)
self.state = 0
class Server():
def __init__(self,config=None, factory=None):
def __init__(self, config=None, factory=None):
self._rep_store = {}
self._net = ServerNetService(store_reference=self._rep_store, factory=factory)
self._net = ServerNetService(
store_reference=self._rep_store, factory=factory)
def serve(self,port=5560):
def serve(self, port=5560):
self._net.listen(port=port)
def state(self):
@ -209,7 +216,7 @@ class Server():
class ServerNetService(threading.Thread):
def __init__(self,store_reference=None, factory=None):
def __init__(self, store_reference=None, factory=None):
# Threading
threading.Thread.__init__(self)
self.name = "ServerNetLink"
@ -217,7 +224,7 @@ class ServerNetService(threading.Thread):
self._exit_event = threading.Event()
# Networking
self._rep_store = store_reference
self._rep_store = store_reference
self.context = zmq.Context.instance()
self.snapshot = None
@ -227,7 +234,6 @@ class ServerNetService(threading.Thread):
self.factory = factory
self.clients = {}
def listen(self, port=5560):
try:
# Update request
@ -249,7 +255,7 @@ class ServerNetService(threading.Thread):
self.pull.setsockopt(zmq.RCVHWM, 60)
self.pull.bind("tcp://*:{}".format(port+2))
self.start()
self.start()
except zmq.error.ZMQError:
logger.error("Address already in use, change net config")
@ -285,39 +291,38 @@ class ServerNetService(threading.Thread):
for key, item in self._rep_store.items():
self.snapshot.send(identity, zmq.SNDMORE)
item.push(self.snapshot)
# Snapshot end
self.snapshot.send(identity, zmq.SNDMORE)
RepCommand(owner='server',pointer='SNAPSHOT_END').push(self.snapshot)
RepCommand(owner='server', pointer='SNAPSHOT_END').push(
self.snapshot)
# Regular update routing (Clients / Server / Clients)
if self.pull in socks:
logger.debug("SERVER: Receiving changes from client")
datablock = ReplicatedDatablock.pull(self.pull, self.factory)
datablock.store(self._rep_store)
# Update all clients
# for cli_name,cli_id in self.clients.items():
# logger.debug("SERVER: Broadcast changes to {}".format(cli_name))
# self.publisher.send(cli_id, zmq.SNDMORE)
# datablock.push(self.publisher)
datablock.push(self.publisher)
self.snapshot.close()
self.pull.close()
self.publisher.close()
self._exit_event.clear()
def stop(self):
self._exit_event.set()
#Wait the end of the run
# Wait the end of the run
while self._exit_event.is_set():
time.sleep(.1)
self.state = 0
self.state = 0

View File

@ -1,45 +1,46 @@
import unittest
from replication import ReplicatedDatablock, ReplicatedDataFactory
import umsgpack
import logging
from replication_client import Client, Server
import time
import cProfile
import logging
import re
import time
import unittest
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
logger = logging.getLogger(__name__)
from replication import ReplicatedDatablock, ReplicatedDataFactory
from replication_client import Client, Server
class SampleData():
def __init__(self, map={"sample":"data"}):
def __init__(self, map={"sample": "data"}):
self.map = map
class RepSampleData(ReplicatedDatablock):
def serialize(self,data):
def serialize(self, data):
import pickle
return pickle.dumps(data)
def deserialize(self,data):
def deserialize(self, data):
import pickle
return pickle.loads(data)
def dump(self):
import json
output = {}
output['map'] = json.dumps(self.pointer.map)
return output
def load(self,target=None):
def load(self, target=None):
import json
if target is None:
target = SampleData()
target = SampleData()
target.map = json.loads(self.buffer['map'])
@ -49,13 +50,14 @@ class TestDataFactory(unittest.TestCase):
factory = ReplicatedDataFactory()
factory.register_type(SampleData, RepSampleData)
data_sample = SampleData()
rep_sample = factory.construct_from_dcc(data_sample)(owner="toto", pointer=data_sample)
self.assertEqual(isinstance(rep_sample,RepSampleData), True)
rep_sample = factory.construct_from_dcc(
data_sample)(owner="toto", pointer=data_sample)
self.assertEqual(isinstance(rep_sample, RepSampleData), True)
class TestClient(unittest.TestCase):
def __init__(self,methodName='runTest'):
def __init__(self, methodName='runTest'):
unittest.TestCase.__init__(self, methodName)
def test_empty_snapshot(self):
@ -70,7 +72,7 @@ class TestClient(unittest.TestCase):
client.connect(port=5570)
test_state = client.state
server.stop()
client.disconnect()
@ -84,14 +86,13 @@ class TestClient(unittest.TestCase):
server = Server(factory=factory)
client = Client(factory=factory, id="cli_test_filled_snapshot")
client2 = Client(factory=factory, id="client_2")
server.serve(port=5575)
client.connect(port=5575)
# Test the key registering
data_sample_key = client.register(SampleData())
client2.connect(port=5575)
time.sleep(0.2)
rep_test_key = client2._rep_store[data_sample_key].uuid
@ -104,7 +105,7 @@ class TestClient(unittest.TestCase):
def test_register_client_data(self):
# Setup environment
factory = ReplicatedDataFactory()
factory.register_type(SampleData, RepSampleData)
@ -116,21 +117,18 @@ class TestClient(unittest.TestCase):
client2 = Client(factory=factory, id="cli2_test_register_client_data")
client2.connect(port=5560)
# Test the key registering
data_sample_key = client.register(SampleData())
time.sleep(0.3)
#Waiting for server to receive the datas
# Waiting for server to receive the datas
rep_test_key = client2._rep_store[data_sample_key].uuid
client.disconnect()
client2.disconnect()
server.stop()
self.assertEqual(rep_test_key, data_sample_key)
def test_client_data_intergity(self):
@ -146,23 +144,21 @@ class TestClient(unittest.TestCase):
client2 = Client(factory=factory, id="cli2_test_client_data_intergity")
client2.connect(port=5560)
test_map = {"toto":"test"}
test_map = {"toto": "test"}
# Test the key registering
data_sample_key = client.register(SampleData(map=test_map))
test_map_result = SampleData()
#Waiting for server to receive the datas
# Waiting for server to receive the datas
time.sleep(1)
client2._rep_store[data_sample_key].load(target=test_map_result)
client.disconnect()
client2.disconnect()
server.stop()
self.assertEqual(test_map_result.map["toto"], test_map["toto"])
def test_client_unregister_key(self):
@ -178,18 +174,18 @@ class TestClient(unittest.TestCase):
client2 = Client(factory=factory, id="cli2_test_client_data_intergity")
client2.connect(port=5560)
test_map = {"toto":"test"}
test_map = {"toto": "test"}
# Test the key registering
data_sample_key = client.register(SampleData(map=test_map))
test_map_result = SampleData()
#Waiting for server to receive the datas
# Waiting for server to receive the datas
time.sleep(.1)
client2._rep_store[data_sample_key].load(target=test_map_result)
client.unregister(data_sample_key)
time.sleep(.1)
@ -205,7 +201,7 @@ class TestClient(unittest.TestCase):
server.stop()
self.assertFalse(data_sample_key in client._rep_store)
def test_client_disconnect(self):
pass
@ -223,7 +219,7 @@ class TestStressClient(unittest.TestCase):
server = Server(factory=factory)
client = Client(factory=factory, id="cli_test_filled_snapshot")
client2 = Client(factory=factory, id="client_2")
server.serve(port=5575)
client.connect(port=5575)
client2.connect(port=5575)
@ -234,7 +230,7 @@ class TestStressClient(unittest.TestCase):
while len(client2._rep_store.keys()) < 10000:
time.sleep(0.00001)
total_time+=0.00001
total_time += 0.00001
# test_num_items = len(client2._rep_store.keys())
server.stop()
@ -242,26 +238,27 @@ class TestStressClient(unittest.TestCase):
client2.disconnect()
logger.info("{} s for 10000 values".format(total_time))
self.assertLess(total_time,1)
self.assertLess(total_time, 1)
def suite():
suite = unittest.TestSuite()
# Data factory
suite.addTest(TestDataFactory('test_data_factory'))
# Client
# Client
suite.addTest(TestClient('test_empty_snapshot'))
suite.addTest(TestClient('test_filled_snapshot'))
suite.addTest(TestClient('test_register_client_data'))
suite.addTest(TestClient('test_client_data_intergity'))
# Stress test
suite.addTest(TestStressClient('test_stress_register'))
return suite
if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite())