From 68cb4a67405f15924ce56cfe2d32c7bdaea13f2e Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sat, 25 Feb 2023 18:23:25 -0600 Subject: [PATCH 01/17] Initial work on Channel --- RNS/Channel.py | 350 +++++++++++++++++++++++++++++++++++++++++++++++ RNS/Link.py | 34 ++++- RNS/Packet.py | 1 + tests/channel.py | 316 ++++++++++++++++++++++++++++++++++++++++++ tests/link.py | 59 ++++++++ 5 files changed, 759 insertions(+), 1 deletion(-) create mode 100644 RNS/Channel.py create mode 100644 tests/channel.py diff --git a/RNS/Channel.py b/RNS/Channel.py new file mode 100644 index 0000000..2594d37 --- /dev/null +++ b/RNS/Channel.py @@ -0,0 +1,350 @@ +from __future__ import annotations +import collections +import enum +import threading +import time +from types import TracebackType +from typing import Type, Callable, TypeVar +import abc +import contextlib +import struct +import RNS +from abc import ABC, abstractmethod +_TPacket = TypeVar("_TPacket") + + +class ChannelOutletBase(ABC): + @abstractmethod + def send(self, raw: bytes) -> _TPacket: + raise NotImplemented() + + @abstractmethod + def resend(self, packet: _TPacket) -> _TPacket: + raise NotImplemented() + + @property + @abstractmethod + def mdu(self): + raise NotImplemented() + + @property + @abstractmethod + def rtt(self): + raise NotImplemented() + + @property + @abstractmethod + def is_usable(self): + raise NotImplemented() + + @abstractmethod + def get_packet_state(self, packet: _TPacket) -> MessageState: + raise NotImplemented() + + @abstractmethod + def timed_out(self): + raise NotImplemented() + + @abstractmethod + def __str__(self): + raise NotImplemented() + + @abstractmethod + def set_packet_timeout_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None, + timeout: float | None = None): + raise NotImplemented() + + @abstractmethod + def set_packet_delivered_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None): + raise NotImplemented() + + @abstractmethod + def get_packet_id(self, packet: _TPacket) -> any: + raise NotImplemented() + + +class CEType(enum.IntEnum): + ME_NO_MSG_TYPE = 0 + ME_INVALID_MSG_TYPE = 1 + ME_NOT_REGISTERED = 2 + ME_LINK_NOT_READY = 3 + ME_ALREADY_SENT = 4 + ME_TOO_BIG = 5 + + +class ChannelException(Exception): + def __init__(self, ce_type: CEType, *args): + super().__init__(args) + self.type = ce_type + + +class MessageState(enum.IntEnum): + MSGSTATE_NEW = 0 + MSGSTATE_SENT = 1 + MSGSTATE_DELIVERED = 2 + MSGSTATE_FAILED = 3 + + +class MessageBase(abc.ABC): + MSGTYPE = None + + @abstractmethod + def pack(self) -> bytes: + raise NotImplemented() + + @abstractmethod + def unpack(self, raw): + raise NotImplemented() + + +class Envelope: + def unpack(self, message_factories: dict[int, Type]) -> MessageBase: + msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6]) + raw = self.raw[6:] + ctor = message_factories.get(msgtype, None) + if ctor is None: + raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}") + message = ctor() + message.unpack(raw) + return message + + def pack(self) -> bytes: + if self.message.__class__.MSGTYPE is None: + raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE") + data = self.message.pack() + self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data + return self.raw + + def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None): + self.ts = time.time() + self.id = id(self) + self.message = message + self.raw = raw + self.packet: _TPacket = None + self.sequence = sequence + self.outlet = outlet + self.tries = 0 + self.tracked = False + + +class Channel(contextlib.AbstractContextManager): + def __init__(self, outlet: ChannelOutletBase): + self._outlet = outlet + self._lock = threading.RLock() + self._tx_ring: collections.deque[Envelope] = collections.deque() + self._rx_ring: collections.deque[Envelope] = collections.deque() + self._message_callback: Callable[[MessageBase], None] | None = None + self._next_sequence = 0 + self._message_factories: dict[int, Type[MessageBase]] = self._get_msg_constructors() + self._max_tries = 5 + + def __enter__(self) -> Channel: + return self + + def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, + __traceback: TracebackType | None) -> bool | None: + self.shutdown() + return False + + @staticmethod + def _get_msg_constructors() -> (int, Type[MessageBase]): + subclass_tuples = [] + for subclass in MessageBase.__subclasses__(): + with contextlib.suppress(Exception): + subclass() # verify constructor works with no arguments, needed for unpacking + subclass_tuples.append((subclass.MSGTYPE, subclass)) + message_factories = dict(subclass_tuples) + return message_factories + + def register_message_type(self, message_class: Type[MessageBase]): + if not issubclass(message_class, MessageBase): + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.") + if message_class.MSGTYPE is None: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") + try: + message_class() + except Exception as ex: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} raised an exception when constructed with no arguments: {ex}") + + self._message_factories[message_class.MSGTYPE] = message_class + + def set_message_callback(self, callback: Callable[[MessageBase], None]): + self._message_callback = callback + + def shutdown(self): + self.clear_rings() + + def clear_rings(self): + with self._lock: + for envelope in self._tx_ring: + if envelope.packet is not None: + self._outlet.set_packet_timeout_callback(envelope.packet, None) + self._outlet.set_packet_delivered_callback(envelope.packet, None) + self._tx_ring.clear() + self._rx_ring.clear() + + def emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: + with self._lock: + i = 0 + for env in ring: + if env.sequence < envelope.sequence: + ring.insert(i, envelope) + return True + if env.sequence == envelope.sequence: + RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME) + return False + i += 1 + envelope.tracked = True + ring.append(envelope) + return True + + def prune_rx_ring(self): + with self._lock: + # Implementation for fixed window = 1 + stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:] + for env in stale: + env.tracked = False + self._rx_ring.remove(env) + + def receive(self, raw: bytes): + try: + envelope = Envelope(outlet=self._outlet, raw=raw) + message = envelope.unpack(self._message_factories) + with self._lock: + is_new = self.emplace_envelope(envelope, self._rx_ring) + self.prune_rx_ring() + if not is_new: + RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) + return + RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) + if self._message_callback: + threading.Thread(target=self._message_callback, name="Message Callback", args=[message], daemon=True)\ + .start() + except Exception as ex: + RNS.log(f"Channel: Error receiving data: {ex}") + + def is_ready_to_send(self) -> bool: + if not self._outlet.is_usable: + RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME) + return False + + with self._lock: + for envelope in self._tx_ring: + if envelope.outlet == self._outlet and (not envelope.packet + or self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_SENT): + RNS.log("Channel: Link has a pending message.", RNS.LOG_EXTREME) + return False + return True + + def _packet_tx_op(self, packet: _TPacket, op: Callable[[_TPacket], bool]): + with self._lock: + envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet), + self._tx_ring), None) + if envelope and op(envelope): + envelope.tracked = False + if envelope in self._tx_ring: + self._tx_ring.remove(envelope) + else: + RNS.log("Channel: Envelope not found in TX ring", RNS.LOG_DEBUG) + if not envelope: + RNS.log("Channel: Spurious message received.", RNS.LOG_EXTREME) + + def _packet_delivered(self, packet: _TPacket): + self._packet_tx_op(packet, lambda env: True) + + def _packet_timeout(self, packet: _TPacket): + def retry_envelope(envelope: Envelope) -> bool: + if envelope.tries >= self._max_tries: + RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR) + self.shutdown() # start on separate thread? + self._outlet.timed_out() + return True + envelope.tries += 1 + self._outlet.resend(envelope.packet) + return False + + self._packet_tx_op(packet, retry_envelope) + + def send(self, message: MessageBase) -> Envelope: + envelope: Envelope | None = None + with self._lock: + if not self.is_ready_to_send(): + raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") + envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence) + self._next_sequence = (self._next_sequence + 1) % 0x10000 + self.emplace_envelope(envelope, self._tx_ring) + if envelope is None: + raise BlockingIOError() + + envelope.pack() + if len(envelope.raw) > self._outlet.mdu: + raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}") + envelope.packet = self._outlet.send(envelope.raw) + envelope.tries += 1 + self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) + self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout) + return envelope + + +class LinkChannelOutlet(ChannelOutletBase): + def __init__(self, link: RNS.Link): + self.link = link + + def send(self, raw: bytes) -> RNS.Packet: + packet = RNS.Packet(self.link, raw, context=RNS.Packet.CHANNEL) + packet.send() + return packet + + def resend(self, packet: RNS.Packet) -> RNS.Packet: + if not packet.resend(): + RNS.log("Failed to resend packet", RNS.LOG_ERROR) + return packet + + @property + def mdu(self): + return self.link.MDU + + @property + def rtt(self): + return self.link.rtt + + @property + def is_usable(self): + return True # had issues looking at Link.status + + def get_packet_state(self, packet: _TPacket) -> MessageState: + status = packet.receipt.get_status() + if status == RNS.PacketReceipt.SENT: + return MessageState.MSGSTATE_SENT + if status == RNS.PacketReceipt.DELIVERED: + return MessageState.MSGSTATE_DELIVERED + if status == RNS.PacketReceipt.FAILED: + return MessageState.MSGSTATE_FAILED + else: + raise Exception(f"Unexpected receipt state: {status}") + + def timed_out(self): + self.link.teardown() + + def __str__(self): + return f"{self.__class__.__name__}({self.link})" + + def set_packet_timeout_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None, + timeout: float | None = None): + if timeout: + packet.receipt.set_timeout(timeout) + + def inner(receipt: RNS.PacketReceipt): + callback(packet) + + packet.receipt.set_timeout_callback(inner if callback else None) + + def set_packet_delivered_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None): + def inner(receipt: RNS.PacketReceipt): + callback(packet) + + packet.receipt.set_delivery_callback(inner if callback else None) + + def get_packet_id(self, packet: RNS.Packet) -> any: + return packet.get_hash() diff --git a/RNS/Link.py b/RNS/Link.py index 00cca34..4fd58f9 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -22,7 +22,7 @@ from RNS.Cryptography import X25519PrivateKey, X25519PublicKey, Ed25519PrivateKey, Ed25519PublicKey from RNS.Cryptography import Fernet - +from RNS.Channel import Channel, LinkChannelOutlet from time import sleep from .vendor import umsgpack as umsgpack import threading @@ -163,6 +163,7 @@ class Link: self.destination = destination self.attached_interface = None self.__remote_identity = None + self._channel = None if self.destination == None: self.initiator = False self.prv = X25519PrivateKey.generate() @@ -462,6 +463,8 @@ class Link: resource.cancel() for resource in self.outgoing_resources: resource.cancel() + if self._channel: + self._channel.shutdown() self.prv = None self.pub = None @@ -642,6 +645,27 @@ class Link: if pending_request.request_id == resource.request_id: pending_request.request_timed_out(None) + def _ensure_channel(self): + if self._channel is None: + self._channel = Channel(LinkChannelOutlet(self)) + return self._channel + + def set_message_callback(self, callback, message_types=None): + if not callback: + if self._channel: + self._channel.set_message_callback(None) + return + + self._ensure_channel() + + if message_types: + for msg_type in message_types: + self._channel.register_message_type(msg_type) + self._channel.set_message_callback(callback) + + def send_message(self, message: RNS.Channel.MessageBase): + self._ensure_channel().send(message) + def receive(self, packet): self.watchdog_lock = True if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])): @@ -788,6 +812,14 @@ class Link: for resource in self.incoming_resources: resource.receive_part(packet) + elif packet.context == RNS.Packet.CHANNEL: + if not self._channel: + RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) + else: + plaintext = self.decrypt(packet.data) + self._channel.receive(plaintext) + packet.prove() + elif packet.packet_type == RNS.Packet.PROOF: if packet.context == RNS.Packet.RESOURCE_PRF: resource_hash = packet.data[0:RNS.Identity.HASHLENGTH//8] diff --git a/RNS/Packet.py b/RNS/Packet.py index cdc476c..a105dec 100755 --- a/RNS/Packet.py +++ b/RNS/Packet.py @@ -75,6 +75,7 @@ class Packet: PATH_RESPONSE = 0x0B # Packet is a response to a path request COMMAND = 0x0C # Packet is a command COMMAND_STATUS = 0x0D # Packet is a status of an executed command + CHANNEL = 0x0E # Packet contains link channel data KEEPALIVE = 0xFA # Packet is a keepalive packet LINKIDENTIFY = 0xFB # Packet is a link peer identification proof LINKCLOSE = 0xFC # Packet is a link close message diff --git a/tests/channel.py b/tests/channel.py new file mode 100644 index 0000000..eb57966 --- /dev/null +++ b/tests/channel.py @@ -0,0 +1,316 @@ +from __future__ import annotations +import threading +import RNS +from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase +from RNS.vendor import umsgpack +from typing import Callable +import contextlib +import typing +import types +import time +import uuid +import unittest + + +class Packet: + timeout = 1.0 + + def __init__(self, raw: bytes): + self.state = MessageState.MSGSTATE_NEW + self.raw = raw + self.packet_id = uuid.uuid4() + self.tries = 0 + self.timeout_id = None + self.lock = threading.RLock() + self.instances = 0 + self.timeout_callback: Callable[[Packet], None] | None = None + self.delivered_callback: Callable[[Packet], None] | None = None + + def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float): + with self.lock: + if timeout is not None: + self.timeout = timeout + self.timeout_callback = callback + + + def send(self): + self.tries += 1 + self.state = MessageState.MSGSTATE_SENT + + def elapsed(timeout: float, timeout_id: uuid.uuid4): + with self.lock: + self.instances += 1 + try: + time.sleep(timeout) + with self.lock: + if self.timeout_id == timeout_id: + self.timeout_id = None + self.state = MessageState.MSGSTATE_FAILED + if self.timeout_callback: + self.timeout_callback(self) + finally: + with self.lock: + self.instances -= 1 + + self.timeout_id = uuid.uuid4() + threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id], + daemon=True).start() + + def clear_timeout(self): + self.timeout_id = None + + def set_delivered_callback(self, callback: Callable[[Packet], None]): + self.delivered_callback = callback + + def delivered(self): + with self.lock: + self.state = MessageState.MSGSTATE_DELIVERED + self.timeout_id = None + if self.delivered_callback: + self.delivered_callback(self) + + +class ChannelOutletTest(ChannelOutletBase): + def get_packet_state(self, packet: Packet) -> MessageState: + return packet.state + + def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None, + timeout: float | None = None): + packet.set_timeout(callback, timeout) + + def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None): + packet.set_delivered_callback(callback) + + def get_packet_id(self, packet: Packet) -> any: + return packet.packet_id + + def __init__(self, mdu: int, rtt: float): + self.link_id = uuid.uuid4() + self.timeout_callbacks = 0 + self._mdu = mdu + self._rtt = rtt + self._usable = True + self.packets = [] + self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None + + def send(self, raw: bytes) -> Packet: + packet = Packet(raw) + packet.send() + self.packets.append(packet) + return packet + + def resend(self, packet: Packet) -> Packet: + packet.send() + return packet + + @property + def mdu(self): + return self._mdu + + @property + def rtt(self): + return self._rtt + + @property + def is_usable(self): + return self._usable + + def timed_out(self): + self.timeout_callbacks += 1 + + def __str__(self): + return str(self.link_id) + + +class MessageTest(MessageBase): + MSGTYPE = 0xabcd + + def __init__(self): + self.id = str(uuid.uuid4()) + self.data = "test" + self.not_serialized = str(uuid.uuid4()) + + def pack(self) -> bytes: + return umsgpack.packb((self.id, self.data)) + + def unpack(self, raw): + self.id, self.data = umsgpack.unpackb(raw) + + +class ProtocolHarness(contextlib.AbstractContextManager): + def __init__(self, rtt: float): + self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) + self.channel = Channel(self.outlet) + + def cleanup(self): + self.channel.shutdown() + + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: + # self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") + self.cleanup() + return False + + +class TestChannel(unittest.TestCase): + def setUp(self) -> None: + self.rtt = 0.001 + self.retry_interval = self.rtt * 150 + Packet.timeout = self.retry_interval + self.h = ProtocolHarness(self.rtt) + + def tearDown(self) -> None: + self.h.cleanup() + + def test_send_one_retry(self): + message = MessageTest() + + self.assertEqual(0, len(self.h.outlet.packets)) + + envelope = self.h.channel.send(message) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + time.sleep(self.retry_interval * 1.5) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(2, envelope.tries) + self.assertEqual(2, packet.tries) + self.assertEqual(1, packet.instances) + + time.sleep(self.retry_interval) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(self.h.outlet.packets[0], packet) + self.assertEqual(3, envelope.tries) + self.assertEqual(3, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + + packet.delivered() + + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + + time.sleep(self.retry_interval) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(3, envelope.tries) + self.assertEqual(3, packet.tries) + self.assertEqual(0, packet.instances) + self.assertFalse(envelope.tracked) + + def test_send_timeout(self): + message = MessageTest() + + self.assertEqual(0, len(self.h.outlet.packets)) + + envelope = self.h.channel.send(message) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + time.sleep(self.retry_interval * 7.5) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(5, envelope.tries) + self.assertEqual(5, packet.tries) + self.assertEqual(0, packet.instances) + self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) + self.assertFalse(envelope.tracked) + + def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): + decoded: [MessageBase] = [] + + def handle_message(message: MessageBase): + decoded.append(message) + + self.h.channel.set_message_callback(handle_message) + self.assertEqual(len(self.h.outlet.packets), 0) + + envelope = self.h.channel.send(message) + time.sleep(self.retry_interval * 0.5) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + packet.delivered() + + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + + time.sleep(self.retry_interval * 2) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(0, packet.instances) + self.assertFalse(envelope.tracked) + + self.assertEqual(len(self.h.outlet.packets), 1) + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + self.assertFalse(envelope.tracked) + self.assertEqual(0, len(decoded)) + + self.h.channel.receive(packet.raw) + + self.assertEqual(1, len(decoded)) + + rx_message = decoded[0] + + self.assertIsNotNone(rx_message) + self.assertIsInstance(rx_message, message.__class__) + checker(rx_message) + + def test_send_receive_message_test(self): + message = MessageTest() + + def check(rx_message: MessageBase): + self.assertIsInstance(rx_message, message.__class__) + self.assertEqual(message.id, rx_message.id) + self.assertEqual(message.data, rx_message.data) + self.assertNotEqual(message.not_serialized, rx_message.not_serialized) + + self.eat_own_dog_food(message, check) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/tests/link.py b/tests/link.py index 5368858..5b22bb2 100644 --- a/tests/link.py +++ b/tests/link.py @@ -6,6 +6,9 @@ import threading import time import RNS import os +from tests.channel import MessageTest +from RNS.Channel import MessageBase + APP_NAME = "rns_unit_tests" @@ -46,6 +49,11 @@ def close_rns(): global c_rns if c_rns != None: c_rns.m_proc.kill() + # stdout, stderr = c_rns.m_proc.communicate() + # if stdout: + # print(stdout.decode("utf-8")) + # if stderr: + # print(stderr.decode("utf-8")) class TestLink(unittest.TestCase): def setUp(self): @@ -346,6 +354,52 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + def test_10_channel_round_trip(self): + global c_rns + init_rns(self) + print("") + print("Channel round trip test") + + # TODO: Load this from public bytes only + id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) + self.assertEqual(id1.hash, bytes.fromhex(fixed_keys[0][1])) + + dest = RNS.Destination(id1, RNS.Destination.OUT, RNS.Destination.SINGLE, APP_NAME, "link", "establish") + + self.assertEqual(dest.hash, bytes.fromhex("fb48da0e82e6e01ba0c014513f74540d")) + + l1 = RNS.Link(dest) + time.sleep(1) + self.assertEqual(l1.status, RNS.Link.ACTIVE) + + received = [] + + def handle_message(message: MessageBase): + received.append(message) + + test_message = MessageTest() + test_message.data = "Hello" + + l1.set_message_callback(handle_message) + l1.send_message(test_message) + + time.sleep(0.5) + + self.assertEqual(1, len(received)) + + rx_message = received[0] + + self.assertIsInstance(rx_message, MessageTest) + self.assertEqual("Hello back", rx_message.data) + self.assertEqual(test_message.id, rx_message.id) + self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized) + self.assertEqual(1, len(l1._channel._rx_ring)) + + l1.teardown() + time.sleep(0.5) + self.assertEqual(l1.status, RNS.Link.CLOSED) + self.assertEqual(0, len(l1._channel._rx_ring)) + def size_str(self, num, suffix='B'): units = ['','K','M','G','T','P','E','Z'] @@ -405,6 +459,11 @@ def targets(yp=False): link.set_resource_started_callback(resource_started) link.set_resource_concluded_callback(resource_concluded) + def handle_message(message): + message.data = message.data + " back" + link.send_message(message) + link.set_message_callback(handle_message, [MessageTest]) + m_rns = RNS.Reticulum("./tests/rnsconfig") id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish") From fe3a3e22f785231b0d41aa802d980e2db0ab8b13 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sun, 26 Feb 2023 07:25:49 -0600 Subject: [PATCH 02/17] Expose Channel on Link Separates channel interface from link Also added: allow multiple message handlers --- RNS/Channel.py | 63 ++++++++++++++++++++++++------------------------ RNS/Link.py | 18 +------------- RNS/__init__.py | 1 + tests/channel.py | 2 +- tests/link.py | 12 ++++++--- 5 files changed, 43 insertions(+), 53 deletions(-) diff --git a/RNS/Channel.py b/RNS/Channel.py index 2594d37..0aebb4d 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -4,22 +4,22 @@ import enum import threading import time from types import TracebackType -from typing import Type, Callable, TypeVar +from typing import Type, Callable, TypeVar, Generic, NewType import abc import contextlib import struct import RNS from abc import ABC, abstractmethod -_TPacket = TypeVar("_TPacket") +TPacket = TypeVar("TPacket") -class ChannelOutletBase(ABC): +class ChannelOutletBase(ABC, Generic[TPacket]): @abstractmethod - def send(self, raw: bytes) -> _TPacket: + def send(self, raw: bytes) -> TPacket: raise NotImplemented() @abstractmethod - def resend(self, packet: _TPacket) -> _TPacket: + def resend(self, packet: TPacket) -> TPacket: raise NotImplemented() @property @@ -38,7 +38,7 @@ class ChannelOutletBase(ABC): raise NotImplemented() @abstractmethod - def get_packet_state(self, packet: _TPacket) -> MessageState: + def get_packet_state(self, packet: TPacket) -> MessageState: raise NotImplemented() @abstractmethod @@ -50,16 +50,16 @@ class ChannelOutletBase(ABC): raise NotImplemented() @abstractmethod - def set_packet_timeout_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None, + def set_packet_timeout_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None, timeout: float | None = None): raise NotImplemented() @abstractmethod - def set_packet_delivered_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None): + def set_packet_delivered_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None): raise NotImplemented() @abstractmethod - def get_packet_id(self, packet: _TPacket) -> any: + def get_packet_id(self, packet: TPacket) -> any: raise NotImplemented() @@ -97,6 +97,9 @@ class MessageBase(abc.ABC): raise NotImplemented() +MessageCallbackType = NewType("MessageCallbackType", Callable[[MessageBase], bool]) + + class Envelope: def unpack(self, message_factories: dict[int, Type]) -> MessageBase: msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6]) @@ -120,7 +123,7 @@ class Envelope: self.id = id(self) self.message = message self.raw = raw - self.packet: _TPacket = None + self.packet: TPacket = None self.sequence = sequence self.outlet = outlet self.tries = 0 @@ -133,9 +136,9 @@ class Channel(contextlib.AbstractContextManager): self._lock = threading.RLock() self._tx_ring: collections.deque[Envelope] = collections.deque() self._rx_ring: collections.deque[Envelope] = collections.deque() - self._message_callback: Callable[[MessageBase], None] | None = None + self._message_callbacks: [MessageCallbackType] = [] self._next_sequence = 0 - self._message_factories: dict[int, Type[MessageBase]] = self._get_msg_constructors() + self._message_factories: dict[int, Type[MessageBase]] = {} self._max_tries = 5 def __enter__(self) -> Channel: @@ -146,16 +149,6 @@ class Channel(contextlib.AbstractContextManager): self.shutdown() return False - @staticmethod - def _get_msg_constructors() -> (int, Type[MessageBase]): - subclass_tuples = [] - for subclass in MessageBase.__subclasses__(): - with contextlib.suppress(Exception): - subclass() # verify constructor works with no arguments, needed for unpacking - subclass_tuples.append((subclass.MSGTYPE, subclass)) - message_factories = dict(subclass_tuples) - return message_factories - def register_message_type(self, message_class: Type[MessageBase]): if not issubclass(message_class, MessageBase): raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.") @@ -169,10 +162,15 @@ class Channel(contextlib.AbstractContextManager): self._message_factories[message_class.MSGTYPE] = message_class - def set_message_callback(self, callback: Callable[[MessageBase], None]): - self._message_callback = callback + def add_message_callback(self, callback: MessageCallbackType): + if callback not in self._message_callbacks: + self._message_callbacks.append(callback) + + def remove_message_callback(self, callback: MessageCallbackType): + self._message_callbacks.remove(callback) def shutdown(self): + self._message_callbacks.clear() self.clear_rings() def clear_rings(self): @@ -218,9 +216,8 @@ class Channel(contextlib.AbstractContextManager): RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) return RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) - if self._message_callback: - threading.Thread(target=self._message_callback, name="Message Callback", args=[message], daemon=True)\ - .start() + for cb in self._message_callbacks: + threading.Thread(target=cb, name="Message Callback", args=[message], daemon=True).start() except Exception as ex: RNS.log(f"Channel: Error receiving data: {ex}") @@ -237,7 +234,7 @@ class Channel(contextlib.AbstractContextManager): return False return True - def _packet_tx_op(self, packet: _TPacket, op: Callable[[_TPacket], bool]): + def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]): with self._lock: envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet), self._tx_ring), None) @@ -250,10 +247,10 @@ class Channel(contextlib.AbstractContextManager): if not envelope: RNS.log("Channel: Spurious message received.", RNS.LOG_EXTREME) - def _packet_delivered(self, packet: _TPacket): + def _packet_delivered(self, packet: TPacket): self._packet_tx_op(packet, lambda env: True) - def _packet_timeout(self, packet: _TPacket): + def _packet_timeout(self, packet: TPacket): def retry_envelope(envelope: Envelope) -> bool: if envelope.tries >= self._max_tries: RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR) @@ -286,6 +283,10 @@ class Channel(contextlib.AbstractContextManager): self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout) return envelope + @property + def MDU(self): + return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence) + class LinkChannelOutlet(ChannelOutletBase): def __init__(self, link: RNS.Link): @@ -313,7 +314,7 @@ class LinkChannelOutlet(ChannelOutletBase): def is_usable(self): return True # had issues looking at Link.status - def get_packet_state(self, packet: _TPacket) -> MessageState: + def get_packet_state(self, packet: TPacket) -> MessageState: status = packet.receipt.get_status() if status == RNS.PacketReceipt.SENT: return MessageState.MSGSTATE_SENT diff --git a/RNS/Link.py b/RNS/Link.py index 4fd58f9..5f137d4 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -645,27 +645,11 @@ class Link: if pending_request.request_id == resource.request_id: pending_request.request_timed_out(None) - def _ensure_channel(self): + def get_channel(self): if self._channel is None: self._channel = Channel(LinkChannelOutlet(self)) return self._channel - def set_message_callback(self, callback, message_types=None): - if not callback: - if self._channel: - self._channel.set_message_callback(None) - return - - self._ensure_channel() - - if message_types: - for msg_type in message_types: - self._channel.register_message_type(msg_type) - self._channel.set_message_callback(callback) - - def send_message(self, message: RNS.Channel.MessageBase): - self._ensure_channel().send(message) - def receive(self, packet): self.watchdog_lock = True if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])): diff --git a/RNS/__init__.py b/RNS/__init__.py index 3c2f2ae..0ec2140 100755 --- a/RNS/__init__.py +++ b/RNS/__init__.py @@ -32,6 +32,7 @@ from ._version import __version__ from .Reticulum import Reticulum from .Identity import Identity from .Link import Link, RequestReceipt +from .Channel import MessageBase from .Transport import Transport from .Destination import Destination from .Packet import Packet diff --git a/tests/channel.py b/tests/channel.py index eb57966..260f037 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -251,7 +251,7 @@ class TestChannel(unittest.TestCase): def handle_message(message: MessageBase): decoded.append(message) - self.h.channel.set_message_callback(handle_message) + self.h.channel.add_message_callback(handle_message) self.assertEqual(len(self.h.outlet.packets), 0) envelope = self.h.channel.send(message) diff --git a/tests/link.py b/tests/link.py index 5b22bb2..8322b49 100644 --- a/tests/link.py +++ b/tests/link.py @@ -380,8 +380,10 @@ class TestLink(unittest.TestCase): test_message = MessageTest() test_message.data = "Hello" - l1.set_message_callback(handle_message) - l1.send_message(test_message) + channel = l1.get_channel() + channel.register_message_type(MessageTest) + channel.add_message_callback(handle_message) + channel.send(test_message) time.sleep(0.5) @@ -458,11 +460,13 @@ def targets(yp=False): link.set_resource_strategy(RNS.Link.ACCEPT_ALL) link.set_resource_started_callback(resource_started) link.set_resource_concluded_callback(resource_concluded) + channel = link.get_channel() def handle_message(message): message.data = message.data + " back" - link.send_message(message) - link.set_message_callback(handle_message, [MessageTest]) + channel.send(message) + channel.register_message_type(MessageTest) + channel.add_message_callback(handle_message) m_rns = RNS.Reticulum("./tests/rnsconfig") id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) From a61b15cf6ad576f306ead77ac4160b9dad7971e0 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sun, 26 Feb 2023 07:26:12 -0600 Subject: [PATCH 03/17] Added channel example --- Examples/Channel.py | 375 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 375 insertions(+) create mode 100644 Examples/Channel.py diff --git a/Examples/Channel.py b/Examples/Channel.py new file mode 100644 index 0000000..dfad943 --- /dev/null +++ b/Examples/Channel.py @@ -0,0 +1,375 @@ +########################################################## +# This RNS example demonstrates how to set up a link to # +# a destination, and pass structuredmessages over it # +# using a channel. # +########################################################## + +import os +import sys +import time +import argparse +from datetime import datetime + +import RNS +from RNS.vendor import umsgpack + +# Let's define an app name. We'll use this for all +# destinations we create. Since this echo example +# is part of a range of example utilities, we'll put +# them all within the app namespace "example_utilities" +APP_NAME = "example_utilities" + +########################################################## +#### Shared Objects ###################################### +########################################################## + +# Channel data must be structured in a subclass of +# MessageBase. This ensures that the channel will be able +# to serialize and deserialize the object and multiplex it +# with other objects. Both ends of a link will need the +# same object definitions to be able to communicate over +# a channel. +# +# Note: The objects we wish to use over the channel must +# be registered with the channel, and each link has a +# different channel instance. See the client_connected +# and link_established functions in this example to see +# how message types are registered. + +# Let's make a simple message class called StringMessage +# that will convey a string with a timestamp. + +class StringMessage(RNS.MessageBase): + # The MSGTYPE class variable needs to be assigned a + # 2 byte integer value. This identifier allows the + # channel to look up your message's constructor when a + # message arrives over the channel. + # + # MSGTYPE must be unique across all message types we + # register with the channel + MSGTYPE = 0x0101 + + # The constructor of our object must be callable with + # no arguments. We can have parameters, but they must + # have a default assignment. + # + # This is needed so the channel can create an empty + # version of our message into which the incoming + # message can be unpacked. + def __init__(self, data=None): + self.data = data + self.timestamp = datetime.now() + + # Finally, our message needs to implement functions + # the channel can call to pack and unpack our message + # to/from the raw packet payload. We'll use the + # umsgpack package bundled with RNS. We could also use + # the struct package bundled with Python if we wanted + # more control over the structure of the packed bytes. + # + # Also note that packed message objects must fit + # entirely in one packet. The number of bytes + # available for message payloads can be queried from + # the channel using the Channel.MDU property. The + # channel MDU is slightly less than the link MDU due + # to encoding the message header. + + # The pack function encodes the message contents into + # a byte stream. + def pack(self) -> bytes: + return umsgpack.packb((self.data, self.timestamp)) + + # And the unpack function decodes a byte stream into + # the message contents. + def unpack(self, raw): + self.data, self.timestamp = umsgpack.unpackb(raw) + + +########################################################## +#### Server Part ######################################### +########################################################## + +# A reference to the latest client link that connected +latest_client_link = None + +# This initialisation is executed when the users chooses +# to run as a server +def server(configpath): + # We must first initialise Reticulum + reticulum = RNS.Reticulum(configpath) + + # Randomly create a new identity for our link example + server_identity = RNS.Identity() + + # We create a destination that clients can connect to. We + # want clients to create links to this destination, so we + # need to create a "single" destination type. + server_destination = RNS.Destination( + server_identity, + RNS.Destination.IN, + RNS.Destination.SINGLE, + APP_NAME, + "channelexample" + ) + + # We configure a function that will get called every time + # a new client creates a link to this destination. + server_destination.set_link_established_callback(client_connected) + + # Everything's ready! + # Let's Wait for client requests or user input + server_loop(server_destination) + +def server_loop(destination): + # Let the user know that everything is ready + RNS.log( + "Link example "+ + RNS.prettyhexrep(destination.hash)+ + " running, waiting for a connection." + ) + + RNS.log("Hit enter to manually send an announce (Ctrl-C to quit)") + + # We enter a loop that runs until the users exits. + # If the user hits enter, we will announce our server + # destination on the network, which will let clients + # know how to create messages directed towards it. + while True: + entered = input() + destination.announce() + RNS.log("Sent announce from "+RNS.prettyhexrep(destination.hash)) + +# When a client establishes a link to our server +# destination, this function will be called with +# a reference to the link. +def client_connected(link): + global latest_client_link + latest_client_link = link + + RNS.log("Client connected") + link.set_link_closed_callback(client_disconnected) + + # Register message types and add callback to channel + channel = link.get_channel() + channel.register_message_type(StringMessage) + channel.add_message_callback(server_message_received) + +def client_disconnected(link): + RNS.log("Client disconnected") + +def server_message_received(message): + global latest_client_link + + # When a message is received over any active link, + # the replies will all be directed to the last client + # that connected. + if isinstance(message, StringMessage): + RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")") + + reply_message = StringMessage("I received \""+message.data+"\" over the link") + latest_client_link.get_channel().send(reply_message) + + +########################################################## +#### Client Part ######################################### +########################################################## + +# A reference to the server link +server_link = None + +# This initialisation is executed when the users chooses +# to run as a client +def client(destination_hexhash, configpath): + # We need a binary representation of the destination + # hash that was entered on the command line + try: + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 + if len(destination_hexhash) != dest_len: + raise ValueError( + "Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2) + ) + + destination_hash = bytes.fromhex(destination_hexhash) + except: + RNS.log("Invalid destination entered. Check your input!\n") + exit() + + # We must first initialise Reticulum + reticulum = RNS.Reticulum(configpath) + + # Check if we know a path to the destination + if not RNS.Transport.has_path(destination_hash): + RNS.log("Destination is not yet known. Requesting path and waiting for announce to arrive...") + RNS.Transport.request_path(destination_hash) + while not RNS.Transport.has_path(destination_hash): + time.sleep(0.1) + + # Recall the server identity + server_identity = RNS.Identity.recall(destination_hash) + + # Inform the user that we'll begin connecting + RNS.log("Establishing link with server...") + + # When the server identity is known, we set + # up a destination + server_destination = RNS.Destination( + server_identity, + RNS.Destination.OUT, + RNS.Destination.SINGLE, + APP_NAME, + "channelexample" + ) + + # And create a link + link = RNS.Link(server_destination) + + # We set a callback that will get executed + # every time a packet is received over the + # link + link.set_packet_callback(client_message_received) + + # We'll also set up functions to inform the + # user when the link is established or closed + link.set_link_established_callback(link_established) + link.set_link_closed_callback(link_closed) + + # Everything is set up, so let's enter a loop + # for the user to interact with the example + client_loop() + +def client_loop(): + global server_link + + # Wait for the link to become active + while not server_link: + time.sleep(0.1) + + should_quit = False + while not should_quit: + try: + print("> ", end=" ") + text = input() + + # Check if we should quit the example + if text == "quit" or text == "q" or text == "exit": + should_quit = True + server_link.teardown() + + # If not, send the entered text over the link + if text != "": + message = StringMessage(text) + packed_size = len(message.pack()) + channel = server_link.get_channel() + if channel.is_ready_to_send(): + if packed_size <= channel.MDU: + channel.send(message) + else: + RNS.log( + "Cannot send this packet, the data size of "+ + str(packed_size)+" bytes exceeds the link packet MDU of "+ + str(channel.MDU)+" bytes", + RNS.LOG_ERROR + ) + else: + RNS.log("Channel is not ready to send, please wait for " + + "pending messages to complete.", RNS.LOG_ERROR) + + except Exception as e: + RNS.log("Error while sending data over the link: "+str(e)) + should_quit = True + server_link.teardown() + +# This function is called when a link +# has been established with the server +def link_established(link): + # We store a reference to the link + # instance for later use + global server_link + server_link = link + + # Register messages and add handler to channel + channel = link.get_channel() + channel.register_message_type(StringMessage) + channel.add_message_callback(client_message_received) + + # Inform the user that the server is + # connected + RNS.log("Link established with server, enter some text to send, or \"quit\" to quit") + +# When a link is closed, we'll inform the +# user, and exit the program +def link_closed(link): + if link.teardown_reason == RNS.Link.TIMEOUT: + RNS.log("The link timed out, exiting now") + elif link.teardown_reason == RNS.Link.DESTINATION_CLOSED: + RNS.log("The link was closed by the server, exiting now") + else: + RNS.log("Link closed, exiting now") + + RNS.Reticulum.exit_handler() + time.sleep(1.5) + os._exit(0) + +# When a packet is received over the link, we +# simply print out the data. +def client_message_received(message): + if isinstance(message, StringMessage): + RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")") + print("> ", end=" ") + sys.stdout.flush() + + +########################################################## +#### Program Startup ##################################### +########################################################## + +# This part of the program runs at startup, +# and parses input of from the user, and then +# starts up the desired program mode. +if __name__ == "__main__": + try: + parser = argparse.ArgumentParser(description="Simple link example") + + parser.add_argument( + "-s", + "--server", + action="store_true", + help="wait for incoming link requests from clients" + ) + + parser.add_argument( + "--config", + action="store", + default=None, + help="path to alternative Reticulum config directory", + type=str + ) + + parser.add_argument( + "destination", + nargs="?", + default=None, + help="hexadecimal hash of the server destination", + type=str + ) + + args = parser.parse_args() + + if args.config: + configarg = args.config + else: + configarg = None + + if args.server: + server(configarg) + else: + if (args.destination == None): + print("") + parser.print_help() + print("") + else: + client(args.destination, configarg) + + except KeyboardInterrupt: + print("") + exit() \ No newline at end of file From e00582615144e0ae692015884fafd440b59acfee Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sun, 26 Feb 2023 11:23:38 -0600 Subject: [PATCH 04/17] Allow channel message handlers to short circuit - a message handler can return logical True to prevent subsequent message handlers from running --- Examples/Channel.py | 4 ++-- RNS/Channel.py | 56 +++++++++++++++++++++++++++++---------------- tests/channel.py | 38 +++++++++++++++++++++++++++++- tests/link.py | 4 ++-- 4 files changed, 77 insertions(+), 25 deletions(-) diff --git a/Examples/Channel.py b/Examples/Channel.py index dfad943..4f1bd2c 100644 --- a/Examples/Channel.py +++ b/Examples/Channel.py @@ -152,7 +152,7 @@ def client_connected(link): # Register message types and add callback to channel channel = link.get_channel() channel.register_message_type(StringMessage) - channel.add_message_callback(server_message_received) + channel.add_message_handler(server_message_received) def client_disconnected(link): RNS.log("Client disconnected") @@ -290,7 +290,7 @@ def link_established(link): # Register messages and add handler to channel channel = link.get_channel() channel.register_message_type(StringMessage) - channel.add_message_callback(client_message_received) + channel.add_message_handler(client_message_received) # Inform the user that the server is # connected diff --git a/RNS/Channel.py b/RNS/Channel.py index 0aebb4d..0b023be 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -150,28 +150,34 @@ class Channel(contextlib.AbstractContextManager): return False def register_message_type(self, message_class: Type[MessageBase]): - if not issubclass(message_class, MessageBase): - raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.") - if message_class.MSGTYPE is None: - raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") - try: - message_class() - except Exception as ex: - raise ChannelException(CEType.ME_INVALID_MSG_TYPE, - f"{message_class} raised an exception when constructed with no arguments: {ex}") + with self._lock: + if not issubclass(message_class, MessageBase): + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} is not a subclass of {MessageBase}.") + if message_class.MSGTYPE is None: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} has invalid MSGTYPE class attribute.") + try: + message_class() + except Exception as ex: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} raised an exception when constructed with no arguments: {ex}") - self._message_factories[message_class.MSGTYPE] = message_class + self._message_factories[message_class.MSGTYPE] = message_class - def add_message_callback(self, callback: MessageCallbackType): - if callback not in self._message_callbacks: - self._message_callbacks.append(callback) + def add_message_handler(self, callback: MessageCallbackType): + with self._lock: + if callback not in self._message_callbacks: + self._message_callbacks.append(callback) - def remove_message_callback(self, callback: MessageCallbackType): - self._message_callbacks.remove(callback) + def remove_message_handler(self, callback: MessageCallbackType): + with self._lock: + self._message_callbacks.remove(callback) def shutdown(self): - self._message_callbacks.clear() - self.clear_rings() + with self._lock: + self._message_callbacks.clear() + self.clear_rings() def clear_rings(self): with self._lock: @@ -205,19 +211,29 @@ class Channel(contextlib.AbstractContextManager): env.tracked = False self._rx_ring.remove(env) + def _run_callbacks(self, message: MessageBase): + with self._lock: + cbs = self._message_callbacks.copy() + + for cb in cbs: + try: + if cb(message): + return + except Exception as ex: + RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR) + def receive(self, raw: bytes): try: envelope = Envelope(outlet=self._outlet, raw=raw) - message = envelope.unpack(self._message_factories) with self._lock: + message = envelope.unpack(self._message_factories) is_new = self.emplace_envelope(envelope, self._rx_ring) self.prune_rx_ring() if not is_new: RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) return RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) - for cb in self._message_callbacks: - threading.Thread(target=cb, name="Message Callback", args=[message], daemon=True).start() + threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start() except Exception as ex: RNS.log(f"Channel: Error receiving data: {ex}") diff --git a/tests/channel.py b/tests/channel.py index 260f037..c9a64b3 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -245,13 +245,49 @@ class TestChannel(unittest.TestCase): self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) self.assertFalse(envelope.tracked) + def test_multiple_handler(self): + handler1_called = 0 + handler1_return = True + handler2_called = 0 + + def handler1(msg: MessageBase): + nonlocal handler1_called, handler1_return + self.assertIsInstance(msg, MessageTest) + handler1_called += 1 + return handler1_return + + def handler2(msg: MessageBase): + nonlocal handler2_called + self.assertIsInstance(msg, MessageTest) + handler2_called += 1 + + message = MessageTest() + self.h.channel.register_message_type(MessageTest) + self.h.channel.add_message_handler(handler1) + self.h.channel.add_message_handler(handler2) + envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) + raw = envelope.pack() + self.h.channel.receive(raw) + + self.assertEqual(1, handler1_called) + self.assertEqual(0, handler2_called) + + handler1_return = False + envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) + raw = envelope.pack() + self.h.channel.receive(raw) + + self.assertEqual(2, handler1_called) + self.assertEqual(1, handler2_called) + + def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): decoded: [MessageBase] = [] def handle_message(message: MessageBase): decoded.append(message) - self.h.channel.add_message_callback(handle_message) + self.h.channel.add_message_handler(handle_message) self.assertEqual(len(self.h.outlet.packets), 0) envelope = self.h.channel.send(message) diff --git a/tests/link.py b/tests/link.py index 8322b49..021eed0 100644 --- a/tests/link.py +++ b/tests/link.py @@ -382,7 +382,7 @@ class TestLink(unittest.TestCase): channel = l1.get_channel() channel.register_message_type(MessageTest) - channel.add_message_callback(handle_message) + channel.add_message_handler(handle_message) channel.send(test_message) time.sleep(0.5) @@ -466,7 +466,7 @@ def targets(yp=False): message.data = message.data + " back" channel.send(message) channel.register_message_type(MessageTest) - channel.add_message_callback(handle_message) + channel.add_message_handler(handle_message) m_rns = RNS.Reticulum("./tests/rnsconfig") id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) From c00b592ed978680fa8bf6cf485fcb020e5c8b03f Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sun, 26 Feb 2023 11:39:49 -0600 Subject: [PATCH 05/17] System-reserved channel message types - a message handler can return logical True to prevent subsequent message handlers from running - Message types >= 0xff00 are reserved for system/framework messages --- Examples/Channel.py | 3 ++- RNS/Channel.py | 5 ++++- tests/channel.py | 16 ++++++++++++++++ 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/Examples/Channel.py b/Examples/Channel.py index 4f1bd2c..f64e427 100644 --- a/Examples/Channel.py +++ b/Examples/Channel.py @@ -46,7 +46,8 @@ class StringMessage(RNS.MessageBase): # message arrives over the channel. # # MSGTYPE must be unique across all message types we - # register with the channel + # register with the channel. MSGTYPEs >= 0xff00 are + # reserved for the system. MSGTYPE = 0x0101 # The constructor of our object must be callable with diff --git a/RNS/Channel.py b/RNS/Channel.py index 0b023be..31aaf94 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -149,7 +149,7 @@ class Channel(contextlib.AbstractContextManager): self.shutdown() return False - def register_message_type(self, message_class: Type[MessageBase]): + def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): with self._lock: if not issubclass(message_class, MessageBase): raise ChannelException(CEType.ME_INVALID_MSG_TYPE, @@ -157,6 +157,9 @@ class Channel(contextlib.AbstractContextManager): if message_class.MSGTYPE is None: raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") + if message_class.MSGTYPE >= 0xff00 and not is_system_type: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} has system-reserved message type.") try: message_class() except Exception as ex: diff --git a/tests/channel.py b/tests/channel.py index c9a64b3..03e3bd9 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -137,6 +137,16 @@ class MessageTest(MessageBase): self.id, self.data = umsgpack.unpackb(raw) +class SystemMessage(MessageBase): + MSGTYPE = 0xffff + + def pack(self) -> bytes: + return bytes() + + def unpack(self, raw): + pass + + class ProtocolHarness(contextlib.AbstractContextManager): def __init__(self, rtt: float): self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) @@ -280,6 +290,11 @@ class TestChannel(unittest.TestCase): self.assertEqual(2, handler1_called) self.assertEqual(1, handler2_called) + def test_system_message_check(self): + with self.assertRaises(RNS.Channel.ChannelException): + self.h.channel.register_message_type(SystemMessage) + self.h.channel.register_message_type(SystemMessage, is_system_type=True) + def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): decoded: [MessageBase] = [] @@ -287,6 +302,7 @@ class TestChannel(unittest.TestCase): def handle_message(message: MessageBase): decoded.append(message) + self.h.channel.register_message_type(message.__class__) self.h.channel.add_message_handler(handle_message) self.assertEqual(len(self.h.outlet.packets), 0) From 44dc2d06c65f5ecc71d6a9d601f8f9ef159c713c Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sun, 26 Feb 2023 11:47:46 -0600 Subject: [PATCH 06/17] Add channel tests to all test suite Also print name in each test --- tests/all.py | 1 + tests/channel.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/tests/all.py b/tests/all.py index 27863a4..17556b1 100644 --- a/tests/all.py +++ b/tests/all.py @@ -4,6 +4,7 @@ from .hashes import TestSHA256 from .hashes import TestSHA512 from .identity import TestIdentity from .link import TestLink +from .channel import TestChannel if __name__ == '__main__': unittest.main(verbosity=2) \ No newline at end of file diff --git a/tests/channel.py b/tests/channel.py index 03e3bd9..245789a 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -164,6 +164,7 @@ class ProtocolHarness(contextlib.AbstractContextManager): class TestChannel(unittest.TestCase): def setUp(self) -> None: + print("") self.rtt = 0.001 self.retry_interval = self.rtt * 150 Packet.timeout = self.retry_interval @@ -173,6 +174,7 @@ class TestChannel(unittest.TestCase): self.h.cleanup() def test_send_one_retry(self): + print("Channel test one retry") message = MessageTest() self.assertEqual(0, len(self.h.outlet.packets)) @@ -224,6 +226,7 @@ class TestChannel(unittest.TestCase): self.assertFalse(envelope.tracked) def test_send_timeout(self): + print("Channel test retry count exceeded") message = MessageTest() self.assertEqual(0, len(self.h.outlet.packets)) @@ -256,6 +259,8 @@ class TestChannel(unittest.TestCase): self.assertFalse(envelope.tracked) def test_multiple_handler(self): + print("Channel test multiple handler short circuit") + handler1_called = 0 handler1_return = True handler2_called = 0 @@ -291,6 +296,7 @@ class TestChannel(unittest.TestCase): self.assertEqual(1, handler2_called) def test_system_message_check(self): + print("Channel test register system message") with self.assertRaises(RNS.Channel.ChannelException): self.h.channel.register_message_type(SystemMessage) self.h.channel.register_message_type(SystemMessage, is_system_type=True) @@ -353,6 +359,7 @@ class TestChannel(unittest.TestCase): checker(rx_message) def test_send_receive_message_test(self): + print("Channel test send and receive message") message = MessageTest() def check(rx_message: MessageBase): From 464dc23ff0201a9cf5cbb9beddbc6bb2e60f0867 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Mon, 27 Feb 2023 17:36:04 -0600 Subject: [PATCH 07/17] Add some internal documenation --- Examples/Channel.py | 25 ++++++++-- RNS/Channel.py | 109 ++++++++++++++++++++++++++++++++++++++------ RNS/Link.py | 4 +- tests/channel.py | 8 ++-- 4 files changed, 122 insertions(+), 24 deletions(-) diff --git a/Examples/Channel.py b/Examples/Channel.py index f64e427..53b878c 100644 --- a/Examples/Channel.py +++ b/Examples/Channel.py @@ -1,6 +1,6 @@ ########################################################## # This RNS example demonstrates how to set up a link to # -# a destination, and pass structuredmessages over it # +# a destination, and pass structured messages over it # # using a channel. # ########################################################## @@ -46,7 +46,7 @@ class StringMessage(RNS.MessageBase): # message arrives over the channel. # # MSGTYPE must be unique across all message types we - # register with the channel. MSGTYPEs >= 0xff00 are + # register with the channel. MSGTYPEs >= 0xf000 are # reserved for the system. MSGTYPE = 0x0101 @@ -159,17 +159,36 @@ def client_disconnected(link): RNS.log("Client disconnected") def server_message_received(message): + """ + A message handler + @param message: An instance of a subclass of MessageBase + @return: True if message was handled + """ global latest_client_link - # When a message is received over any active link, # the replies will all be directed to the last client # that connected. + + # In a message handler, any deserializable message + # that arrives over the link's channel will be passed + # to all message handlers, unless a preceding handler indicates it + # has handled the message. + # + # if isinstance(message, StringMessage): RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")") reply_message = StringMessage("I received \""+message.data+"\" over the link") latest_client_link.get_channel().send(reply_message) + # Incoming messages are sent to each message + # handler added to the channel, in the order they + # were added. + # If any message handler returns True, the message + # is considered handled and any subsequent + # handlers are skipped. + return True + ########################################################## #### Client Part ######################################### diff --git a/RNS/Channel.py b/RNS/Channel.py index 31aaf94..f6aff67 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -14,6 +14,13 @@ TPacket = TypeVar("TPacket") class ChannelOutletBase(ABC, Generic[TPacket]): + """ + An abstract transport layer interface used by Channel. + + DEPRECATED: This was created for testing; eventually + Channel will use Link or a LinkBase interface + directly. + """ @abstractmethod def send(self, raw: bytes) -> TPacket: raise NotImplemented() @@ -64,6 +71,9 @@ class ChannelOutletBase(ABC, Generic[TPacket]): class CEType(enum.IntEnum): + """ + ChannelException type codes + """ ME_NO_MSG_TYPE = 0 ME_INVALID_MSG_TYPE = 1 ME_NOT_REGISTERED = 2 @@ -73,12 +83,18 @@ class CEType(enum.IntEnum): class ChannelException(Exception): + """ + An exception thrown by Channel, with a type code. + """ def __init__(self, ce_type: CEType, *args): super().__init__(args) self.type = ce_type class MessageState(enum.IntEnum): + """ + Set of possible states for a Message + """ MSGSTATE_NEW = 0 MSGSTATE_SENT = 1 MSGSTATE_DELIVERED = 2 @@ -86,14 +102,29 @@ class MessageState(enum.IntEnum): class MessageBase(abc.ABC): + """ + Base type for any messages sent or received on a Channel. + Subclasses must define the two abstract methods as well as + the MSGTYPE class variable. + """ + # MSGTYPE must be unique within all classes sent over a + # channel. Additionally, MSGTYPE > 0xf000 are reserved. MSGTYPE = None @abstractmethod def pack(self) -> bytes: + """ + Create and return the binary representation of the message + @return: binary representation of message + """ raise NotImplemented() @abstractmethod def unpack(self, raw): + """ + Populate message from binary representation + @param raw: binary representation + """ raise NotImplemented() @@ -101,6 +132,10 @@ MessageCallbackType = NewType("MessageCallbackType", Callable[[MessageBase], boo class Envelope: + """ + Internal wrapper used to transport messages over a channel and + track its state within the channel framework. + """ def unpack(self, message_factories: dict[int, Type]) -> MessageBase: msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6]) raw = self.raw[6:] @@ -131,6 +166,12 @@ class Envelope: class Channel(contextlib.AbstractContextManager): + """ + Channel provides reliable delivery of messages over + a link. Channel is not meant to be instantiated + directly, but rather obtained from a Link using the + get_channel() function. + """ def __init__(self, outlet: ChannelOutletBase): self._outlet = outlet self._lock = threading.RLock() @@ -146,10 +187,14 @@ class Channel(contextlib.AbstractContextManager): def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, __traceback: TracebackType | None) -> bool | None: - self.shutdown() + self._shutdown() return False def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): + """ + Register a message class for reception over a channel. + @param message_class: Class to register. Must extend MessageBase. + """ with self._lock: if not issubclass(message_class, MessageBase): raise ChannelException(CEType.ME_INVALID_MSG_TYPE, @@ -157,7 +202,7 @@ class Channel(contextlib.AbstractContextManager): if message_class.MSGTYPE is None: raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") - if message_class.MSGTYPE >= 0xff00 and not is_system_type: + if message_class.MSGTYPE >= 0xf000 and not is_system_type: raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has system-reserved message type.") try: @@ -169,20 +214,34 @@ class Channel(contextlib.AbstractContextManager): self._message_factories[message_class.MSGTYPE] = message_class def add_message_handler(self, callback: MessageCallbackType): + """ + Add a handler for incoming messages. A handler + has the signature (message: MessageBase) -> bool. + Handlers are processed in the order they are + added. If any handler returns True, processing + of the message stops; handlers after the + returning handler will not be called. + @param callback: Function to call + @return: + """ with self._lock: if callback not in self._message_callbacks: self._message_callbacks.append(callback) def remove_message_handler(self, callback: MessageCallbackType): + """ + Remove a handler + @param callback: handler to remove + """ with self._lock: self._message_callbacks.remove(callback) - def shutdown(self): + def _shutdown(self): with self._lock: self._message_callbacks.clear() - self.clear_rings() + self._clear_rings() - def clear_rings(self): + def _clear_rings(self): with self._lock: for envelope in self._tx_ring: if envelope.packet is not None: @@ -191,14 +250,15 @@ class Channel(contextlib.AbstractContextManager): self._tx_ring.clear() self._rx_ring.clear() - def emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: + def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: with self._lock: i = 0 - for env in ring: - if env.sequence < envelope.sequence: + for existing in ring: + if existing.sequence > envelope.sequence \ + and not existing.sequence // 2 > envelope.sequence: # account for overflow ring.insert(i, envelope) return True - if env.sequence == envelope.sequence: + if existing.sequence == envelope.sequence: RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME) return False i += 1 @@ -206,7 +266,7 @@ class Channel(contextlib.AbstractContextManager): ring.append(envelope) return True - def prune_rx_ring(self): + def _prune_rx_ring(self): with self._lock: # Implementation for fixed window = 1 stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:] @@ -225,13 +285,13 @@ class Channel(contextlib.AbstractContextManager): except Exception as ex: RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR) - def receive(self, raw: bytes): + def _receive(self, raw: bytes): try: envelope = Envelope(outlet=self._outlet, raw=raw) with self._lock: message = envelope.unpack(self._message_factories) - is_new = self.emplace_envelope(envelope, self._rx_ring) - self.prune_rx_ring() + is_new = self._emplace_envelope(envelope, self._rx_ring) + self._prune_rx_ring() if not is_new: RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) return @@ -241,6 +301,10 @@ class Channel(contextlib.AbstractContextManager): RNS.log(f"Channel: Error receiving data: {ex}") def is_ready_to_send(self) -> bool: + """ + Check if Channel is ready to send. + @return: True if ready + """ if not self._outlet.is_usable: RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME) return False @@ -273,7 +337,7 @@ class Channel(contextlib.AbstractContextManager): def retry_envelope(envelope: Envelope) -> bool: if envelope.tries >= self._max_tries: RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR) - self.shutdown() # start on separate thread? + self._shutdown() # start on separate thread? self._outlet.timed_out() return True envelope.tries += 1 @@ -283,13 +347,18 @@ class Channel(contextlib.AbstractContextManager): self._packet_tx_op(packet, retry_envelope) def send(self, message: MessageBase) -> Envelope: + """ + Send a message. If a message send is attempted and + Channel is not ready, an exception is thrown. + @param message: an instance of a MessageBase subclass to send on the Channel + """ envelope: Envelope | None = None with self._lock: if not self.is_ready_to_send(): raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence) self._next_sequence = (self._next_sequence + 1) % 0x10000 - self.emplace_envelope(envelope, self._tx_ring) + self._emplace_envelope(envelope, self._tx_ring) if envelope is None: raise BlockingIOError() @@ -304,10 +373,20 @@ class Channel(contextlib.AbstractContextManager): @property def MDU(self): + """ + Maximum Data Unit: the number of bytes available + for a message to consume in a single send. + @return: number of bytes available + """ return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence) class LinkChannelOutlet(ChannelOutletBase): + """ + An implementation of ChannelOutletBase for RNS.Link. + Allows Channel to send packets over an RNS Link with + Packets. + """ def __init__(self, link: RNS.Link): self.link = link diff --git a/RNS/Link.py b/RNS/Link.py index 5f137d4..0f42388 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -464,7 +464,7 @@ class Link: for resource in self.outgoing_resources: resource.cancel() if self._channel: - self._channel.shutdown() + self._channel._shutdown() self.prv = None self.pub = None @@ -801,7 +801,7 @@ class Link: RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) else: plaintext = self.decrypt(packet.data) - self._channel.receive(plaintext) + self._channel._receive(plaintext) packet.prove() elif packet.packet_type == RNS.Packet.PROOF: diff --git a/tests/channel.py b/tests/channel.py index 245789a..b1097bf 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -153,7 +153,7 @@ class ProtocolHarness(contextlib.AbstractContextManager): self.channel = Channel(self.outlet) def cleanup(self): - self.channel.shutdown() + self.channel._shutdown() def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, __traceback: types.TracebackType) -> bool: @@ -282,7 +282,7 @@ class TestChannel(unittest.TestCase): self.h.channel.add_message_handler(handler2) envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) raw = envelope.pack() - self.h.channel.receive(raw) + self.h.channel._receive(raw) self.assertEqual(1, handler1_called) self.assertEqual(0, handler2_called) @@ -290,7 +290,7 @@ class TestChannel(unittest.TestCase): handler1_return = False envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) raw = envelope.pack() - self.h.channel.receive(raw) + self.h.channel._receive(raw) self.assertEqual(2, handler1_called) self.assertEqual(1, handler2_called) @@ -348,7 +348,7 @@ class TestChannel(unittest.TestCase): self.assertFalse(envelope.tracked) self.assertEqual(0, len(decoded)) - self.h.channel.receive(packet.raw) + self.h.channel._receive(packet.raw) self.assertEqual(1, len(decoded)) From 661964277f47063451122c0bba0e352ca5bde846 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Mon, 27 Feb 2023 19:05:25 -0600 Subject: [PATCH 08/17] Fix up documentation for building --- RNS/Channel.py | 68 ++++++++++++++++++++++++++++++++------- RNS/Link.py | 5 +++ docs/source/reference.rst | 28 ++++++++++++++++ 3 files changed, 89 insertions(+), 12 deletions(-) diff --git a/RNS/Channel.py b/RNS/Channel.py index f6aff67..a781697 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -1,3 +1,25 @@ +# MIT License +# +# Copyright (c) 2016-2023 Mark Qvist / unsigned.io and contributors. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + from __future__ import annotations import collections import enum @@ -105,17 +127,23 @@ class MessageBase(abc.ABC): """ Base type for any messages sent or received on a Channel. Subclasses must define the two abstract methods as well as - the MSGTYPE class variable. + the ``MSGTYPE`` class variable. """ # MSGTYPE must be unique within all classes sent over a # channel. Additionally, MSGTYPE > 0xf000 are reserved. MSGTYPE = None + """ + Defines a unique identifier for a message class. + ``MSGTYPE`` must be unique within all classes sent over a channel. + ``MSGTYPE`` must be < ``0xf000``. Values >= ``0xf000`` are reserved. + """ @abstractmethod def pack(self) -> bytes: """ Create and return the binary representation of the message - @return: binary representation of message + + :return: binary representation of message """ raise NotImplemented() @@ -123,7 +151,8 @@ class MessageBase(abc.ABC): def unpack(self, raw): """ Populate message from binary representation - @param raw: binary representation + + :param raw: binary representation """ raise NotImplemented() @@ -168,11 +197,19 @@ class Envelope: class Channel(contextlib.AbstractContextManager): """ Channel provides reliable delivery of messages over - a link. Channel is not meant to be instantiated + a link. + + Channel is not meant to be instantiated directly, but rather obtained from a Link using the get_channel() function. + + :param outlet: Outlet object to use for transport """ def __init__(self, outlet: ChannelOutletBase): + """ + + @param outlet: + """ self._outlet = outlet self._lock = threading.RLock() self._tx_ring: collections.deque[Envelope] = collections.deque() @@ -193,7 +230,8 @@ class Channel(contextlib.AbstractContextManager): def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): """ Register a message class for reception over a channel. - @param message_class: Class to register. Must extend MessageBase. + + :param message_class: Class to register. Must extend MessageBase. """ with self._lock: if not issubclass(message_class, MessageBase): @@ -216,13 +254,13 @@ class Channel(contextlib.AbstractContextManager): def add_message_handler(self, callback: MessageCallbackType): """ Add a handler for incoming messages. A handler - has the signature (message: MessageBase) -> bool. + has the signature ``(message: MessageBase) -> bool``. Handlers are processed in the order they are added. If any handler returns True, processing of the message stops; handlers after the returning handler will not be called. - @param callback: Function to call - @return: + + :param callback: Function to call """ with self._lock: if callback not in self._message_callbacks: @@ -231,7 +269,8 @@ class Channel(contextlib.AbstractContextManager): def remove_message_handler(self, callback: MessageCallbackType): """ Remove a handler - @param callback: handler to remove + + :param callback: handler to remove """ with self._lock: self._message_callbacks.remove(callback) @@ -303,7 +342,8 @@ class Channel(contextlib.AbstractContextManager): def is_ready_to_send(self) -> bool: """ Check if Channel is ready to send. - @return: True if ready + + :return: True if ready """ if not self._outlet.is_usable: RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME) @@ -350,7 +390,8 @@ class Channel(contextlib.AbstractContextManager): """ Send a message. If a message send is attempted and Channel is not ready, an exception is thrown. - @param message: an instance of a MessageBase subclass to send on the Channel + + :param message: an instance of a MessageBase subclass to send on the Channel """ envelope: Envelope | None = None with self._lock: @@ -376,7 +417,8 @@ class Channel(contextlib.AbstractContextManager): """ Maximum Data Unit: the number of bytes available for a message to consume in a single send. - @return: number of bytes available + + :return: number of bytes available """ return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence) @@ -386,6 +428,8 @@ class LinkChannelOutlet(ChannelOutletBase): An implementation of ChannelOutletBase for RNS.Link. Allows Channel to send packets over an RNS Link with Packets. + + :param link: RNS Link to wrap """ def __init__(self, link: RNS.Link): self.link = link diff --git a/RNS/Link.py b/RNS/Link.py index 0f42388..c380b10 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -646,6 +646,11 @@ class Link: pending_request.request_timed_out(None) def get_channel(self): + """ + Get the ``Channel`` for this link. + + :return: ``Channel`` object + """ if self._channel is None: self._channel = Channel(LinkChannelOutlet(self)) return self._channel diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 017cf8d..6c958aa 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -121,6 +121,34 @@ This chapter lists and explains all classes exposed by the Reticulum Network Sta .. autoclass:: RNS.Resource(data, link, advertise=True, auto_compress=True, callback=None, progress_callback=None, timeout=None) :members: +.. _api-channel: + +.. only:: html + + |start-h3| Channel |end-h3| + +.. only:: latex + + Channel + ------ + +.. autoclass:: RNS.Channel.Channel(outlet) + :members: + +.. _api-messsagebase: + +.. only:: html + + |start-h3| MessageBase |end-h3| + +.. only:: latex + + MessageBase + ------ + +.. autoclass:: RNS.MessageBase() + :members: + .. _api-transport: .. only:: html From 118acf77b8b20ac6451df1cdc470738c9c5ee2df Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Mon, 27 Feb 2023 21:10:28 -0600 Subject: [PATCH 09/17] Fix up documentation even more --- RNS/Channel.py | 60 ++++++++++++++++++++++++++++----------- docs/source/examples.rst | 12 ++++++++ docs/source/reference.rst | 2 +- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/RNS/Channel.py b/RNS/Channel.py index a781697..fba65e1 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -134,8 +134,9 @@ class MessageBase(abc.ABC): MSGTYPE = None """ Defines a unique identifier for a message class. - ``MSGTYPE`` must be unique within all classes sent over a channel. - ``MSGTYPE`` must be < ``0xf000``. Values >= ``0xf000`` are reserved. + + * Must be unique within all classes registered with a ``Channel`` + * Must be less than ``0xf000``. Values greater than or equal to ``0xf000`` are reserved. """ @abstractmethod @@ -148,7 +149,7 @@ class MessageBase(abc.ABC): raise NotImplemented() @abstractmethod - def unpack(self, raw): + def unpack(self, raw: bytes): """ Populate message from binary representation @@ -196,14 +197,29 @@ class Envelope: class Channel(contextlib.AbstractContextManager): """ - Channel provides reliable delivery of messages over + Provides reliable delivery of messages over a link. - Channel is not meant to be instantiated - directly, but rather obtained from a Link using the - get_channel() function. + ``Channel`` differs from ``Request`` and + ``Resource`` in some important ways: - :param outlet: Outlet object to use for transport + **Continuous** + Messages can be sent or received as long as + the ``Link`` is open. + **Bi-directional** + Messages can be sent in either direction on + the ``Link``; neither end is the client or + server. + **Size-constrained** + Messages must be encoded into a single packet. + + ``Channel`` is similar to ``Packet``, except that it + provides reliable delivery (automatic retries) as well + as a structure for exchanging several types of + messages over the ``Link``. + + ``Channel`` is not instantiated directly, but rather + obtained from a ``Link`` with ``get_channel()``. """ def __init__(self, outlet: ChannelOutletBase): """ @@ -227,12 +243,17 @@ class Channel(contextlib.AbstractContextManager): self._shutdown() return False - def register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): + def register_message_type(self, message_class: Type[MessageBase]): """ - Register a message class for reception over a channel. + Register a message class for reception over a ``Channel``. - :param message_class: Class to register. Must extend MessageBase. + Message classes must extend ``MessageBase``. + + :param message_class: Class to register """ + self._register_message_type(message_class, is_system_type=False) + + def _register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): with self._lock: if not issubclass(message_class, MessageBase): raise ChannelException(CEType.ME_INVALID_MSG_TYPE, @@ -254,7 +275,10 @@ class Channel(contextlib.AbstractContextManager): def add_message_handler(self, callback: MessageCallbackType): """ Add a handler for incoming messages. A handler - has the signature ``(message: MessageBase) -> bool``. + has the following signature: + + ``(message: MessageBase) -> bool`` + Handlers are processed in the order they are added. If any handler returns True, processing of the message stops; handlers after the @@ -268,7 +292,7 @@ class Channel(contextlib.AbstractContextManager): def remove_message_handler(self, callback: MessageCallbackType): """ - Remove a handler + Remove a handler added with ``add_message_handler``. :param callback: handler to remove """ @@ -341,7 +365,7 @@ class Channel(contextlib.AbstractContextManager): def is_ready_to_send(self) -> bool: """ - Check if Channel is ready to send. + Check if ``Channel`` is ready to send. :return: True if ready """ @@ -389,9 +413,9 @@ class Channel(contextlib.AbstractContextManager): def send(self, message: MessageBase) -> Envelope: """ Send a message. If a message send is attempted and - Channel is not ready, an exception is thrown. + ``Channel`` is not ready, an exception is thrown. - :param message: an instance of a MessageBase subclass to send on the Channel + :param message: an instance of a ``MessageBase`` subclass """ envelope: Envelope | None = None with self._lock: @@ -416,7 +440,9 @@ class Channel(contextlib.AbstractContextManager): def MDU(self): """ Maximum Data Unit: the number of bytes available - for a message to consume in a single send. + for a message to consume in a single send. This + value is adjusted from the ``Link`` MDU to accommodate + message header information. :return: number of bytes available """ diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 9b4428f..54c13f3 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -92,6 +92,18 @@ The *Request* example explores sendig requests and receiving responses. This example can also be found at ``_. +.. _example-channel: + +Channel +==== + +The *Channel* example explores using a ``Channel`` to send structured +data between peers of a ``Link``. + +.. literalinclude:: ../../Examples/Channel.py + +This example can also be found at ``_. + .. _example-filetransfer: Filetransfer diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 6c958aa..8d519a8 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -132,7 +132,7 @@ This chapter lists and explains all classes exposed by the Reticulum Network Sta Channel ------ -.. autoclass:: RNS.Channel.Channel(outlet) +.. autoclass:: RNS.Channel.Channel() :members: .. _api-messsagebase: From 42935c8238fcda31f69cb6dc84bc8be1f1a8a4c5 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Mon, 27 Feb 2023 21:15:25 -0600 Subject: [PATCH 10/17] Make the PR have zero deletions --- RNS/Link.py | 1 + 1 file changed, 1 insertion(+) diff --git a/RNS/Link.py b/RNS/Link.py index c380b10..a7c3d5c 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -22,6 +22,7 @@ from RNS.Cryptography import X25519PrivateKey, X25519PublicKey, Ed25519PrivateKey, Ed25519PublicKey from RNS.Cryptography import Fernet + from RNS.Channel import Channel, LinkChannelOutlet from time import sleep from .vendor import umsgpack as umsgpack From 68f95cd80b811335df4b2dff2a017a68a5c7793e Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Mon, 27 Feb 2023 21:30:13 -0600 Subject: [PATCH 11/17] Tidy up PR --- tests/link.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/link.py b/tests/link.py index 021eed0..3f36e70 100644 --- a/tests/link.py +++ b/tests/link.py @@ -49,11 +49,6 @@ def close_rns(): global c_rns if c_rns != None: c_rns.m_proc.kill() - # stdout, stderr = c_rns.m_proc.communicate() - # if stdout: - # print(stdout.decode("utf-8")) - # if stderr: - # print(stderr.decode("utf-8")) class TestLink(unittest.TestCase): def setUp(self): From d3c4928edab0df78e46ffe0dfb3de2ce94bacda6 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Mon, 27 Feb 2023 21:31:41 -0600 Subject: [PATCH 12/17] Tidy up PR --- tests/link.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/link.py b/tests/link.py index 3f36e70..203e982 100644 --- a/tests/link.py +++ b/tests/link.py @@ -9,7 +9,6 @@ import os from tests.channel import MessageTest from RNS.Channel import MessageBase - APP_NAME = "rns_unit_tests" fixed_keys = [ From 8f0151fed61cd1e39146dc030c415c013ed191cc Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Mon, 27 Feb 2023 21:33:50 -0600 Subject: [PATCH 13/17] Tidy up PR --- RNS/Link.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RNS/Link.py b/RNS/Link.py index a7c3d5c..822e1fc 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -22,8 +22,8 @@ from RNS.Cryptography import X25519PrivateKey, X25519PublicKey, Ed25519PrivateKey, Ed25519PublicKey from RNS.Cryptography import Fernet - from RNS.Channel import Channel, LinkChannelOutlet + from time import sleep from .vendor import umsgpack as umsgpack import threading From 8168d9bb92482881650cc8aa25780efe29ff2902 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Tue, 28 Feb 2023 08:13:07 -0600 Subject: [PATCH 14/17] Only send proof if link is still active --- RNS/Link.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/RNS/Link.py b/RNS/Link.py index 822e1fc..9f95ba0 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -808,7 +808,8 @@ class Link: else: plaintext = self.decrypt(packet.data) self._channel._receive(plaintext) - packet.prove() + if self.status == Link.ACTIVE: + packet.prove() elif packet.packet_type == RNS.Packet.PROOF: if packet.context == RNS.Packet.RESOURCE_PRF: From 72300cc82129a5b2565de099d7ebdafa4d4e2c57 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Tue, 28 Feb 2023 08:24:13 -0600 Subject: [PATCH 15/17] Revert "Only send proof if link is still active" --- RNS/Link.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/RNS/Link.py b/RNS/Link.py index 9f95ba0..822e1fc 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -808,8 +808,7 @@ class Link: else: plaintext = self.decrypt(packet.data) self._channel._receive(plaintext) - if self.status == Link.ACTIVE: - packet.prove() + packet.prove() elif packet.packet_type == RNS.Packet.PROOF: if packet.context == RNS.Packet.RESOURCE_PRF: From 9963cf37b860cade6c9424d7917a4060b7ac7039 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Tue, 28 Feb 2023 08:38:23 -0600 Subject: [PATCH 16/17] Fix exceptions on Channel shutdown --- RNS/Channel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/RNS/Channel.py b/RNS/Channel.py index fba65e1..839bf27 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -507,13 +507,15 @@ class LinkChannelOutlet(ChannelOutletBase): def inner(receipt: RNS.PacketReceipt): callback(packet) - packet.receipt.set_timeout_callback(inner if callback else None) + if packet and packet.receipt: + packet.receipt.set_timeout_callback(inner if callback else None) def set_packet_delivered_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None): def inner(receipt: RNS.PacketReceipt): callback(packet) - packet.receipt.set_delivery_callback(inner if callback else None) + if packet and packet.receipt: + packet.receipt.set_delivery_callback(inner if callback else None) def get_packet_id(self, packet: RNS.Packet) -> any: return packet.get_hash() From d2d121d49f745cc1c5e74790775a1a5f72246ab1 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Tue, 28 Feb 2023 08:38:36 -0600 Subject: [PATCH 17/17] Fix broken Channel test --- tests/channel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/channel.py b/tests/channel.py index b1097bf..3a00cbe 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -138,7 +138,7 @@ class MessageTest(MessageBase): class SystemMessage(MessageBase): - MSGTYPE = 0xffff + MSGTYPE = 0xf000 def pack(self) -> bytes: return bytes() @@ -299,7 +299,7 @@ class TestChannel(unittest.TestCase): print("Channel test register system message") with self.assertRaises(RNS.Channel.ChannelException): self.h.channel.register_message_type(SystemMessage) - self.h.channel.register_message_type(SystemMessage, is_system_type=True) + self.h.channel._register_message_type(SystemMessage, is_system_type=True) def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]):