From 6d9d410a703d0d507d4884075bc9f131ab553719 Mon Sep 17 00:00:00 2001 From: Aaron Heise <5148966+acehoss@users.noreply.github.com> Date: Sat, 4 Mar 2023 23:37:58 -0600 Subject: [PATCH] Address multiple issues with Buffer and Channel - StreamDataMessage now packed by struct rather than umsgpack for a more predictable size - Added protected variable on LocalInterface to allow tests to simulate a low bandwidth connection - Retry timer now has exponential backoff and a more sane starting value - Link proves packet _before_ sending contents to Channel; this should help prevent spurious retries especially on half-duplex links - Prevent Transport packet filter from filtering out duplicate packets for Channel; handle duplicates in Channel to ensure the packet is reproven (in case the original proof packet was lost) - Fix up other tests broken by these changes --- RNS/Buffer.py | 50 +++++--------- RNS/Channel.py | 16 ++++- RNS/Interfaces/LocalInterface.py | 7 ++ RNS/Link.py | 2 +- RNS/Transport.py | 2 + tests/channel.py | 43 +++++++++--- tests/link.py | 111 +++++++++++++++++++++++++++++++ 7 files changed, 186 insertions(+), 45 deletions(-) diff --git a/RNS/Buffer.py b/RNS/Buffer.py index 55f1d98..bcba3b9 100644 --- a/RNS/Buffer.py +++ b/RNS/Buffer.py @@ -1,7 +1,8 @@ from __future__ import annotations import sys +import threading from threading import RLock -from RNS.vendor import umsgpack +import struct from RNS.Channel import Channel, MessageBase, SystemMessageTypes import RNS from io import RawIOBase, BufferedRWPair, BufferedReader, BufferedWriter @@ -16,22 +17,12 @@ class StreamDataMessage(MessageBase): uses a system-reserved message type. """ - STREAM_ID_MAX = 65535 + STREAM_ID_MAX = 0x7fff # 32767 """ - While not essential for the current message packing - method (umsgpack), the stream id is clamped to the - size of a UInt16 for future struct packing. + The stream id is limited to 2 bytes - 1 bit """ - OVERHEAD = 0 - """ - The number of bytes used by this messa - - When the Buffer package is imported, this value is - calculated based on the value of RNS.Link.MDU. - """ - - MAX_DATA_LEN = 0 + MAX_DATA_LEN = RNS.Link.MDU - 2 - 6 # 2 for stream data message header, 6 for channel envelope """ When the Buffer package is imported, this value is calculcated based on the value of OVERHEAD @@ -48,7 +39,7 @@ class StreamDataMessage(MessageBase): """ super().__init__() if stream_id is not None and stream_id > self.STREAM_ID_MAX: - raise ValueError("stream_id must be 0-65535") + raise ValueError("stream_id must be 0-32767") self.stream_id = stream_id self.data = data or bytes() self.eof = eof @@ -56,18 +47,14 @@ class StreamDataMessage(MessageBase): def pack(self) -> bytes: if self.stream_id is None: raise ValueError("stream_id") - return umsgpack.packb((self.stream_id, self.eof, bytes(self.data))) + header_val = (0x7fff & self.stream_id) | (0x8000 if self.eof else 0x0000) + return bytes(struct.pack(">H", header_val) + (self.data if self.data else bytes())) def unpack(self, raw): - self.stream_id, self.eof, self.data = umsgpack.unpackb(raw) - - -_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8") -StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=StreamDataMessage.STREAM_ID_MAX, - data=_link_sized_bytes, - eof=True).pack()) - len(_link_sized_bytes) + 4 # TODO: Calculation was off by 10 bytes, why? -StreamDataMessage.MAX_DATA_LEN = RNS.Link.MDU - StreamDataMessage.OVERHEAD -_link_sized_bytes = None + self.stream_id = struct.unpack(">H", raw[:2])[0] + self.eof = (0x8000 & self.stream_id) > 0 + self.stream_id = self.stream_id & 0x7fff + self.data = raw[2:] class RawChannelReader(RawIOBase, AbstractContextManager): @@ -144,9 +131,9 @@ class RawChannelReader(RawIOBase, AbstractContextManager): def readinto(self, __buffer: bytearray) -> int | None: ready = self._read(len(__buffer)) - if ready: + if ready is not None: __buffer[:len(ready)] = ready - return len(ready) if ready else None + return len(ready) if ready is not None else None def writable(self) -> bool: return False @@ -198,11 +185,10 @@ class RawChannelWriter(RawIOBase, AbstractContextManager): def write(self, __b: bytes) -> int | None: try: - if self._channel.is_ready_to_send(): - chunk = __b[:StreamDataMessage.MAX_DATA_LEN] - message = StreamDataMessage(self._stream_id, chunk, self._eof) - self._channel.send(message) - return len(chunk) + chunk = bytes(__b[:StreamDataMessage.MAX_DATA_LEN]) + message = StreamDataMessage(self._stream_id, chunk, self._eof) + self._channel.send(message) + return len(chunk) except RNS.Channel.ChannelException as cex: if cex.type != RNS.Channel.CEType.ME_LINK_NOT_READY: raise diff --git a/RNS/Channel.py b/RNS/Channel.py index ff88669..c0f5442 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -356,6 +356,10 @@ class Channel(contextlib.AbstractContextManager): envelope = Envelope(outlet=self._outlet, raw=raw) with self._lock: message = envelope.unpack(self._message_factories) + prev_env = self._rx_ring[0] if len(self._rx_ring) > 0 else None + if prev_env and envelope.sequence != prev_env.sequence + 1: + RNS.log("Channel: Out of order packet received", RNS.LOG_DEBUG) + return is_new = self._emplace_envelope(envelope, self._rx_ring) self._prune_rx_ring() if not is_new: @@ -403,6 +407,9 @@ class Channel(contextlib.AbstractContextManager): def _packet_delivered(self, packet: TPacket): self._packet_tx_op(packet, lambda env: True) + def _get_packet_timeout_time(self, tries: int) -> float: + return pow(2, tries - 1) * max(self._outlet.rtt, 0.01) * 5 + def _packet_timeout(self, packet: TPacket): def retry_envelope(envelope: Envelope) -> bool: if envelope.tries >= self._max_tries: @@ -412,9 +419,11 @@ class Channel(contextlib.AbstractContextManager): return True envelope.tries += 1 self._outlet.resend(envelope.packet) + self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries)) return False - self._packet_tx_op(packet, retry_envelope) + if self._outlet.get_packet_state(packet) != MessageState.MSGSTATE_DELIVERED: + self._packet_tx_op(packet, retry_envelope) def send(self, message: MessageBase) -> Envelope: """ @@ -439,7 +448,7 @@ class Channel(contextlib.AbstractContextManager): envelope.packet = self._outlet.send(envelope.raw) envelope.tries += 1 self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) - self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout) + self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries)) return envelope @property @@ -473,6 +482,7 @@ class LinkChannelOutlet(ChannelOutletBase): return packet def resend(self, packet: RNS.Packet) -> RNS.Packet: + RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG) if not packet.resend(): RNS.log("Failed to resend packet", RNS.LOG_ERROR) return packet @@ -511,7 +521,7 @@ class LinkChannelOutlet(ChannelOutletBase): def set_packet_timeout_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None, timeout: float | None = None): - if timeout: + if timeout and packet.receipt: packet.receipt.set_timeout(timeout) def inner(receipt: RNS.PacketReceipt): diff --git a/RNS/Interfaces/LocalInterface.py b/RNS/Interfaces/LocalInterface.py index 2710a2d..937a5d3 100644 --- a/RNS/Interfaces/LocalInterface.py +++ b/RNS/Interfaces/LocalInterface.py @@ -86,6 +86,8 @@ class LocalClientInterface(Interface): self.online = True self.writing = False + self._force_bitrate = False + self.announce_rate_target = None self.announce_rate_grace = None self.announce_rate_penalty = None @@ -137,6 +139,9 @@ class LocalClientInterface(Interface): def processIncoming(self, data): + if self._force_bitrate: + time.sleep(len(data) / self.bitrate * 8) + self.rxb += len(data) if hasattr(self, "parent_interface") and self.parent_interface != None: self.parent_interface.rxb += len(data) @@ -154,6 +159,8 @@ class LocalClientInterface(Interface): if self.online: try: self.writing = True + if self._force_bitrate: + time.sleep(len(data) / self.bitrate * 8) data = bytes([HDLC.FLAG])+HDLC.escape(data)+bytes([HDLC.FLAG]) self.socket.sendall(data) self.writing = False diff --git a/RNS/Link.py b/RNS/Link.py index 1624253..4abb9bc 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -809,9 +809,9 @@ class Link: if not self._channel: RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) else: + packet.prove() plaintext = self.decrypt(packet.data) self._channel._receive(plaintext) - packet.prove() elif packet.packet_type == RNS.Packet.PROOF: if packet.context == RNS.Packet.RESOURCE_PRF: diff --git a/RNS/Transport.py b/RNS/Transport.py index 7de6da4..3f54068 100755 --- a/RNS/Transport.py +++ b/RNS/Transport.py @@ -882,6 +882,8 @@ class Transport: return True if packet.context == RNS.Packet.CACHE_REQUEST: return True + if packet.context == RNS.Packet.CHANNEL: + return True if packet.destination_type == RNS.Destination.PLAIN: if packet.packet_type != RNS.Packet.ANNOUNCE: diff --git a/tests/channel.py b/tests/channel.py index a5c287d..697e420 100644 --- a/tests/channel.py +++ b/tests/channel.py @@ -155,6 +155,7 @@ class ProtocolHarness(contextlib.AbstractContextManager): def __init__(self, rtt: float): self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) self.channel = Channel(self.outlet) + Packet.timeout = self.channel._get_packet_timeout_time(1) def cleanup(self): self.channel._shutdown() @@ -169,9 +170,7 @@ class ProtocolHarness(contextlib.AbstractContextManager): class TestChannel(unittest.TestCase): def setUp(self) -> None: print("") - self.rtt = 0.001 - self.retry_interval = self.rtt * 150 - Packet.timeout = self.retry_interval + self.rtt = 0.01 self.h = ProtocolHarness(self.rtt) def tearDown(self) -> None: @@ -201,14 +200,14 @@ class TestChannel(unittest.TestCase): self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) self.assertEqual(envelope.raw, packet.raw) - time.sleep(self.retry_interval * 1.5) + time.sleep(self.h.channel._get_packet_timeout_time(1) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(2, envelope.tries) self.assertEqual(2, packet.tries) self.assertEqual(1, packet.instances) - time.sleep(self.retry_interval) + time.sleep(self.h.channel._get_packet_timeout_time(2) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(self.h.outlet.packets[0], packet) @@ -221,7 +220,7 @@ class TestChannel(unittest.TestCase): self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) - time.sleep(self.retry_interval) + time.sleep(self.h.channel._get_packet_timeout_time(3) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(3, envelope.tries) @@ -253,7 +252,11 @@ class TestChannel(unittest.TestCase): self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) self.assertEqual(envelope.raw, packet.raw) - time.sleep(self.retry_interval * 7.5) + time.sleep(self.h.channel._get_packet_timeout_time(1)) + time.sleep(self.h.channel._get_packet_timeout_time(2)) + time.sleep(self.h.channel._get_packet_timeout_time(3)) + time.sleep(self.h.channel._get_packet_timeout_time(4)) + time.sleep(self.h.channel._get_packet_timeout_time(5) * 1.1) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(5, envelope.tries) @@ -317,7 +320,7 @@ class TestChannel(unittest.TestCase): self.assertEqual(len(self.h.outlet.packets), 0) envelope = self.h.channel.send(message) - time.sleep(self.retry_interval * 0.5) + time.sleep(self.h.channel._get_packet_timeout_time(1) * 0.5) self.assertIsNotNone(envelope) self.assertIsNotNone(envelope.raw) @@ -339,7 +342,7 @@ class TestChannel(unittest.TestCase): self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) - time.sleep(self.retry_interval * 2) + time.sleep(self.h.channel._get_packet_timeout_time(1)) self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(1, envelope.tries) @@ -460,6 +463,7 @@ class TestChannel(unittest.TestCase): packet = self.h.outlet.packets[0] self.h.channel._receive(packet.raw) + packet.delivered() self.assertEqual(1, callbacks) self.assertEqual(len(data), last_cb_value) @@ -472,6 +476,27 @@ class TestChannel(unittest.TestCase): decoded = result.decode("utf-8") self.assertEqual(data, decoded) + self.assertEqual(1, len(self.h.outlet.packets)) + + result = reader.read(1) + + self.assertIsNone(result) + self.assertTrue(self.h.channel.is_ready_to_send()) + + writer.close() + + self.assertEqual(2, len(self.h.outlet.packets)) + + packet = self.h.outlet.packets[1] + self.h.channel._receive(packet.raw) + packet.delivered() + + result = reader.read(1) + + self.assertIsNotNone(result) + self.assertTrue(len(result) == 0) + + if __name__ == '__main__': diff --git a/tests/link.py b/tests/link.py index 818ac99..604b7c5 100644 --- a/tests/link.py +++ b/tests/link.py @@ -4,10 +4,14 @@ import subprocess import shlex import threading import time +from unittest import skipIf import RNS import os from tests.channel import MessageTest from RNS.Channel import MessageBase +from RNS.Buffer import StreamDataMessage +from RNS.Interfaces.LocalInterface import LocalClientInterface +from math import ceil APP_NAME = "rns_unit_tests" @@ -438,6 +442,113 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + def test_12_buffer_round_trip_big(self, local_bitrate = None): + global c_rns + init_rns(self) + print("") + print("Buffer round trip test") + + local_interface = next(filter(lambda iface: isinstance(iface, LocalClientInterface), RNS.Transport.interfaces), None) + self.assertIsNotNone(local_interface) + original_bitrate = local_interface.bitrate + + try: + if local_bitrate is not None: + local_interface.bitrate = local_bitrate + local_interface._force_bitrate = True + print("Forcing local bitrate of " + str(local_bitrate) + " bps (" + str(round(local_bitrate/8, 0)) + " B/s)") + + # TODO: Load this from public bytes only + id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) + self.assertEqual(id1.hash, bytes.fromhex(fixed_keys[0][1])) + + dest = RNS.Destination(id1, RNS.Destination.OUT, RNS.Destination.SINGLE, APP_NAME, "link", "establish") + + self.assertEqual(dest.hash, bytes.fromhex("fb48da0e82e6e01ba0c014513f74540d")) + + l1 = RNS.Link(dest) + # delay a reasonable time for link to come up at current bitrate + link_sleep = max(RNS.Link.MDU * 3 / local_interface.bitrate * 8, 2) + timeout_at = time.time() + link_sleep + print("Waiting " + str(round(link_sleep, 1)) + " sec for link to come up") + while l1.status != RNS.Link.ACTIVE and time.time() < timeout_at: + time.sleep(0.01) + + self.assertEqual(l1.status, RNS.Link.ACTIVE) + + buffer = None + received = [] + def handle_data(ready_bytes: int): + data = buffer.read(ready_bytes) + received.append(data) + + channel = l1.get_channel() + buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_data) + + # try to make the message big enough to split across packets, but + # small enough to make the test complete in a reasonable amount of time + seed_text = "0123456789" + message = seed_text*ceil(min(max(local_interface.bitrate / 8, + StreamDataMessage.MAX_DATA_LEN * 2 / len(seed_text)), + 1000)) + # the return message will have an appendage string " back at you" + # for every StreamDataMessage that arrives. To verify, we need + # to insert that string every MAX_DATA_LEN and also at the end. + expected_rx_message = "" + for i in range(0, len(message)): + if i > 0 and (i % StreamDataMessage.MAX_DATA_LEN) == 0: + expected_rx_message += " back at you" + expected_rx_message += message[i] + expected_rx_message += " back at you" + + # since the segments will be received at max length for a + # StreamDataMessage, the appended text will end up in a + # separate packet. + expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2) + print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " + + "expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes") + transfer_sleep = max(expected_chunk_count * 3 * c_rns.MTU / local_interface.bitrate * 8, 3) + print("Will take up to " + str(round(transfer_sleep, 0)) + " seconds to transfer") + expected_ready_time = time.time() + transfer_sleep + buffer.write(message.encode("utf-8")) + buffer.flush() + # delay a reasonable time for the send and receive + # a chunk each way plus a little more for a proof each way + while time.time() < expected_ready_time and len(received) < expected_chunk_count: + time.sleep(0.1) + # sleep for at least one more chunk round trip in case there + # are more chunks than expected + if time.time() < expected_ready_time: + time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1)) + + # Why does this not always work out correctly? + # self.assertEqual(expected_chunk_count, len(received)) + + data = bytearray() + for rx in received: + data.extend(rx) + + rx_message = data.decode("utf-8") + + self.assertEqual(len(expected_rx_message), len(rx_message)) + for i in range(0, len(expected_rx_message)): + self.assertEqual(expected_rx_message[i], rx_message[i]) + self.assertEqual(expected_rx_message, rx_message) + + l1.teardown() + time.sleep(0.5) + self.assertEqual(l1.status, RNS.Link.CLOSED) + finally: + local_interface.bitrate = original_bitrate + local_interface._force_bitrate = False + + # Run with + # RUN_SLOW_TESTS=1 python tests/link.py TestLink.test_13_buffer_round_trip_big_slow + # Or + # make RUN_SLOW_TESTS=1 test + @skipIf(int(os.getenv('RUN_SLOW_TESTS', 0)) < 1, "Not running slow tests") + def test_13_buffer_round_trip_big_slow(self): + self.test_12_buffer_round_trip_big(local_bitrate=410) def size_str(self, num, suffix='B'): units = ['','K','M','G','T','P','E','Z']