diff --git a/Examples/Channel.py b/Examples/Channel.py index dfad943..4f1bd2c 100644 --- a/Examples/Channel.py +++ b/Examples/Channel.py @@ -152,7 +152,7 @@ def client_connected(link): # Register message types and add callback to channel channel = link.get_channel() channel.register_message_type(StringMessage) - channel.add_message_callback(server_message_received) + channel.add_message_handler(server_message_received) def client_disconnected(link): RNS.log("Client disconnected") @@ -290,7 +290,7 @@ def link_established(link): # Register messages and add handler to channel channel = link.get_channel() channel.register_message_type(StringMessage) - channel.add_message_callback(client_message_received) + channel.add_message_handler(client_message_received) # Inform the user that the server is # connected diff --git a/RNS/Channel.py b/RNS/Channel.py index 0aebb4d..0b023be 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -150,28 +150,34 @@ class Channel(contextlib.AbstractContextManager): return False def register_message_type(self, message_class: Type[MessageBase]): - if not issubclass(message_class, MessageBase): - raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.") - if message_class.MSGTYPE is None: - raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") - try: - message_class() - except Exception as ex: - raise ChannelException(CEType.ME_INVALID_MSG_TYPE, - f"{message_class} raised an exception when constructed with no arguments: {ex}") + with self._lock: + if not issubclass(message_class, MessageBase): + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} is not a subclass of {MessageBase}.") + if message_class.MSGTYPE is None: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} has invalid MSGTYPE class attribute.") + try: + message_class() + except Exception as ex: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} raised an exception when constructed with no arguments: {ex}") - self._message_factories[message_class.MSGTYPE] = message_class + self._message_factories[message_class.MSGTYPE] = message_class - def add_message_callback(self, callback: MessageCallbackType): - if callback not in self._message_callbacks: - self._message_callbacks.append(callback) + def add_message_handler(self, callback: MessageCallbackType): + with self._lock: + if callback not in self._message_callbacks: + self._message_callbacks.append(callback) - def remove_message_callback(self, callback: MessageCallbackType): - self._message_callbacks.remove(callback) + def remove_message_handler(self, callback: MessageCallbackType): + with self._lock: + self._message_callbacks.remove(callback) def shutdown(self): - self._message_callbacks.clear() - self.clear_rings() + with self._lock: + self._message_callbacks.clear() + self.clear_rings() def clear_rings(self): with self._lock: @@ -205,19 +211,29 @@ class Channel(contextlib.AbstractContextManager): env.tracked = False self._rx_ring.remove(env) + def _run_callbacks(self, message: MessageBase): + with self._lock: + cbs = self._message_callbacks.copy() + + for cb in cbs: + try: + if cb(message): + return + except Exception as ex: + RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR) + def receive(self, raw: bytes): try: envelope = Envelope(outlet=self._outlet, raw=raw) - message = envelope.unpack(self._message_factories) with self._lock: + message = envelope.unpack(self._message_factories) is_new = self.emplace_envelope(envelope, self._rx_ring) self.prune_rx_ring() if not is_new: RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) return RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) - for cb in self._message_callbacks: - threading.Thread(target=cb, name="Message Callback", args=[message], daemon=True).start() + threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start() except Exception as ex: RNS.log(f"Channel: Error receiving data: {ex}") diff --git a/tests/channel.py b/tests/channel.py index 260f037..c9a64b3 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -245,13 +245,49 @@ class TestChannel(unittest.TestCase): self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) self.assertFalse(envelope.tracked) + def test_multiple_handler(self): + handler1_called = 0 + handler1_return = True + handler2_called = 0 + + def handler1(msg: MessageBase): + nonlocal handler1_called, handler1_return + self.assertIsInstance(msg, MessageTest) + handler1_called += 1 + return handler1_return + + def handler2(msg: MessageBase): + nonlocal handler2_called + self.assertIsInstance(msg, MessageTest) + handler2_called += 1 + + message = MessageTest() + self.h.channel.register_message_type(MessageTest) + self.h.channel.add_message_handler(handler1) + self.h.channel.add_message_handler(handler2) + envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) + raw = envelope.pack() + self.h.channel.receive(raw) + + self.assertEqual(1, handler1_called) + self.assertEqual(0, handler2_called) + + handler1_return = False + envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) + raw = envelope.pack() + self.h.channel.receive(raw) + + self.assertEqual(2, handler1_called) + self.assertEqual(1, handler2_called) + + def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): decoded: [MessageBase] = [] def handle_message(message: MessageBase): decoded.append(message) - self.h.channel.add_message_callback(handle_message) + self.h.channel.add_message_handler(handle_message) self.assertEqual(len(self.h.outlet.packets), 0) envelope = self.h.channel.send(message) diff --git a/tests/link.py b/tests/link.py index 8322b49..021eed0 100644 --- a/tests/link.py +++ b/tests/link.py @@ -382,7 +382,7 @@ class TestLink(unittest.TestCase): channel = l1.get_channel() channel.register_message_type(MessageTest) - channel.add_message_callback(handle_message) + channel.add_message_handler(handle_message) channel.send(test_message) time.sleep(0.5) @@ -466,7 +466,7 @@ def targets(yp=False): message.data = message.data + " back" channel.send(message) channel.register_message_type(MessageTest) - channel.add_message_callback(handle_message) + channel.add_message_handler(handle_message) m_rns = RNS.Reticulum("./tests/rnsconfig") id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))