From fe3a3e22f785231b0d41aa802d980e2db0ab8b13 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sun, 26 Feb 2023 07:25:49 -0600 Subject: [PATCH] Expose Channel on Link Separates channel interface from link Also added: allow multiple message handlers --- RNS/Channel.py | 63 ++++++++++++++++++++++++------------------------ RNS/Link.py | 18 +------------- RNS/__init__.py | 1 + tests/channel.py | 2 +- tests/link.py | 12 ++++++--- 5 files changed, 43 insertions(+), 53 deletions(-) diff --git a/RNS/Channel.py b/RNS/Channel.py index 2594d37..0aebb4d 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -4,22 +4,22 @@ import enum import threading import time from types import TracebackType -from typing import Type, Callable, TypeVar +from typing import Type, Callable, TypeVar, Generic, NewType import abc import contextlib import struct import RNS from abc import ABC, abstractmethod -_TPacket = TypeVar("_TPacket") +TPacket = TypeVar("TPacket") -class ChannelOutletBase(ABC): +class ChannelOutletBase(ABC, Generic[TPacket]): @abstractmethod - def send(self, raw: bytes) -> _TPacket: + def send(self, raw: bytes) -> TPacket: raise NotImplemented() @abstractmethod - def resend(self, packet: _TPacket) -> _TPacket: + def resend(self, packet: TPacket) -> TPacket: raise NotImplemented() @property @@ -38,7 +38,7 @@ class ChannelOutletBase(ABC): raise NotImplemented() @abstractmethod - def get_packet_state(self, packet: _TPacket) -> MessageState: + def get_packet_state(self, packet: TPacket) -> MessageState: raise NotImplemented() @abstractmethod @@ -50,16 +50,16 @@ class ChannelOutletBase(ABC): raise NotImplemented() @abstractmethod - def set_packet_timeout_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None, + def set_packet_timeout_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None, timeout: float | None = None): raise NotImplemented() @abstractmethod - def set_packet_delivered_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None): + def set_packet_delivered_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None): raise NotImplemented() @abstractmethod - def get_packet_id(self, packet: _TPacket) -> any: + def get_packet_id(self, packet: TPacket) -> any: raise NotImplemented() @@ -97,6 +97,9 @@ class MessageBase(abc.ABC): raise NotImplemented() +MessageCallbackType = NewType("MessageCallbackType", Callable[[MessageBase], bool]) + + class Envelope: def unpack(self, message_factories: dict[int, Type]) -> MessageBase: msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6]) @@ -120,7 +123,7 @@ class Envelope: self.id = id(self) self.message = message self.raw = raw - self.packet: _TPacket = None + self.packet: TPacket = None self.sequence = sequence self.outlet = outlet self.tries = 0 @@ -133,9 +136,9 @@ class Channel(contextlib.AbstractContextManager): self._lock = threading.RLock() self._tx_ring: collections.deque[Envelope] = collections.deque() self._rx_ring: collections.deque[Envelope] = collections.deque() - self._message_callback: Callable[[MessageBase], None] | None = None + self._message_callbacks: [MessageCallbackType] = [] self._next_sequence = 0 - self._message_factories: dict[int, Type[MessageBase]] = self._get_msg_constructors() + self._message_factories: dict[int, Type[MessageBase]] = {} self._max_tries = 5 def __enter__(self) -> Channel: @@ -146,16 +149,6 @@ class Channel(contextlib.AbstractContextManager): self.shutdown() return False - @staticmethod - def _get_msg_constructors() -> (int, Type[MessageBase]): - subclass_tuples = [] - for subclass in MessageBase.__subclasses__(): - with contextlib.suppress(Exception): - subclass() # verify constructor works with no arguments, needed for unpacking - subclass_tuples.append((subclass.MSGTYPE, subclass)) - message_factories = dict(subclass_tuples) - return message_factories - def register_message_type(self, message_class: Type[MessageBase]): if not issubclass(message_class, MessageBase): raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.") @@ -169,10 +162,15 @@ class Channel(contextlib.AbstractContextManager): self._message_factories[message_class.MSGTYPE] = message_class - def set_message_callback(self, callback: Callable[[MessageBase], None]): - self._message_callback = callback + def add_message_callback(self, callback: MessageCallbackType): + if callback not in self._message_callbacks: + self._message_callbacks.append(callback) + + def remove_message_callback(self, callback: MessageCallbackType): + self._message_callbacks.remove(callback) def shutdown(self): + self._message_callbacks.clear() self.clear_rings() def clear_rings(self): @@ -218,9 +216,8 @@ class Channel(contextlib.AbstractContextManager): RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) return RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) - if self._message_callback: - threading.Thread(target=self._message_callback, name="Message Callback", args=[message], daemon=True)\ - .start() + for cb in self._message_callbacks: + threading.Thread(target=cb, name="Message Callback", args=[message], daemon=True).start() except Exception as ex: RNS.log(f"Channel: Error receiving data: {ex}") @@ -237,7 +234,7 @@ class Channel(contextlib.AbstractContextManager): return False return True - def _packet_tx_op(self, packet: _TPacket, op: Callable[[_TPacket], bool]): + def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]): with self._lock: envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet), self._tx_ring), None) @@ -250,10 +247,10 @@ class Channel(contextlib.AbstractContextManager): if not envelope: RNS.log("Channel: Spurious message received.", RNS.LOG_EXTREME) - def _packet_delivered(self, packet: _TPacket): + def _packet_delivered(self, packet: TPacket): self._packet_tx_op(packet, lambda env: True) - def _packet_timeout(self, packet: _TPacket): + def _packet_timeout(self, packet: TPacket): def retry_envelope(envelope: Envelope) -> bool: if envelope.tries >= self._max_tries: RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR) @@ -286,6 +283,10 @@ class Channel(contextlib.AbstractContextManager): self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout) return envelope + @property + def MDU(self): + return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence) + class LinkChannelOutlet(ChannelOutletBase): def __init__(self, link: RNS.Link): @@ -313,7 +314,7 @@ class LinkChannelOutlet(ChannelOutletBase): def is_usable(self): return True # had issues looking at Link.status - def get_packet_state(self, packet: _TPacket) -> MessageState: + def get_packet_state(self, packet: TPacket) -> MessageState: status = packet.receipt.get_status() if status == RNS.PacketReceipt.SENT: return MessageState.MSGSTATE_SENT diff --git a/RNS/Link.py b/RNS/Link.py index 4fd58f9..5f137d4 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -645,27 +645,11 @@ class Link: if pending_request.request_id == resource.request_id: pending_request.request_timed_out(None) - def _ensure_channel(self): + def get_channel(self): if self._channel is None: self._channel = Channel(LinkChannelOutlet(self)) return self._channel - def set_message_callback(self, callback, message_types=None): - if not callback: - if self._channel: - self._channel.set_message_callback(None) - return - - self._ensure_channel() - - if message_types: - for msg_type in message_types: - self._channel.register_message_type(msg_type) - self._channel.set_message_callback(callback) - - def send_message(self, message: RNS.Channel.MessageBase): - self._ensure_channel().send(message) - def receive(self, packet): self.watchdog_lock = True if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])): diff --git a/RNS/__init__.py b/RNS/__init__.py index 3c2f2ae..0ec2140 100755 --- a/RNS/__init__.py +++ b/RNS/__init__.py @@ -32,6 +32,7 @@ from ._version import __version__ from .Reticulum import Reticulum from .Identity import Identity from .Link import Link, RequestReceipt +from .Channel import MessageBase from .Transport import Transport from .Destination import Destination from .Packet import Packet diff --git a/tests/channel.py b/tests/channel.py index eb57966..260f037 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -251,7 +251,7 @@ class TestChannel(unittest.TestCase): def handle_message(message: MessageBase): decoded.append(message) - self.h.channel.set_message_callback(handle_message) + self.h.channel.add_message_callback(handle_message) self.assertEqual(len(self.h.outlet.packets), 0) envelope = self.h.channel.send(message) diff --git a/tests/link.py b/tests/link.py index 5b22bb2..8322b49 100644 --- a/tests/link.py +++ b/tests/link.py @@ -380,8 +380,10 @@ class TestLink(unittest.TestCase): test_message = MessageTest() test_message.data = "Hello" - l1.set_message_callback(handle_message) - l1.send_message(test_message) + channel = l1.get_channel() + channel.register_message_type(MessageTest) + channel.add_message_callback(handle_message) + channel.send(test_message) time.sleep(0.5) @@ -458,11 +460,13 @@ def targets(yp=False): link.set_resource_strategy(RNS.Link.ACCEPT_ALL) link.set_resource_started_callback(resource_started) link.set_resource_concluded_callback(resource_concluded) + channel = link.get_channel() def handle_message(message): message.data = message.data + " back" - link.send_message(message) - link.set_message_callback(handle_message, [MessageTest]) + channel.send(message) + channel.register_message_type(MessageTest) + channel.add_message_callback(handle_message) m_rns = RNS.Reticulum("./tests/rnsconfig") id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))