diff --git a/Examples/Channel.py b/Examples/Channel.py new file mode 100644 index 0000000..53b878c --- /dev/null +++ b/Examples/Channel.py @@ -0,0 +1,395 @@ +########################################################## +# This RNS example demonstrates how to set up a link to # +# a destination, and pass structured messages over it # +# using a channel. # +########################################################## + +import os +import sys +import time +import argparse +from datetime import datetime + +import RNS +from RNS.vendor import umsgpack + +# Let's define an app name. We'll use this for all +# destinations we create. Since this echo example +# is part of a range of example utilities, we'll put +# them all within the app namespace "example_utilities" +APP_NAME = "example_utilities" + +########################################################## +#### Shared Objects ###################################### +########################################################## + +# Channel data must be structured in a subclass of +# MessageBase. This ensures that the channel will be able +# to serialize and deserialize the object and multiplex it +# with other objects. Both ends of a link will need the +# same object definitions to be able to communicate over +# a channel. +# +# Note: The objects we wish to use over the channel must +# be registered with the channel, and each link has a +# different channel instance. See the client_connected +# and link_established functions in this example to see +# how message types are registered. + +# Let's make a simple message class called StringMessage +# that will convey a string with a timestamp. + +class StringMessage(RNS.MessageBase): + # The MSGTYPE class variable needs to be assigned a + # 2 byte integer value. This identifier allows the + # channel to look up your message's constructor when a + # message arrives over the channel. + # + # MSGTYPE must be unique across all message types we + # register with the channel. MSGTYPEs >= 0xf000 are + # reserved for the system. + MSGTYPE = 0x0101 + + # The constructor of our object must be callable with + # no arguments. We can have parameters, but they must + # have a default assignment. + # + # This is needed so the channel can create an empty + # version of our message into which the incoming + # message can be unpacked. + def __init__(self, data=None): + self.data = data + self.timestamp = datetime.now() + + # Finally, our message needs to implement functions + # the channel can call to pack and unpack our message + # to/from the raw packet payload. We'll use the + # umsgpack package bundled with RNS. We could also use + # the struct package bundled with Python if we wanted + # more control over the structure of the packed bytes. + # + # Also note that packed message objects must fit + # entirely in one packet. The number of bytes + # available for message payloads can be queried from + # the channel using the Channel.MDU property. The + # channel MDU is slightly less than the link MDU due + # to encoding the message header. + + # The pack function encodes the message contents into + # a byte stream. + def pack(self) -> bytes: + return umsgpack.packb((self.data, self.timestamp)) + + # And the unpack function decodes a byte stream into + # the message contents. + def unpack(self, raw): + self.data, self.timestamp = umsgpack.unpackb(raw) + + +########################################################## +#### Server Part ######################################### +########################################################## + +# A reference to the latest client link that connected +latest_client_link = None + +# This initialisation is executed when the users chooses +# to run as a server +def server(configpath): + # We must first initialise Reticulum + reticulum = RNS.Reticulum(configpath) + + # Randomly create a new identity for our link example + server_identity = RNS.Identity() + + # We create a destination that clients can connect to. We + # want clients to create links to this destination, so we + # need to create a "single" destination type. + server_destination = RNS.Destination( + server_identity, + RNS.Destination.IN, + RNS.Destination.SINGLE, + APP_NAME, + "channelexample" + ) + + # We configure a function that will get called every time + # a new client creates a link to this destination. + server_destination.set_link_established_callback(client_connected) + + # Everything's ready! + # Let's Wait for client requests or user input + server_loop(server_destination) + +def server_loop(destination): + # Let the user know that everything is ready + RNS.log( + "Link example "+ + RNS.prettyhexrep(destination.hash)+ + " running, waiting for a connection." + ) + + RNS.log("Hit enter to manually send an announce (Ctrl-C to quit)") + + # We enter a loop that runs until the users exits. + # If the user hits enter, we will announce our server + # destination on the network, which will let clients + # know how to create messages directed towards it. + while True: + entered = input() + destination.announce() + RNS.log("Sent announce from "+RNS.prettyhexrep(destination.hash)) + +# When a client establishes a link to our server +# destination, this function will be called with +# a reference to the link. +def client_connected(link): + global latest_client_link + latest_client_link = link + + RNS.log("Client connected") + link.set_link_closed_callback(client_disconnected) + + # Register message types and add callback to channel + channel = link.get_channel() + channel.register_message_type(StringMessage) + channel.add_message_handler(server_message_received) + +def client_disconnected(link): + RNS.log("Client disconnected") + +def server_message_received(message): + """ + A message handler + @param message: An instance of a subclass of MessageBase + @return: True if message was handled + """ + global latest_client_link + # When a message is received over any active link, + # the replies will all be directed to the last client + # that connected. + + # In a message handler, any deserializable message + # that arrives over the link's channel will be passed + # to all message handlers, unless a preceding handler indicates it + # has handled the message. + # + # + if isinstance(message, StringMessage): + RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")") + + reply_message = StringMessage("I received \""+message.data+"\" over the link") + latest_client_link.get_channel().send(reply_message) + + # Incoming messages are sent to each message + # handler added to the channel, in the order they + # were added. + # If any message handler returns True, the message + # is considered handled and any subsequent + # handlers are skipped. + return True + + +########################################################## +#### Client Part ######################################### +########################################################## + +# A reference to the server link +server_link = None + +# This initialisation is executed when the users chooses +# to run as a client +def client(destination_hexhash, configpath): + # We need a binary representation of the destination + # hash that was entered on the command line + try: + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 + if len(destination_hexhash) != dest_len: + raise ValueError( + "Destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(hex=dest_len, byte=dest_len//2) + ) + + destination_hash = bytes.fromhex(destination_hexhash) + except: + RNS.log("Invalid destination entered. Check your input!\n") + exit() + + # We must first initialise Reticulum + reticulum = RNS.Reticulum(configpath) + + # Check if we know a path to the destination + if not RNS.Transport.has_path(destination_hash): + RNS.log("Destination is not yet known. Requesting path and waiting for announce to arrive...") + RNS.Transport.request_path(destination_hash) + while not RNS.Transport.has_path(destination_hash): + time.sleep(0.1) + + # Recall the server identity + server_identity = RNS.Identity.recall(destination_hash) + + # Inform the user that we'll begin connecting + RNS.log("Establishing link with server...") + + # When the server identity is known, we set + # up a destination + server_destination = RNS.Destination( + server_identity, + RNS.Destination.OUT, + RNS.Destination.SINGLE, + APP_NAME, + "channelexample" + ) + + # And create a link + link = RNS.Link(server_destination) + + # We set a callback that will get executed + # every time a packet is received over the + # link + link.set_packet_callback(client_message_received) + + # We'll also set up functions to inform the + # user when the link is established or closed + link.set_link_established_callback(link_established) + link.set_link_closed_callback(link_closed) + + # Everything is set up, so let's enter a loop + # for the user to interact with the example + client_loop() + +def client_loop(): + global server_link + + # Wait for the link to become active + while not server_link: + time.sleep(0.1) + + should_quit = False + while not should_quit: + try: + print("> ", end=" ") + text = input() + + # Check if we should quit the example + if text == "quit" or text == "q" or text == "exit": + should_quit = True + server_link.teardown() + + # If not, send the entered text over the link + if text != "": + message = StringMessage(text) + packed_size = len(message.pack()) + channel = server_link.get_channel() + if channel.is_ready_to_send(): + if packed_size <= channel.MDU: + channel.send(message) + else: + RNS.log( + "Cannot send this packet, the data size of "+ + str(packed_size)+" bytes exceeds the link packet MDU of "+ + str(channel.MDU)+" bytes", + RNS.LOG_ERROR + ) + else: + RNS.log("Channel is not ready to send, please wait for " + + "pending messages to complete.", RNS.LOG_ERROR) + + except Exception as e: + RNS.log("Error while sending data over the link: "+str(e)) + should_quit = True + server_link.teardown() + +# This function is called when a link +# has been established with the server +def link_established(link): + # We store a reference to the link + # instance for later use + global server_link + server_link = link + + # Register messages and add handler to channel + channel = link.get_channel() + channel.register_message_type(StringMessage) + channel.add_message_handler(client_message_received) + + # Inform the user that the server is + # connected + RNS.log("Link established with server, enter some text to send, or \"quit\" to quit") + +# When a link is closed, we'll inform the +# user, and exit the program +def link_closed(link): + if link.teardown_reason == RNS.Link.TIMEOUT: + RNS.log("The link timed out, exiting now") + elif link.teardown_reason == RNS.Link.DESTINATION_CLOSED: + RNS.log("The link was closed by the server, exiting now") + else: + RNS.log("Link closed, exiting now") + + RNS.Reticulum.exit_handler() + time.sleep(1.5) + os._exit(0) + +# When a packet is received over the link, we +# simply print out the data. +def client_message_received(message): + if isinstance(message, StringMessage): + RNS.log("Received data on the link: " + message.data + " (message created at " + str(message.timestamp) + ")") + print("> ", end=" ") + sys.stdout.flush() + + +########################################################## +#### Program Startup ##################################### +########################################################## + +# This part of the program runs at startup, +# and parses input of from the user, and then +# starts up the desired program mode. +if __name__ == "__main__": + try: + parser = argparse.ArgumentParser(description="Simple link example") + + parser.add_argument( + "-s", + "--server", + action="store_true", + help="wait for incoming link requests from clients" + ) + + parser.add_argument( + "--config", + action="store", + default=None, + help="path to alternative Reticulum config directory", + type=str + ) + + parser.add_argument( + "destination", + nargs="?", + default=None, + help="hexadecimal hash of the server destination", + type=str + ) + + args = parser.parse_args() + + if args.config: + configarg = args.config + else: + configarg = None + + if args.server: + server(configarg) + else: + if (args.destination == None): + print("") + parser.print_help() + print("") + else: + client(args.destination, configarg) + + except KeyboardInterrupt: + print("") + exit() \ No newline at end of file diff --git a/RNS/Channel.py b/RNS/Channel.py new file mode 100644 index 0000000..839bf27 --- /dev/null +++ b/RNS/Channel.py @@ -0,0 +1,521 @@ +# MIT License +# +# Copyright (c) 2016-2023 Mark Qvist / unsigned.io and contributors. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations +import collections +import enum +import threading +import time +from types import TracebackType +from typing import Type, Callable, TypeVar, Generic, NewType +import abc +import contextlib +import struct +import RNS +from abc import ABC, abstractmethod +TPacket = TypeVar("TPacket") + + +class ChannelOutletBase(ABC, Generic[TPacket]): + """ + An abstract transport layer interface used by Channel. + + DEPRECATED: This was created for testing; eventually + Channel will use Link or a LinkBase interface + directly. + """ + @abstractmethod + def send(self, raw: bytes) -> TPacket: + raise NotImplemented() + + @abstractmethod + def resend(self, packet: TPacket) -> TPacket: + raise NotImplemented() + + @property + @abstractmethod + def mdu(self): + raise NotImplemented() + + @property + @abstractmethod + def rtt(self): + raise NotImplemented() + + @property + @abstractmethod + def is_usable(self): + raise NotImplemented() + + @abstractmethod + def get_packet_state(self, packet: TPacket) -> MessageState: + raise NotImplemented() + + @abstractmethod + def timed_out(self): + raise NotImplemented() + + @abstractmethod + def __str__(self): + raise NotImplemented() + + @abstractmethod + def set_packet_timeout_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None, + timeout: float | None = None): + raise NotImplemented() + + @abstractmethod + def set_packet_delivered_callback(self, packet: TPacket, callback: Callable[[TPacket], None] | None): + raise NotImplemented() + + @abstractmethod + def get_packet_id(self, packet: TPacket) -> any: + raise NotImplemented() + + +class CEType(enum.IntEnum): + """ + ChannelException type codes + """ + ME_NO_MSG_TYPE = 0 + ME_INVALID_MSG_TYPE = 1 + ME_NOT_REGISTERED = 2 + ME_LINK_NOT_READY = 3 + ME_ALREADY_SENT = 4 + ME_TOO_BIG = 5 + + +class ChannelException(Exception): + """ + An exception thrown by Channel, with a type code. + """ + def __init__(self, ce_type: CEType, *args): + super().__init__(args) + self.type = ce_type + + +class MessageState(enum.IntEnum): + """ + Set of possible states for a Message + """ + MSGSTATE_NEW = 0 + MSGSTATE_SENT = 1 + MSGSTATE_DELIVERED = 2 + MSGSTATE_FAILED = 3 + + +class MessageBase(abc.ABC): + """ + Base type for any messages sent or received on a Channel. + Subclasses must define the two abstract methods as well as + the ``MSGTYPE`` class variable. + """ + # MSGTYPE must be unique within all classes sent over a + # channel. Additionally, MSGTYPE > 0xf000 are reserved. + MSGTYPE = None + """ + Defines a unique identifier for a message class. + + * Must be unique within all classes registered with a ``Channel`` + * Must be less than ``0xf000``. Values greater than or equal to ``0xf000`` are reserved. + """ + + @abstractmethod + def pack(self) -> bytes: + """ + Create and return the binary representation of the message + + :return: binary representation of message + """ + raise NotImplemented() + + @abstractmethod + def unpack(self, raw: bytes): + """ + Populate message from binary representation + + :param raw: binary representation + """ + raise NotImplemented() + + +MessageCallbackType = NewType("MessageCallbackType", Callable[[MessageBase], bool]) + + +class Envelope: + """ + Internal wrapper used to transport messages over a channel and + track its state within the channel framework. + """ + def unpack(self, message_factories: dict[int, Type]) -> MessageBase: + msgtype, self.sequence, length = struct.unpack(">HHH", self.raw[:6]) + raw = self.raw[6:] + ctor = message_factories.get(msgtype, None) + if ctor is None: + raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}") + message = ctor() + message.unpack(raw) + return message + + def pack(self) -> bytes: + if self.message.__class__.MSGTYPE is None: + raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE") + data = self.message.pack() + self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data + return self.raw + + def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None): + self.ts = time.time() + self.id = id(self) + self.message = message + self.raw = raw + self.packet: TPacket = None + self.sequence = sequence + self.outlet = outlet + self.tries = 0 + self.tracked = False + + +class Channel(contextlib.AbstractContextManager): + """ + Provides reliable delivery of messages over + a link. + + ``Channel`` differs from ``Request`` and + ``Resource`` in some important ways: + + **Continuous** + Messages can be sent or received as long as + the ``Link`` is open. + **Bi-directional** + Messages can be sent in either direction on + the ``Link``; neither end is the client or + server. + **Size-constrained** + Messages must be encoded into a single packet. + + ``Channel`` is similar to ``Packet``, except that it + provides reliable delivery (automatic retries) as well + as a structure for exchanging several types of + messages over the ``Link``. + + ``Channel`` is not instantiated directly, but rather + obtained from a ``Link`` with ``get_channel()``. + """ + def __init__(self, outlet: ChannelOutletBase): + """ + + @param outlet: + """ + self._outlet = outlet + self._lock = threading.RLock() + self._tx_ring: collections.deque[Envelope] = collections.deque() + self._rx_ring: collections.deque[Envelope] = collections.deque() + self._message_callbacks: [MessageCallbackType] = [] + self._next_sequence = 0 + self._message_factories: dict[int, Type[MessageBase]] = {} + self._max_tries = 5 + + def __enter__(self) -> Channel: + return self + + def __exit__(self, __exc_type: Type[BaseException] | None, __exc_value: BaseException | None, + __traceback: TracebackType | None) -> bool | None: + self._shutdown() + return False + + def register_message_type(self, message_class: Type[MessageBase]): + """ + Register a message class for reception over a ``Channel``. + + Message classes must extend ``MessageBase``. + + :param message_class: Class to register + """ + self._register_message_type(message_class, is_system_type=False) + + def _register_message_type(self, message_class: Type[MessageBase], *, is_system_type: bool = False): + with self._lock: + if not issubclass(message_class, MessageBase): + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} is not a subclass of {MessageBase}.") + if message_class.MSGTYPE is None: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} has invalid MSGTYPE class attribute.") + if message_class.MSGTYPE >= 0xf000 and not is_system_type: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} has system-reserved message type.") + try: + message_class() + except Exception as ex: + raise ChannelException(CEType.ME_INVALID_MSG_TYPE, + f"{message_class} raised an exception when constructed with no arguments: {ex}") + + self._message_factories[message_class.MSGTYPE] = message_class + + def add_message_handler(self, callback: MessageCallbackType): + """ + Add a handler for incoming messages. A handler + has the following signature: + + ``(message: MessageBase) -> bool`` + + Handlers are processed in the order they are + added. If any handler returns True, processing + of the message stops; handlers after the + returning handler will not be called. + + :param callback: Function to call + """ + with self._lock: + if callback not in self._message_callbacks: + self._message_callbacks.append(callback) + + def remove_message_handler(self, callback: MessageCallbackType): + """ + Remove a handler added with ``add_message_handler``. + + :param callback: handler to remove + """ + with self._lock: + self._message_callbacks.remove(callback) + + def _shutdown(self): + with self._lock: + self._message_callbacks.clear() + self._clear_rings() + + def _clear_rings(self): + with self._lock: + for envelope in self._tx_ring: + if envelope.packet is not None: + self._outlet.set_packet_timeout_callback(envelope.packet, None) + self._outlet.set_packet_delivered_callback(envelope.packet, None) + self._tx_ring.clear() + self._rx_ring.clear() + + def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool: + with self._lock: + i = 0 + for existing in ring: + if existing.sequence > envelope.sequence \ + and not existing.sequence // 2 > envelope.sequence: # account for overflow + ring.insert(i, envelope) + return True + if existing.sequence == envelope.sequence: + RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME) + return False + i += 1 + envelope.tracked = True + ring.append(envelope) + return True + + def _prune_rx_ring(self): + with self._lock: + # Implementation for fixed window = 1 + stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:] + for env in stale: + env.tracked = False + self._rx_ring.remove(env) + + def _run_callbacks(self, message: MessageBase): + with self._lock: + cbs = self._message_callbacks.copy() + + for cb in cbs: + try: + if cb(message): + return + except Exception as ex: + RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR) + + def _receive(self, raw: bytes): + try: + envelope = Envelope(outlet=self._outlet, raw=raw) + with self._lock: + message = envelope.unpack(self._message_factories) + is_new = self._emplace_envelope(envelope, self._rx_ring) + self._prune_rx_ring() + if not is_new: + RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) + return + RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) + threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start() + except Exception as ex: + RNS.log(f"Channel: Error receiving data: {ex}") + + def is_ready_to_send(self) -> bool: + """ + Check if ``Channel`` is ready to send. + + :return: True if ready + """ + if not self._outlet.is_usable: + RNS.log("Channel: Link is not usable.", RNS.LOG_EXTREME) + return False + + with self._lock: + for envelope in self._tx_ring: + if envelope.outlet == self._outlet and (not envelope.packet + or self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_SENT): + RNS.log("Channel: Link has a pending message.", RNS.LOG_EXTREME) + return False + return True + + def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]): + with self._lock: + envelope = next(filter(lambda e: self._outlet.get_packet_id(e.packet) == self._outlet.get_packet_id(packet), + self._tx_ring), None) + if envelope and op(envelope): + envelope.tracked = False + if envelope in self._tx_ring: + self._tx_ring.remove(envelope) + else: + RNS.log("Channel: Envelope not found in TX ring", RNS.LOG_DEBUG) + if not envelope: + RNS.log("Channel: Spurious message received.", RNS.LOG_EXTREME) + + def _packet_delivered(self, packet: TPacket): + self._packet_tx_op(packet, lambda env: True) + + def _packet_timeout(self, packet: TPacket): + def retry_envelope(envelope: Envelope) -> bool: + if envelope.tries >= self._max_tries: + RNS.log("Channel: Retry count exceeded, tearing down Link.", RNS.LOG_ERROR) + self._shutdown() # start on separate thread? + self._outlet.timed_out() + return True + envelope.tries += 1 + self._outlet.resend(envelope.packet) + return False + + self._packet_tx_op(packet, retry_envelope) + + def send(self, message: MessageBase) -> Envelope: + """ + Send a message. If a message send is attempted and + ``Channel`` is not ready, an exception is thrown. + + :param message: an instance of a ``MessageBase`` subclass + """ + envelope: Envelope | None = None + with self._lock: + if not self.is_ready_to_send(): + raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready") + envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence) + self._next_sequence = (self._next_sequence + 1) % 0x10000 + self._emplace_envelope(envelope, self._tx_ring) + if envelope is None: + raise BlockingIOError() + + envelope.pack() + if len(envelope.raw) > self._outlet.mdu: + raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}") + envelope.packet = self._outlet.send(envelope.raw) + envelope.tries += 1 + self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) + self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout) + return envelope + + @property + def MDU(self): + """ + Maximum Data Unit: the number of bytes available + for a message to consume in a single send. This + value is adjusted from the ``Link`` MDU to accommodate + message header information. + + :return: number of bytes available + """ + return self._outlet.mdu - 6 # sizeof(msgtype) + sizeof(length) + sizeof(sequence) + + +class LinkChannelOutlet(ChannelOutletBase): + """ + An implementation of ChannelOutletBase for RNS.Link. + Allows Channel to send packets over an RNS Link with + Packets. + + :param link: RNS Link to wrap + """ + def __init__(self, link: RNS.Link): + self.link = link + + def send(self, raw: bytes) -> RNS.Packet: + packet = RNS.Packet(self.link, raw, context=RNS.Packet.CHANNEL) + packet.send() + return packet + + def resend(self, packet: RNS.Packet) -> RNS.Packet: + if not packet.resend(): + RNS.log("Failed to resend packet", RNS.LOG_ERROR) + return packet + + @property + def mdu(self): + return self.link.MDU + + @property + def rtt(self): + return self.link.rtt + + @property + def is_usable(self): + return True # had issues looking at Link.status + + def get_packet_state(self, packet: TPacket) -> MessageState: + status = packet.receipt.get_status() + if status == RNS.PacketReceipt.SENT: + return MessageState.MSGSTATE_SENT + if status == RNS.PacketReceipt.DELIVERED: + return MessageState.MSGSTATE_DELIVERED + if status == RNS.PacketReceipt.FAILED: + return MessageState.MSGSTATE_FAILED + else: + raise Exception(f"Unexpected receipt state: {status}") + + def timed_out(self): + self.link.teardown() + + def __str__(self): + return f"{self.__class__.__name__}({self.link})" + + def set_packet_timeout_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None, + timeout: float | None = None): + if timeout: + packet.receipt.set_timeout(timeout) + + def inner(receipt: RNS.PacketReceipt): + callback(packet) + + if packet and packet.receipt: + packet.receipt.set_timeout_callback(inner if callback else None) + + def set_packet_delivered_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None): + def inner(receipt: RNS.PacketReceipt): + callback(packet) + + if packet and packet.receipt: + packet.receipt.set_delivery_callback(inner if callback else None) + + def get_packet_id(self, packet: RNS.Packet) -> any: + return packet.get_hash() diff --git a/RNS/Link.py b/RNS/Link.py index 00cca34..822e1fc 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -22,6 +22,7 @@ from RNS.Cryptography import X25519PrivateKey, X25519PublicKey, Ed25519PrivateKey, Ed25519PublicKey from RNS.Cryptography import Fernet +from RNS.Channel import Channel, LinkChannelOutlet from time import sleep from .vendor import umsgpack as umsgpack @@ -163,6 +164,7 @@ class Link: self.destination = destination self.attached_interface = None self.__remote_identity = None + self._channel = None if self.destination == None: self.initiator = False self.prv = X25519PrivateKey.generate() @@ -462,6 +464,8 @@ class Link: resource.cancel() for resource in self.outgoing_resources: resource.cancel() + if self._channel: + self._channel._shutdown() self.prv = None self.pub = None @@ -642,6 +646,16 @@ class Link: if pending_request.request_id == resource.request_id: pending_request.request_timed_out(None) + def get_channel(self): + """ + Get the ``Channel`` for this link. + + :return: ``Channel`` object + """ + if self._channel is None: + self._channel = Channel(LinkChannelOutlet(self)) + return self._channel + def receive(self, packet): self.watchdog_lock = True if not self.status == Link.CLOSED and not (self.initiator and packet.context == RNS.Packet.KEEPALIVE and packet.data == bytes([0xFF])): @@ -788,6 +802,14 @@ class Link: for resource in self.incoming_resources: resource.receive_part(packet) + elif packet.context == RNS.Packet.CHANNEL: + if not self._channel: + RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) + else: + plaintext = self.decrypt(packet.data) + self._channel._receive(plaintext) + packet.prove() + elif packet.packet_type == RNS.Packet.PROOF: if packet.context == RNS.Packet.RESOURCE_PRF: resource_hash = packet.data[0:RNS.Identity.HASHLENGTH//8] diff --git a/RNS/Packet.py b/RNS/Packet.py index cdc476c..a105dec 100755 --- a/RNS/Packet.py +++ b/RNS/Packet.py @@ -75,6 +75,7 @@ class Packet: PATH_RESPONSE = 0x0B # Packet is a response to a path request COMMAND = 0x0C # Packet is a command COMMAND_STATUS = 0x0D # Packet is a status of an executed command + CHANNEL = 0x0E # Packet contains link channel data KEEPALIVE = 0xFA # Packet is a keepalive packet LINKIDENTIFY = 0xFB # Packet is a link peer identification proof LINKCLOSE = 0xFC # Packet is a link close message diff --git a/RNS/__init__.py b/RNS/__init__.py index 3c2f2ae..0ec2140 100755 --- a/RNS/__init__.py +++ b/RNS/__init__.py @@ -32,6 +32,7 @@ from ._version import __version__ from .Reticulum import Reticulum from .Identity import Identity from .Link import Link, RequestReceipt +from .Channel import MessageBase from .Transport import Transport from .Destination import Destination from .Packet import Packet diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 9b4428f..54c13f3 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -92,6 +92,18 @@ The *Request* example explores sendig requests and receiving responses. This example can also be found at ``_. +.. _example-channel: + +Channel +==== + +The *Channel* example explores using a ``Channel`` to send structured +data between peers of a ``Link``. + +.. literalinclude:: ../../Examples/Channel.py + +This example can also be found at ``_. + .. _example-filetransfer: Filetransfer diff --git a/docs/source/reference.rst b/docs/source/reference.rst index 017cf8d..8d519a8 100644 --- a/docs/source/reference.rst +++ b/docs/source/reference.rst @@ -121,6 +121,34 @@ This chapter lists and explains all classes exposed by the Reticulum Network Sta .. autoclass:: RNS.Resource(data, link, advertise=True, auto_compress=True, callback=None, progress_callback=None, timeout=None) :members: +.. _api-channel: + +.. only:: html + + |start-h3| Channel |end-h3| + +.. only:: latex + + Channel + ------ + +.. autoclass:: RNS.Channel.Channel() + :members: + +.. _api-messsagebase: + +.. only:: html + + |start-h3| MessageBase |end-h3| + +.. only:: latex + + MessageBase + ------ + +.. autoclass:: RNS.MessageBase() + :members: + .. _api-transport: .. only:: html diff --git a/tests/all.py b/tests/all.py index 27863a4..17556b1 100644 --- a/tests/all.py +++ b/tests/all.py @@ -4,6 +4,7 @@ from .hashes import TestSHA256 from .hashes import TestSHA512 from .identity import TestIdentity from .link import TestLink +from .channel import TestChannel if __name__ == '__main__': unittest.main(verbosity=2) \ No newline at end of file diff --git a/tests/channel.py b/tests/channel.py new file mode 100644 index 0000000..3a00cbe --- /dev/null +++ b/tests/channel.py @@ -0,0 +1,375 @@ +from __future__ import annotations +import threading +import RNS +from RNS.Channel import MessageState, ChannelOutletBase, Channel, MessageBase +from RNS.vendor import umsgpack +from typing import Callable +import contextlib +import typing +import types +import time +import uuid +import unittest + + +class Packet: + timeout = 1.0 + + def __init__(self, raw: bytes): + self.state = MessageState.MSGSTATE_NEW + self.raw = raw + self.packet_id = uuid.uuid4() + self.tries = 0 + self.timeout_id = None + self.lock = threading.RLock() + self.instances = 0 + self.timeout_callback: Callable[[Packet], None] | None = None + self.delivered_callback: Callable[[Packet], None] | None = None + + def set_timeout(self, callback: Callable[[Packet], None] | None, timeout: float): + with self.lock: + if timeout is not None: + self.timeout = timeout + self.timeout_callback = callback + + + def send(self): + self.tries += 1 + self.state = MessageState.MSGSTATE_SENT + + def elapsed(timeout: float, timeout_id: uuid.uuid4): + with self.lock: + self.instances += 1 + try: + time.sleep(timeout) + with self.lock: + if self.timeout_id == timeout_id: + self.timeout_id = None + self.state = MessageState.MSGSTATE_FAILED + if self.timeout_callback: + self.timeout_callback(self) + finally: + with self.lock: + self.instances -= 1 + + self.timeout_id = uuid.uuid4() + threading.Thread(target=elapsed, name="Packet Timeout", args=[self.timeout, self.timeout_id], + daemon=True).start() + + def clear_timeout(self): + self.timeout_id = None + + def set_delivered_callback(self, callback: Callable[[Packet], None]): + self.delivered_callback = callback + + def delivered(self): + with self.lock: + self.state = MessageState.MSGSTATE_DELIVERED + self.timeout_id = None + if self.delivered_callback: + self.delivered_callback(self) + + +class ChannelOutletTest(ChannelOutletBase): + def get_packet_state(self, packet: Packet) -> MessageState: + return packet.state + + def set_packet_timeout_callback(self, packet: Packet, callback: Callable[[Packet], None] | None, + timeout: float | None = None): + packet.set_timeout(callback, timeout) + + def set_packet_delivered_callback(self, packet: Packet, callback: Callable[[Packet], None] | None): + packet.set_delivered_callback(callback) + + def get_packet_id(self, packet: Packet) -> any: + return packet.packet_id + + def __init__(self, mdu: int, rtt: float): + self.link_id = uuid.uuid4() + self.timeout_callbacks = 0 + self._mdu = mdu + self._rtt = rtt + self._usable = True + self.packets = [] + self.packet_callback: Callable[[ChannelOutletBase, bytes], None] | None = None + + def send(self, raw: bytes) -> Packet: + packet = Packet(raw) + packet.send() + self.packets.append(packet) + return packet + + def resend(self, packet: Packet) -> Packet: + packet.send() + return packet + + @property + def mdu(self): + return self._mdu + + @property + def rtt(self): + return self._rtt + + @property + def is_usable(self): + return self._usable + + def timed_out(self): + self.timeout_callbacks += 1 + + def __str__(self): + return str(self.link_id) + + +class MessageTest(MessageBase): + MSGTYPE = 0xabcd + + def __init__(self): + self.id = str(uuid.uuid4()) + self.data = "test" + self.not_serialized = str(uuid.uuid4()) + + def pack(self) -> bytes: + return umsgpack.packb((self.id, self.data)) + + def unpack(self, raw): + self.id, self.data = umsgpack.unpackb(raw) + + +class SystemMessage(MessageBase): + MSGTYPE = 0xf000 + + def pack(self) -> bytes: + return bytes() + + def unpack(self, raw): + pass + + +class ProtocolHarness(contextlib.AbstractContextManager): + def __init__(self, rtt: float): + self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) + self.channel = Channel(self.outlet) + + def cleanup(self): + self.channel._shutdown() + + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: + # self._log.debug(f"__exit__({__exc_type}, {__exc_value}, {__traceback})") + self.cleanup() + return False + + +class TestChannel(unittest.TestCase): + def setUp(self) -> None: + print("") + self.rtt = 0.001 + self.retry_interval = self.rtt * 150 + Packet.timeout = self.retry_interval + self.h = ProtocolHarness(self.rtt) + + def tearDown(self) -> None: + self.h.cleanup() + + def test_send_one_retry(self): + print("Channel test one retry") + message = MessageTest() + + self.assertEqual(0, len(self.h.outlet.packets)) + + envelope = self.h.channel.send(message) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + time.sleep(self.retry_interval * 1.5) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(2, envelope.tries) + self.assertEqual(2, packet.tries) + self.assertEqual(1, packet.instances) + + time.sleep(self.retry_interval) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(self.h.outlet.packets[0], packet) + self.assertEqual(3, envelope.tries) + self.assertEqual(3, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + + packet.delivered() + + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + + time.sleep(self.retry_interval) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(3, envelope.tries) + self.assertEqual(3, packet.tries) + self.assertEqual(0, packet.instances) + self.assertFalse(envelope.tracked) + + def test_send_timeout(self): + print("Channel test retry count exceeded") + message = MessageTest() + + self.assertEqual(0, len(self.h.outlet.packets)) + + envelope = self.h.channel.send(message) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + time.sleep(self.retry_interval * 7.5) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(5, envelope.tries) + self.assertEqual(5, packet.tries) + self.assertEqual(0, packet.instances) + self.assertEqual(MessageState.MSGSTATE_FAILED, packet.state) + self.assertFalse(envelope.tracked) + + def test_multiple_handler(self): + print("Channel test multiple handler short circuit") + + handler1_called = 0 + handler1_return = True + handler2_called = 0 + + def handler1(msg: MessageBase): + nonlocal handler1_called, handler1_return + self.assertIsInstance(msg, MessageTest) + handler1_called += 1 + return handler1_return + + def handler2(msg: MessageBase): + nonlocal handler2_called + self.assertIsInstance(msg, MessageTest) + handler2_called += 1 + + message = MessageTest() + self.h.channel.register_message_type(MessageTest) + self.h.channel.add_message_handler(handler1) + self.h.channel.add_message_handler(handler2) + envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=0) + raw = envelope.pack() + self.h.channel._receive(raw) + + self.assertEqual(1, handler1_called) + self.assertEqual(0, handler2_called) + + handler1_return = False + envelope = RNS.Channel.Envelope(self.h.outlet, message, sequence=1) + raw = envelope.pack() + self.h.channel._receive(raw) + + self.assertEqual(2, handler1_called) + self.assertEqual(1, handler2_called) + + def test_system_message_check(self): + print("Channel test register system message") + with self.assertRaises(RNS.Channel.ChannelException): + self.h.channel.register_message_type(SystemMessage) + self.h.channel._register_message_type(SystemMessage, is_system_type=True) + + + def eat_own_dog_food(self, message: MessageBase, checker: typing.Callable[[MessageBase], None]): + decoded: [MessageBase] = [] + + def handle_message(message: MessageBase): + decoded.append(message) + + self.h.channel.register_message_type(message.__class__) + self.h.channel.add_message_handler(handle_message) + self.assertEqual(len(self.h.outlet.packets), 0) + + envelope = self.h.channel.send(message) + time.sleep(self.retry_interval * 0.5) + + self.assertIsNotNone(envelope) + self.assertIsNotNone(envelope.raw) + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertIsNotNone(envelope.packet) + self.assertTrue(envelope in self.h.channel._tx_ring) + self.assertTrue(envelope.tracked) + + packet = self.h.outlet.packets[0] + + self.assertEqual(envelope.packet, packet) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(1, packet.instances) + self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) + self.assertEqual(envelope.raw, packet.raw) + + packet.delivered() + + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + + time.sleep(self.retry_interval * 2) + + self.assertEqual(1, len(self.h.outlet.packets)) + self.assertEqual(1, envelope.tries) + self.assertEqual(1, packet.tries) + self.assertEqual(0, packet.instances) + self.assertFalse(envelope.tracked) + + self.assertEqual(len(self.h.outlet.packets), 1) + self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) + self.assertFalse(envelope.tracked) + self.assertEqual(0, len(decoded)) + + self.h.channel._receive(packet.raw) + + self.assertEqual(1, len(decoded)) + + rx_message = decoded[0] + + self.assertIsNotNone(rx_message) + self.assertIsInstance(rx_message, message.__class__) + checker(rx_message) + + def test_send_receive_message_test(self): + print("Channel test send and receive message") + message = MessageTest() + + def check(rx_message: MessageBase): + self.assertIsInstance(rx_message, message.__class__) + self.assertEqual(message.id, rx_message.id) + self.assertEqual(message.data, rx_message.data) + self.assertNotEqual(message.not_serialized, rx_message.not_serialized) + + self.eat_own_dog_food(message, check) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/tests/link.py b/tests/link.py index 5368858..203e982 100644 --- a/tests/link.py +++ b/tests/link.py @@ -6,6 +6,8 @@ import threading import time import RNS import os +from tests.channel import MessageTest +from RNS.Channel import MessageBase APP_NAME = "rns_unit_tests" @@ -346,6 +348,54 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + def test_10_channel_round_trip(self): + global c_rns + init_rns(self) + print("") + print("Channel round trip test") + + # TODO: Load this from public bytes only + id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) + self.assertEqual(id1.hash, bytes.fromhex(fixed_keys[0][1])) + + dest = RNS.Destination(id1, RNS.Destination.OUT, RNS.Destination.SINGLE, APP_NAME, "link", "establish") + + self.assertEqual(dest.hash, bytes.fromhex("fb48da0e82e6e01ba0c014513f74540d")) + + l1 = RNS.Link(dest) + time.sleep(1) + self.assertEqual(l1.status, RNS.Link.ACTIVE) + + received = [] + + def handle_message(message: MessageBase): + received.append(message) + + test_message = MessageTest() + test_message.data = "Hello" + + channel = l1.get_channel() + channel.register_message_type(MessageTest) + channel.add_message_handler(handle_message) + channel.send(test_message) + + time.sleep(0.5) + + self.assertEqual(1, len(received)) + + rx_message = received[0] + + self.assertIsInstance(rx_message, MessageTest) + self.assertEqual("Hello back", rx_message.data) + self.assertEqual(test_message.id, rx_message.id) + self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized) + self.assertEqual(1, len(l1._channel._rx_ring)) + + l1.teardown() + time.sleep(0.5) + self.assertEqual(l1.status, RNS.Link.CLOSED) + self.assertEqual(0, len(l1._channel._rx_ring)) + def size_str(self, num, suffix='B'): units = ['','K','M','G','T','P','E','Z'] @@ -404,6 +454,13 @@ def targets(yp=False): link.set_resource_strategy(RNS.Link.ACCEPT_ALL) link.set_resource_started_callback(resource_started) link.set_resource_concluded_callback(resource_concluded) + channel = link.get_channel() + + def handle_message(message): + message.data = message.data + " back" + channel.send(message) + channel.register_message_type(MessageTest) + channel.add_message_handler(handle_message) m_rns = RNS.Reticulum("./tests/rnsconfig") id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))