feat: factory get implementation by type name

(not sure about this...)
This commit is contained in:
Swann Martinez
2019-07-17 18:26:30 +02:00
parent 7b3d5d2334
commit 5f3badc81e
3 changed files with 16 additions and 12 deletions

View File

@ -35,6 +35,11 @@ class ReplicatedDataFactory(object):
print("type not supported for replication") print("type not supported for replication")
raise NotImplementedError raise NotImplementedError
def match_type_by_name(self,type_name):
for stypes, implementation in self.supported_types:
if type_name == implementation.__class__.__name__:
return implementation
def construct_from_dcc(self,data): def construct_from_dcc(self,data):
implementation = self.match_type_by_instance(data) implementation = self.match_type_by_instance(data)
return implementation return implementation
@ -61,9 +66,7 @@ class ReplicatedDatablock(object):
assert(owner) assert(owner)
self.owner = owner self.owner = owner
self.pointer = data self.pointer = data
self.str_type = self.data.__class__.__name__
if data:
self.str_type = self.data.__class__.__name__
def push(self, socket): def push(self, socket):
@ -76,8 +79,9 @@ class ReplicatedDatablock(object):
assert(isinstance(data, bytes)) assert(isinstance(data, bytes))
owner = self.owner.encode() owner = self.owner.encode()
key = self.uuid.encode() key = self.uuid.encode()
type = self.str_type.encode()
socket.send_multipart([key,owner,data]) socket.send_multipart([key,owner,str_type,data])
@classmethod @classmethod
def pull(cls, socket, factory): def pull(cls, socket, factory):
@ -86,9 +90,9 @@ class ReplicatedDatablock(object):
- read data from the socket - read data from the socket
- reconstruct an instance - reconstruct an instance
""" """
uuid, owner, data = socket.recv_multipart(zmq.NOBLOCK) uuid, owner,str_type, data = socket.recv_multipart(zmq.NOBLOCK)
instance = factory.construct_from_net(data)(owner=owner.decode(), uuid=uuid.decode()) instance = factory.construct_from_net(str_type.decode())(owner=owner.decode(), uuid=uuid.decode())
instance.data = instance.deserialize(data) instance.data = instance.deserialize(data)
return instance return instance

View File

@ -31,7 +31,7 @@ class Client(object):
""" """
assert(object) assert(object)
new_item = self._factory.construct_from_dcc(object)(owner="client") new_item = self._factory.construct_from_dcc(object)(owner="client", data=object)
if new_item: if new_item:
log.info("Registering {} on {}".format(object,new_item.uuid)) log.info("Registering {} on {}".format(object,new_item.uuid))
@ -107,9 +107,9 @@ class ClientNetService(threading.Thread):
class Server(): class Server():
def __init__(self,config=None): def __init__(self,config=None, factory=None):
self.rep_store = {} self.rep_store = {}
self.net = ServerNetService(self.rep_store) self.net = ServerNetService(store_reference=self.rep_store, factory=factory)
# self.serve() # self.serve()
def serve(self): def serve(self):
@ -205,7 +205,7 @@ class ServerNetService(threading.Thread):
# Regular update routing (Clients / Client) # Regular update routing (Clients / Client)
if self.pull in socks: if self.pull in socks:
log.info("Receiving changes from client") log.info("Receiving changes from client")
msg = ReplicatedDatablock.pull(self.pull) msg = ReplicatedDatablock.pull(self.pull, self.factory)
msg.store(self.store) msg.store(self.store)
# msg = message.Message.recv(self.collector_sock) # msg = message.Message.recv(self.collector_sock)

View File

@ -72,7 +72,7 @@ class TestDataReplication(unittest.TestCase):
factory = ReplicatedDataFactory() factory = ReplicatedDataFactory()
factory.register_type(SampleData, RepSampleData) factory.register_type(SampleData, RepSampleData)
server_api = Server() server_api = Server(factory=factory)
server_api.serve() server_api.serve()
client_api = Client(factory=factory) client_api = Client(factory=factory)
client_api.connect() client_api.connect()