diff --git a/.gitignore b/.gitignore index ae7e68d..4902c93 100755 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,6 @@ docs/build rns*.egg-info profile.data tests/rnsconfig/storage +tests/rnsconfig/logfile *.data *.result diff --git a/RNS/Channel.py b/RNS/Channel.py index f75b529..16d41b4 100644 --- a/RNS/Channel.py +++ b/RNS/Channel.py @@ -223,6 +223,38 @@ class Channel(contextlib.AbstractContextManager): ``Channel`` is not instantiated directly, but rather obtained from a ``Link`` with ``get_channel()``. """ + + # The initial window size at channel setup + WINDOW = 2 + + # Absolute minimum window size + WINDOW_MIN = 1 + + # The maximum window size for transfers on slow links + WINDOW_MAX_SLOW = 10 + + # The maximum window size for transfers on fast links + WINDOW_MAX_FAST = 75 + + # For calculating maps and guard segments, this + # must be set to the global maximum window. + WINDOW_MAX = WINDOW_MAX_FAST + + # If the fast rate is sustained for this many request + # rounds, the fast link window size will be allowed. + FAST_RATE_THRESHOLD = WINDOW_MAX_SLOW - WINDOW - 2 + + # If the RTT rate is higher than this value, + # the max window size for fast links will be used. + # The default is 50 Kbps (the value is stored in + # bytes per second, hence the "/ 8"). + RATE_FAST = (50*1000) / 8 + + # The minimum allowed flexibility of the window size. + # The difference between window_max and window_min + # will never be smaller than this value. + WINDOW_FLEXIBILITY = 4 + def __init__(self, outlet: ChannelOutletBase): """ @@ -234,8 +266,10 @@ class Channel(contextlib.AbstractContextManager): self._rx_ring: collections.deque[Envelope] = collections.deque() self._message_callbacks: [MessageCallbackType] = [] self._next_sequence = 0 + self._next_rx_sequence = 0 self._message_factories: dict[int, Type[MessageBase]] = {} self._max_tries = 5 + self._max_outstanding = Channel.WINDOW def __enter__(self) -> Channel: return self @@ -325,21 +359,13 @@ class Channel(contextlib.AbstractContextManager): ring.insert(i, envelope) return True if existing.sequence == envelope.sequence: - RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME) + RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME) return False i += 1 envelope.tracked = True ring.append(envelope) return True - def _prune_rx_ring(self): - with self._lock: - # Implementation for fixed window = 1 - stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:] - for env in stale: - env.tracked = False - self._rx_ring.remove(env) - def _run_callbacks(self, message: MessageBase): with self._lock: cbs = self._message_callbacks.copy() @@ -349,24 +375,61 @@ class Channel(contextlib.AbstractContextManager): if cb(message): return except Exception as ex: - RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR) + RNS.log(f"Channel "+str(self)+" experienced an error while running message callback: {ex}", RNS.LOG_ERROR) def _receive(self, raw: bytes): try: envelope = Envelope(outlet=self._outlet, raw=raw) with self._lock: message = envelope.unpack(self._message_factories) - prev_env = self._rx_ring[0] if len(self._rx_ring) > 0 else None - if prev_env and envelope.sequence != (prev_env.sequence + 1) % 0x10000: - RNS.log("Channel: Out of order packet received", RNS.LOG_EXTREME) + + # TODO: Test sequence overflow + if envelope.sequence < self._next_rx_sequence: + window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % 0x10000 + if window_overflow < self._next_rx_sequence: + if envelope.sequence > window_overflow: + RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG) + return + else: + if envelope.sequence < self._next_rx_sequence: + RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG) + return is_new = self._emplace_envelope(envelope, self._rx_ring) - self._prune_rx_ring() + if not is_new: - RNS.log("Channel: Duplicate message received", RNS.LOG_EXTREME) + RNS.log("Duplicate message received on channel "+str(self), RNS.LOG_EXTREME) return - RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) - threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start() + else: + + # TODO: Remove + # rmsg = "RX Ring State:\n " + # for e in self._rx_ring: + # rmsg += "["+str(e.sequence)+"]" + # rmsg += "\n" + # RNS.log(rmsg) + # print(rmsg) + + with self._lock: + contigous = [] + for e in self._rx_ring: + if e.sequence == self._next_rx_sequence: + contigous.append(e) + self._next_rx_sequence = (self._next_rx_sequence + 1) % 0x10000 + + for e in contigous: + m = e.unpack(self._message_factories) + self._rx_ring.remove(e) + threading.Thread(target=self._run_callbacks, name="Message Callback", args=[m], daemon=True).start() + + # TODO: Remove + # rmsg = "RX Ring State:\n " + # for e in self._rx_ring: + # rmsg += "["+str(e.sequence)+"]" + # rmsg += "\n" + # RNS.log(rmsg) + # print(rmsg) + except Exception as ex: RNS.log(f"Channel: Error receiving data: {ex}") @@ -381,14 +444,15 @@ class Channel(contextlib.AbstractContextManager): return False with self._lock: + outstanding = 0 for envelope in self._tx_ring: - if envelope.outlet == self._outlet and (not envelope.packet - or self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_SENT): - # TODO: Check if this should be enabled with some kind of - # rate limiting, since it currently floods log output when - # messages are waiting. - # RNS.log("Channel: Link has a pending message.", RNS.LOG_EXTREME) - return False + if envelope.outlet == self._outlet: + if not envelope.packet or not self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED: + outstanding += 1 + + if outstanding >= self._max_outstanding: + return False + return True def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]): @@ -419,6 +483,7 @@ class Channel(contextlib.AbstractContextManager): return True envelope.tries += 1 self._outlet.resend(envelope.packet) + self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries)) return False @@ -449,6 +514,25 @@ class Channel(contextlib.AbstractContextManager): envelope.tries += 1 self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries)) + + # TODO: Remove + # try: + # tmsg = "TX Ring State:\n " + # for e in self._tx_ring: + # estat="u" + # status = e.packet.receipt.get_status() + # if status == RNS.PacketReceipt.SENT: + # estat="s" + # if status == RNS.PacketReceipt.DELIVERED: + # estat="d" + # if status == RNS.PacketReceipt.FAILED: + # estat="f" + # tmsg += "["+str(e.sequence)+estat+"]" + # print(tmsg) + # RNS.log(tmsg) + # except: + # pass + return envelope @property @@ -483,7 +567,8 @@ class LinkChannelOutlet(ChannelOutletBase): def resend(self, packet: RNS.Packet) -> RNS.Packet: RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG) - if not packet.resend(): + receipt = packet.resend() + if not receipt: RNS.log("Failed to resend packet", RNS.LOG_ERROR) return packet @@ -538,4 +623,7 @@ class LinkChannelOutlet(ChannelOutletBase): packet.receipt.set_delivery_callback(inner if callback else None) def get_packet_id(self, packet: RNS.Packet) -> any: - return packet.get_hash() + if packet and hasattr(packet, "get_hash") and callable(packet.get_hash): + return packet.get_hash() + else: + return None diff --git a/RNS/Link.py b/RNS/Link.py index 4abb9bc..493de47 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -809,6 +809,19 @@ class Link: if not self._channel: RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) else: + # TODO: Remove packet loss simulator ###### + # if not hasattr(self, "drop_counter"): + # self.drop_counter = 0 + # self.drop_counter += 1 + + # if self.drop_counter%6 == 0: + # RNS.log("Dropping channel packet for testing", RNS.LOG_DEBUG) + # else: + # packet.prove() + # plaintext = self.decrypt(packet.data) + # self._channel._receive(plaintext) + ############################################ + packet.prove() plaintext = self.decrypt(packet.data) self._channel._receive(plaintext) diff --git a/tests/link.py b/tests/link.py index ad5af45..f8c79eb 100644 --- a/tests/link.py +++ b/tests/link.py @@ -61,6 +61,7 @@ class TestLink(unittest.TestCase): def tearDownClass(cls): close_rns() + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_0_valid_announce(self): init_rns(self) print("") @@ -71,6 +72,7 @@ class TestLink(unittest.TestCase): ap.pack() self.assertEqual(RNS.Identity.validate_announce(ap), True) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_1_invalid_announce(self): init_rns(self) print("") @@ -86,6 +88,7 @@ class TestLink(unittest.TestCase): ap.send() self.assertEqual(RNS.Identity.validate_announce(ap), False) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_2_establish(self): init_rns(self) print("") @@ -105,6 +108,7 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_3_packets(self): init_rns(self) print("") @@ -171,6 +175,7 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_4_micro_resource(self): init_rns(self) print("") @@ -206,6 +211,7 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_5_mini_resource(self): init_rns(self) print("") @@ -241,6 +247,7 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_6_small_resource(self): init_rns(self) print("") @@ -276,6 +283,7 @@ class TestLink(unittest.TestCase): self.assertEqual(l1.status, RNS.Link.CLOSED) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_7_medium_resource(self): if RNS.Cryptography.backend() == "internal": print("Skipping medium resource test...") @@ -314,6 +322,7 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_9_large_resource(self): if RNS.Cryptography.backend() == "internal": print("Skipping large resource test...") @@ -352,6 +361,7 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + #@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_10_channel_round_trip(self): global c_rns init_rns(self) @@ -393,13 +403,14 @@ class TestLink(unittest.TestCase): self.assertEqual("Hello back", rx_message.data) self.assertEqual(test_message.id, rx_message.id) self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized) - self.assertEqual(1, len(l1._channel._rx_ring)) + self.assertEqual(0, len(l1._channel._rx_ring)) l1.teardown() time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) self.assertEqual(0, len(l1._channel._rx_ring)) + # @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping") def test_11_buffer_round_trip(self): global c_rns init_rns(self) @@ -442,6 +453,7 @@ class TestLink(unittest.TestCase): time.sleep(0.5) self.assertEqual(l1.status, RNS.Link.CLOSED) + # @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None and os.getenv('RUN_SLOW_TESTS') == None, "Skipping") def test_12_buffer_round_trip_big(self, local_bitrate = None): global c_rns init_rns(self) @@ -479,6 +491,8 @@ class TestLink(unittest.TestCase): buffer = None received = [] def handle_data(ready_bytes: int): + # TODO: Remove + RNS.log("Handling data") data = buffer.read(ready_bytes) received.append(data) @@ -487,50 +501,60 @@ class TestLink(unittest.TestCase): # try to make the message big enough to split across packets, but # small enough to make the test complete in a reasonable amount of time - seed_text = "0123456789" - message = seed_text*ceil(min(max(local_interface.bitrate / 8, - StreamDataMessage.MAX_DATA_LEN * 2 / len(seed_text)), - 1000)) + # seed_text = "0123456789" + # message = seed_text*ceil(min(max(local_interface.bitrate / 8, + # StreamDataMessage.MAX_DATA_LEN * 2 / len(seed_text)), + # 1000)) + + if local_interface.bitrate < 1000: + target_bytes = 3000 + else: + target_bytes = 16000 + + + message = os.urandom(target_bytes) + # the return message will have an appendage string " back at you" # for every StreamDataMessage that arrives. To verify, we need # to insert that string every MAX_DATA_LEN and also at the end. - expected_rx_message = "" + expected_rx_message = b"" for i in range(0, len(message)): if i > 0 and (i % StreamDataMessage.MAX_DATA_LEN) == 0: - expected_rx_message += " back at you" - expected_rx_message += message[i] - expected_rx_message += " back at you" + expected_rx_message += " back at you".encode("utf-8") + expected_rx_message += bytes([message[i]]) + expected_rx_message += " back at you".encode("utf-8") # since the segments will be received at max length for a # StreamDataMessage, the appended text will end up in a # separate packet. - expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2) + expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2)-1 print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " + "expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes") - transfer_sleep = max(expected_chunk_count * 3 * c_rns.MTU / local_interface.bitrate * 8, 3) - print("Will take up to " + str(round(transfer_sleep, 0)) + " seconds to transfer") - expected_ready_time = time.time() + transfer_sleep - buffer.write(message.encode("utf-8")) + + buffer.write(message) buffer.flush() + # delay a reasonable time for the send and receive # a chunk each way plus a little more for a proof each way - while time.time() < expected_ready_time and len(received) < expected_chunk_count: - time.sleep(0.1) - # sleep for at least one more chunk round trip in case there - # are more chunks than expected - if time.time() < expected_ready_time: - time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1)) + # while time.time() < expected_ready_time and len(received) < expected_chunk_count: + # time.sleep(0.1) + # # sleep for at least one more chunk round trip in case there + # # are more chunks than expected + # if time.time() < expected_ready_time: + # time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1)) - time.sleep(0.25) - - # Why does this not always work out correctly? - # self.assertEqual(expected_chunk_count, len(received)) + timeout = time.time() + 10 + while len(received) < expected_chunk_count and not time.time() > timeout: + time.sleep(2) + print(f"Received {len(received)} out of {expected_chunk_count} chunks so far") + time.sleep(2) + print(f"Received {len(received)} out of {expected_chunk_count} chunks") data = bytearray() for rx in received: data.extend(rx) - rx_message = data.decode("utf-8") + rx_message = data self.assertEqual(len(expected_rx_message), len(rx_message)) for i in range(0, len(expected_rx_message)): @@ -548,7 +572,7 @@ class TestLink(unittest.TestCase): # RUN_SLOW_TESTS=1 python tests/link.py TestLink.test_13_buffer_round_trip_big_slow # Or # make RUN_SLOW_TESTS=1 test - @skipIf(int(os.getenv('RUN_SLOW_TESTS', 0)) < 1, "Not running slow tests") + @skipIf(os.getenv('RUN_SLOW_TESTS') == None, "Not running slow tests") def test_13_buffer_round_trip_big_slow(self): self.test_12_buffer_round_trip_big(local_bitrate=410) @@ -612,8 +636,10 @@ def targets(yp=False): channel = link.get_channel() def handle_message(message): - message.data = message.data + " back" - channel.send(message) + if isinstance(message, MessageTest): + message.data = message.data + " back" + channel.send(message) + channel.register_message_type(MessageTest) channel.add_message_handler(handle_message) @@ -621,12 +647,12 @@ def targets(yp=False): def handle_buffer(ready_bytes: int): data = buffer.read(ready_bytes) - buffer.write((data.decode("utf-8") + " back at you").encode("utf-8")) + buffer.write(data + " back at you".encode("utf-8")) buffer.flush() buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_buffer) - m_rns = RNS.Reticulum("./tests/rnsconfig") + m_rns = RNS.Reticulum("./tests/rnsconfig", logdest=RNS.LOG_FILE, loglevel=RNS.LOG_EXTREME) id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0])) d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish") d1.set_proof_strategy(RNS.Destination.PROVE_ALL)