From 68cb4a67405f15924ce56cfe2d32c7bdaea13f2e Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sat, 25 Feb 2023 18:23:25 -0600 Subject: [PATCH] Initial work on Channel --- RNS/Channel.py | 350 +++++++++++++++++++++++++++++++++++++++++++++++ RNS/Link.py | 34 ++++- RNS/Packet.py | 1 + tests/channel.py | 316 ++++++++++++++++++++++++++++++++++++++++++ tests/link.py | 59 ++++++++ 5 files changed, 759 insertions(+), 1 deletion(-) create mode 100644 RNS/Channel.py create mode 100644 tests/channel.py diff --git a/RNS/Channel.py b/RNS/Channel.py new file mode 100644 index 0000000..2594d37 --- /dev/null +++ b/RNS/Channel.py @@ -0,0 +1,350 @@ +from __future__ import annotations +import collections +import enum +import threading +import time +from types import TracebackType +from typing import Type, Callable, TypeVar +import abc +import contextlib +import struct +import RNS +from abc import ABC, abstractmethod +_TPacket = TypeVar("_TPacket") + + +class ChannelOutletBase(ABC): + @abstractmethod + def send(self, raw: bytes) -> _TPacket: + raise NotImplemented() + + @abstractmethod + def resend(self, packet: _TPacket) -> _TPacket: + raise NotImplemented() + + @property + @abstractmethod + def mdu(self): + raise NotImplemented() + + @property + @abstractmethod + def rtt(self): + raise NotImplemented() + + @property + @abstractmethod + def is_usable(self): + raise NotImplemented() + + @abstractmethod + def get_packet_state(self, packet: _TPacket) -> MessageState: + raise NotImplemented() + + @abstractmethod + def timed_out(self): + raise NotImplemented() + + @abstractmethod + def __str__(self): + raise NotImplemented() + + @abstractmethod + def set_packet_timeout_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None, + timeout: float | None = None): + raise NotImplemented() + + @abstractmethod + def set_packet_delivered_callback(self, packet: _TPacket, callback: Callable[[_TPacket], None] | None): + raise NotImplemented() + + @abstractmethod + def get_packet_id(self, packet: _TPacket) -> any: + raise NotImplemented() + + +class CEType(enum.IntEnum): + ME_NO_MSG_TYPE = 0 + ME_INVALID_MSG_TYPE = 1 + ME_NOT_REGISTERED = 2 + ME_LINK_NOT_READY = 3 + ME_ALREADY_SENT = 4 + ME_TOO_BIG = 5 + + +class ChannelException(Exception): + def __init__(self, ce_type: CEType, *args): + super().__init__(args) + self.type = ce_type + + +class MessageState(enum.IntEnum): + MSGSTATE_NEW = 0 + MSGSTATE_SENT = 1 + MSGSTATE_DELIVERED = 2 + MSGSTATE_FAILED = 3 + + +class MessageBase(abc.ABC): + MSGTYPE = None + + @abstractmethod + def pack(self) -> bytes: + raise NotImplemented() + + @abstractmethod + def unpack(self, raw): + raise NotImplemented() + + +class Envelope: + def unpack(self, message_factories: dict[int, Type]) -> MessageBase: + msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6]) + raw = self.raw[6:] + ctor = message_factories.get(msgtype, None) + if ctor is None: + raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}") + message = ctor() + message.unpack(raw) + return message + + def pack(self) -> bytes: + if self.message.__class__.MSGTYPE is None: + raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE") + data = self.message.pack() + self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data + return self.raw + + def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None): + self.ts = time.time() + self.id = id(self) + self.message = message + self.raw = raw + self.packet: _TPacket = None + self.sequence = sequence + self.outlet = outlet + self.tries = 0 + self.tracked = False + + +class Channel(contextlib.AbstractContextManager): + def __init__(self, outlet: ChannelOutletBase): + self._outlet = outlet + self._lock = threading.RLock() + self._tx_ring: collections.deque[Envelope] = collections.deque() + self._rx_ring: collections.deque[Envelope] = collections.deque() + self._message_callback: Callable[[MessageBase], None] | None = None + self._next_sequence = 0 + self._message_factories: dict[int, Type[MessageBase]] = self._get_msg_constructors() + self._max_tries = 5 + + def __enter__(self) -> Channel: + return self + + def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, + __traceback: TracebackType | None) -> bool | None: + self.shutdown() + return False + + @staticmethod + def _get_msg_constructors() -> (int, Type[MessageBase]): + subclass_tuples = [] + for subclass in MessageBase.__subclasses__(): + with contextlib.suppress(Exception): + subclass() # verify constructor works with no arguments, needed for unpacking + subclass_tuples.append((subclass.MSGTYPE, subclass)) + message_factories = dict(subclass_tuples) + return message_factories + + def register_message_type(self, message_class: Type[MessageBase]): + if not issubclass(message_class, MessageBase): + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} is not a subclass of {MessageBase}.") + if message_class.MSGTYPE is None: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, f"{message_class} has invalid MSGTYPE class attribute.") + try: + message_class() + except Exception as ex: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} raised an exception when constructed with no arguments: {ex}") + + self._message_factories[message_class.MSGTYPE] = message_class + + def set_message_callback(self, callback: Callable[[MessageBase], None]): + self._message_callback = callback + + def shutdown(self): + self.clear_rings() + + def clear_rings(self): + with self._lock: + for envelope in self._tx_ring: + if envelope.packet is not None: + self._outlet.set_packet_timeout_callback(envelope.packet, None) + self._outlet.set_packet_delivered_callback(envelope.packet, None) + self._tx_ring.clear() + self._rx_ring.clear() + + def emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: + with self._lock: + i = 0 + for env in ring: + if env.sequence < envelope.sequence: + ring.insert(i, envelope) + return True + if env.sequence == envelope.sequence: + RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME) + return False + i += 1 + envelope.tracked = True + ring.append(envelope) + return True + + def prune_rx_ring(self): + with self._lock: + # Implementation for fixed window = 1 + stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:] + for env in stale: + env.tracked = False + self._rx_ring.remove(env) + + def receive(self, raw: bytes): + try: + envelope = Envelope(outlet=self._outlet, raw=raw) + message = envelope.unpack(self._message_factories) + with self._lock: + is_new = self.emplace_envelope(envelope, self._rx_ring) + self.prune_rx_ring() + if not is_new: + RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) + return + RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) + if self._message_callback: + threading.Thread(target=self._message_callback, name="Message Callback", args=[message], daemon=True)\ + .start() + except Exception as ex: + RNS.log(f"Channel: Error receiving data: {ex}") + + def is_ready_to_send(self) -> bool: + if not self._outlet.is_usable: + RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME) + return False + + with self._lock: + for envelope in self._tx_ring: + if envelope.outlet == self._outlet and (not envelope.packet + or self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_SENT): + RNS.log("Channel: Link has a pending message.", RNS.LOG_EXTREME) + return False + return True + + def _packet_tx_op(self, packet: _TPacket, op: Callable[[_TPacket], bool]): + with self._lock: + envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet), + self._tx_ring), None) + if envelope and op(envelope): + envelope.tracked = False + if envelope in self._tx_ring: + self._tx_ring.remove(envelope) + else: + RNS.log("Channel: Envelope not found in TX ring", RNS.LOG_DEBUG) + if not envelope: + RNS.log("Channel: Spurious message received.", RNS.LOG_EXTREME) + + def _packet_delivered(self, packet: _TPacket): + self._packet_tx_op(packet, lambda env: True) + + def _packet_timeout(self, packet: _TPacket): + def retry_envelope(envelope: Envelope) -> bool: + if envelope.tries >= self._max_tries: + RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR) + self.shutdown() # start on separate thread? + self._outlet.timed_out() + return True + envelope.tries += 1 + self._outlet.resend(envelope.packet) + return False + + self._packet_tx_op(packet, retry_envelope) + + def send(self, message: MessageBase) -> Envelope: + envelope: Envelope | None = None + with self._lock: + if not self.is_ready_to_send(): + raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") + envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence) + self._next_sequence = (self._next_sequence + 1) % 0x10000 + self.emplace_envelope(envelope, self._tx_ring) + if envelope is None: + raise BlockingIOError() + + envelope.pack() + if len(envelope.raw) > self._outlet.mdu: + raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}") + envelope.packet = self._outlet.send(envelope.raw) + envelope.tries += 1 + self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) + self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout) + return envelope + + +class LinkChannelOutlet(ChannelOutletBase): + def __init__(self, link: RNS.Link): + self.link = link + + def send(self, raw: bytes) -> RNS.Packet: + packet = RNS.Packet(self.link, raw, context=RNS.Packet.CHANNEL) + packet.send() + return packet + + def resend(self, packet: RNS.Packet) -> RNS.Packet: + if not packet.resend(): + RNS.log("Failed to resend packet", RNS.LOG_ERROR) + return packet + + @property + def mdu(self): + return self.link.MDU + + @property + def rtt(self): + return self.link.rtt + + @property + def is_usable(self): + return True # had issues looking at Link.status + + def get_packet_state(self, packet: _TPacket) -> MessageState: + status = packet.receipt.get_status() + if status == RNS.PacketReceipt.SENT: + return MessageState.MSGSTATE_SENT + if status == RNS.PacketReceipt.DELIVERED: + return MessageState.MSGSTATE_DELIVERED + if status == RNS.PacketReceipt.FAILED: + return MessageState.MSGSTATE_FAILED + else: + raise Exception(f"Unexpected receipt state: {status}") + + def timed_out(self): + self.link.teardown() + + def __str__(self): + return f"{self.__class__.__name__}({self.link})" + + def set_packet_timeout_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None, + timeout: float | None = None): + if timeout: + packet.receipt.set_timeout(timeout) + + def inner(receipt: RNS.PacketReceipt): + callback(packet) + + packet.receipt.set_timeout_callback(inner if callback else None) + + def set_packet_delivered_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None): + def inner(receipt: RNS.PacketReceipt): + callback(packet) + + packet.receipt.set_delivery_callback(inner if callback else None) + + def get_packet_id(self, packet: RNS.Packet) -> any: + return packet.get_hash() diff --git a/RNS/Link.py b/RNS/Link.py index 00cca34..4fd58f9 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -22,7 +22,7 @@ from RNS.Cryptography import X25519PrivateKey, X25519PublicKey, Ed25519PrivateKey, Ed25519PublicKey from RNS.Cryptography import Fernet - +from RNS.Channel import Channel, LinkChannelOutlet from time import sleep from .vendor import umsgpack as umsgpack import threading @@ -163,6 +163,7 @@ class Link: self.destination = destination self.attached_interface = None self.__remote_identity = None + self._channel = None if self.destination == None: self.initiator = False self.prv = X25519PrivateKey.generate() @@ -462,6 +463,8 @@ class Link: resource.cancel() for resource in self.outgoing_resources: resource.cancel() + if self._channel: + self._channel.shutdown() self.prv = None self.pub = None @@ -642,6 +645,27 @@ class Link: if pending_request.request_id == resource.request_id: pending_request.request_timed_out(None) + def _ensure_channel(self): + if self._channel is None: + self._channel = Channel(LinkChannelOutlet(self)) + return self._channel + + def set_message_callback(self, callback, message_types=None): + if not callback: + if self._channel: + self._channel.set_message_callback(None) + return + + self._ensure_channel() + + if message_types: + for msg_type in message_types: + self._channel.register_message_type(msg_type) + self._channel.set_message_callback(callback) + + def send_message(self, message: RNS.Channel.MessageBase): + self._ensure_channel().send(message) + def receive(self, packet): self.watchdog_lock = True if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])): @@ -788,6 +812,14 @@ class Link: for resource in self.incoming_resources: resource.receive_part(packet) + elif packet.context == RNS.Packet.CHANNEL: + if not self._channel: + RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) + else: + plaintext = self.decrypt(packet.data) + self._channel.receive(plaintext) + packet.prove() + elif packet.packet_type == RNS.Packet.PROOF: if packet.context == RNS.Packet.RESOURCE_PRF: resource_hash = packet.data[0:RNS.Identity.HASHLENGTH//8] diff --git a/RNS/Packet.py b/RNS/Packet.py index cdc476c..a105dec 100755 --- a/RNS/Packet.py +++ b/RNS/Packet.py @@ -75,6 +75,7 @@ class Packet: PATH_RESPONSE = 0x0B # Packet is a response to a path request COMMAND = 0x0C # Packet is a command COMMAND_STATUS = 0x0D # Packet is a status of an executed command + CHANNEL = 0x0E # Packet contains link channel data KEEPALIVE = 0xFA # Packet is a keepalive packet LINKIDENTIFY = 0xFB # Packet is a link peer identification proof LINKCLOSE = 0xFC # Packet is a link close message diff --git a/tests/channel.py b/tests/channel.py new file mode 100644 index 0000000..eb57966 --- /dev/null +++ b/tests/channel.py @@ -0,0 +1,316 @@ +from __future__ import annotations +import threading +import RNS +from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase +from RNS.vendor import umsgpack +from typing import Callable +import contextlib +import typing +import types +import time +import uuid +import unittest + + +class Packet: + timeout = 1.0 + + def __init__(self, raw: bytes): + self.state = MessageState.MSGSTATE_NEW + self.raw = raw + self.packet_id = uuid.uuid4() + self.tries = 0 + self.timeout_id = None + self.lock = threading.RLock() + self.instances = 0 + self.timeout_callback: Callable[[Packet], None] | None = None + self.delivered_callback: Callable[[Packet], None] | None = None + + def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float): + with self.lock: + if timeout is not None: + self.timeout = timeout + self.timeout_callback = callback + + + def send(self): + self.tries += 1 + self.state = MessageState.MSGSTATE_SENT + + def elapsed(timeout: float, timeout_id: uuid.uuid4): + with self.lock: + self.instances += 1 + try: + time.sleep(timeout) + with self.lock: + if self.timeout_id == timeout_id: + self.timeout_id = None + self.state = MessageState.MSGSTATE_FAILED + if self.timeout_callback: + self.timeout_callback(self) + finally: + with self.lock: + self.instances -= 1 + + self.timeout_id = uuid.uuid4() + threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id], + daemon=True).start() + + def clear_timeout(self): + self.timeout_id = None + + def set_delivered_callback(self, callback: Callable[[Packet], None]): + self.delivered_callback = callback + + def delivered(self): + with self.lock: + self.state = MessageState.MSGSTATE_DELIVERED + self.timeout_id = None + if self.delivered_callback: + self.delivered_callback(self) + + +class ChannelOutletTest(ChannelOutletBase): + def get_packet_state(self, packet: Packet) -> MessageState: + return packet.state + + def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None, + timeout: float | None = None): + packet.set_timeout(callback, timeout) + + def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None): + packet.set_delivered_callback(callback) + + def get_packet_id(self, packet: Packet) -> any: + return packet.packet_id + + def __init__(self, mdu: int, rtt: float): + self.link_id = uuid.uuid4() + self.timeout_callbacks = 0 + self._mdu = mdu + self._rtt = rtt + self._usable = True + self.packets = [] + self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None + + def send(self, raw: bytes) -> Packet: + packet = Packet(raw) + packet.send() + self.packets.append(packet) + return packet + + def resend(self, packet: Packet) -> Packet: + packet.send() + return packet + + @property + def mdu(self): + return self._mdu + + @property + def rtt(self): + return self._rtt + + @property + def is_usable(self): + return self._usable + + def timed_out(self): + self.timeout_callbacks += 1 + + def __str__(self): + return str(self.link_id) + + +class MessageTest(MessageBase): + MSGTYPE = 0xabcd + + def __init__(self): + self.id = str(uuid.uuid4()) + self.data = "test" + self.not_serialized = str(uuid.uuid4()) + + def pack(self) -> bytes: + return umsgpack.packb((self.id, self.data)) + + def unpack(self, raw): + self.id, self.data = umsgpack.unpackb(raw) + + +class ProtocolHarness(contextlib.AbstractContextManager): + def __init__(self, rtt: float): + self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) + self.channel = Channel(self.outlet) + + def cleanup(self): + self.channel.shutdown() + + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: + # self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") + self.cleanup() + return False + + +class TestChannel(unittest.TestCase): + def setUp(self) -> None: + self.rtt = 0.001 + self.retry_interval = self.rtt * 150 + Packet.timeout = self.retry_interval + self.h = ProtocolHarness(self.rtt) + + def tearDown(self) -> None: + self.h.cleanup() + + def test_send_one_retry(self): + message = MessageTest() + + self.assertEqual(0, len(self.h.outlet.packets)) + + envelope = self.h.channel.send(message) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + time.sleep(self.retry_interval * 1.5) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(2, envelope.tries) + self.assertEqual(2, packet.tries) + self.assertEqual(1, packet.instances) + + time.sleep(self.retry_interval) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(self.h.outlet.packets[0], packet) + self.assertEqual(3, envelope.tries) + self.assertEqual(3, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + + packet.delivered() + + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + + time.sleep(self.retry_interval) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(3, envelope.tries) + self.assertEqual(3, packet.tries) + self.assertEqual(0, packet.instances) + self.assertFalse(envelope.tracked) + + def test_send_timeout(self): + message = MessageTest() + + self.assertEqual(0, len(self.h.outlet.packets)) + + envelope = self.h.channel.send(message) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + time.sleep(self.retry_interval * 7.5) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(5, envelope.tries) + self.assertEqual(5, packet.tries) + self.assertEqual(0, packet.instances) + self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) + self.assertFalse(envelope.tracked) + + def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): + decoded: [MessageBase] = [] + + def handle_message(message: MessageBase): + decoded.append(message) + + self.h.channel.set_message_callback(handle_message) + self.assertEqual(len(self.h.outlet.packets), 0) + + envelope = self.h.channel.send(message) + time.sleep(self.retry_interval * 0.5) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + packet.delivered() + + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + + time.sleep(self.retry_interval * 2) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(0, packet.instances) + self.assertFalse(envelope.tracked) + + self.assertEqual(len(self.h.outlet.packets), 1) + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + self.assertFalse(envelope.tracked) + self.assertEqual(0, len(decoded)) + + self.h.channel.receive(packet.raw) + + self.assertEqual(1, len(decoded)) + + rx_message = decoded[0] + + self.assertIsNotNone(rx_message) + self.assertIsInstance(rx_message, message.__class__) + checker(rx_message) + + def test_send_receive_message_test(self): + message = MessageTest() + + def check(rx_message: MessageBase): + self.assertIsInstance(rx_message, message.__class__) + self.assertEqual(message.id, rx_message.id) + self.assertEqual(message.data, rx_message.data) + self.assertNotEqual(message.not_serialized, rx_message.not_serialized) + + self.eat_own_dog_food(message, check) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/tests/link.py b/tests/link.py index 5368858..5b22bb2 100644 --- a/tests/link.py +++ b/tests/link.py @@ -6,6 +6,9 @@ import threading import time import RNS import os +from tests.channel import MessageTest +from RNS.Channel import MessageBase + APP_NAME = "rns_unit_tests" @@ -46,6 +49,11 @@ def close_rns(): global c_rns if c_rns != None: c_rns.m_proc.kill() + # stdout, stderr = c_rns.m_proc.communicate() + # if stdout: + # print(stdout.decode("utf-8")) + # if stderr: + # print(stderr.decode("utf-8")) class TestLink(unittest.TestCase): def setUp(self): @@ -346,6 +354,52 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + def test_10_channel_round_trip(self): + global c_rns + init_rns(self) + print("") + print("Channel round trip test") + + # TODO: Load this from public bytes only + id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) + self.assertEqual(id1.hash, bytes.fromhex(fixed_keys[0][1])) + + dest = RNS.Destination(id1, RNS.Destination.OUT, RNS.Destination.SINGLE, APP_NAME, "link", "establish") + + self.assertEqual(dest.hash, bytes.fromhex("fb48da0e82e6e01ba0c014513f74540d")) + + l1 = RNS.Link(dest) + time.sleep(1) + self.assertEqual(l1.status, RNS.Link.ACTIVE) + + received = [] + + def handle_message(message: MessageBase): + received.append(message) + + test_message = MessageTest() + test_message.data = "Hello" + + l1.set_message_callback(handle_message) + l1.send_message(test_message) + + time.sleep(0.5) + + self.assertEqual(1, len(received)) + + rx_message = received[0] + + self.assertIsInstance(rx_message, MessageTest) + self.assertEqual("Hello back", rx_message.data) + self.assertEqual(test_message.id, rx_message.id) + self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized) + self.assertEqual(1, len(l1._channel._rx_ring)) + + l1.teardown() + time.sleep(0.5) + self.assertEqual(l1.status, RNS.Link.CLOSED) + self.assertEqual(0, len(l1._channel._rx_ring)) + def size_str(self, num, suffix='B'): units = ['','K','M','G','T','P','E','Z'] @@ -405,6 +459,11 @@ def targets(yp=False): link.set_resource_started_callback(resource_started) link.set_resource_concluded_callback(resource_concluded) + def handle_message(message): + message.data = message.data + " back" + link.send_message(message) + link.set_message_callback(handle_message, [MessageTest]) + m_rns = RNS.Reticulum("./tests/rnsconfig") id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish")