Merge pull request #252 from acehoss/bugfix/buffer-missing-segments

Bugfix: buffer missing segments
This commit is contained in:
markqvist 2023-03-05 17:59:03 +01:00 committed by GitHub
commit f5d77a1dfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 186 additions and 45 deletions

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
import threading
from threading import RLock from threading import RLock
from RNS.vendor import umsgpack import struct
from RNS.Channel import Channel, MessageBase, SystemMessageTypes from RNS.Channel import Channel, MessageBase, SystemMessageTypes
import RNS import RNS
from io import RawIOBase, BufferedRWPair, BufferedReader, BufferedWriter from io import RawIOBase, BufferedRWPair, BufferedReader, BufferedWriter
@ -16,22 +17,12 @@ class StreamDataMessage(MessageBase):
uses a system-reserved message type. uses a system-reserved message type.
""" """
STREAM_ID_MAX = 65535 STREAM_ID_MAX = 0x7fff # 32767
""" """
While not essential for the current message packing The stream id is limited to 2 bytes - 1 bit
method (umsgpack), the stream id is clamped to the
size of a UInt16 for future struct packing.
""" """
OVERHEAD = 0 MAX_DATA_LEN = RNS.Link.MDU - 2 - 6 # 2 for stream data message header, 6 for channel envelope
"""
The number of bytes used by this messa
When the Buffer package is imported, this value is
calculated based on the value of RNS.Link.MDU.
"""
MAX_DATA_LEN = 0
""" """
When the Buffer package is imported, this value is When the Buffer package is imported, this value is
calculcated based on the value of OVERHEAD calculcated based on the value of OVERHEAD
@ -48,7 +39,7 @@ class StreamDataMessage(MessageBase):
""" """
super().__init__() super().__init__()
if stream_id is not None and stream_id > self.STREAM_ID_MAX: if stream_id is not None and stream_id > self.STREAM_ID_MAX:
raise ValueError("stream_id must be 0-65535") raise ValueError("stream_id must be 0-32767")
self.stream_id = stream_id self.stream_id = stream_id
self.data = data or bytes() self.data = data or bytes()
self.eof = eof self.eof = eof
@ -56,18 +47,14 @@ class StreamDataMessage(MessageBase):
def pack(self) -> bytes: def pack(self) -> bytes:
if self.stream_id is None: if self.stream_id is None:
raise ValueError("stream_id") raise ValueError("stream_id")
return umsgpack.packb((self.stream_id, self.eof, bytes(self.data))) header_val = (0x7fff & self.stream_id) | (0x8000 if self.eof else 0x0000)
return bytes(struct.pack(">H", header_val) + (self.data if self.data else bytes()))
def unpack(self, raw): def unpack(self, raw):
self.stream_id, self.eof, self.data = umsgpack.unpackb(raw) self.stream_id = struct.unpack(">H", raw[:2])[0]
self.eof = (0x8000 & self.stream_id) > 0
self.stream_id = self.stream_id & 0x7fff
_link_sized_bytes = ("\0"*RNS.Link.MDU).encode("utf-8") self.data = raw[2:]
StreamDataMessage.OVERHEAD = len(StreamDataMessage(stream_id=StreamDataMessage.STREAM_ID_MAX,
data=_link_sized_bytes,
eof=True).pack()) - len(_link_sized_bytes) + 4 # TODO: Calculation was off by 10 bytes, why?
StreamDataMessage.MAX_DATA_LEN = RNS.Link.MDU - StreamDataMessage.OVERHEAD
_link_sized_bytes = None
class RawChannelReader(RawIOBase, AbstractContextManager): class RawChannelReader(RawIOBase, AbstractContextManager):
@ -144,9 +131,9 @@ class RawChannelReader(RawIOBase, AbstractContextManager):
def readinto(self, __buffer: bytearray) -> int | None: def readinto(self, __buffer: bytearray) -> int | None:
ready = self._read(len(__buffer)) ready = self._read(len(__buffer))
if ready: if ready is not None:
__buffer[:len(ready)] = ready __buffer[:len(ready)] = ready
return len(ready) if ready else None return len(ready) if ready is not None else None
def writable(self) -> bool: def writable(self) -> bool:
return False return False
@ -198,11 +185,10 @@ class RawChannelWriter(RawIOBase, AbstractContextManager):
def write(self, __b: bytes) -> int | None: def write(self, __b: bytes) -> int | None:
try: try:
if self._channel.is_ready_to_send(): chunk = bytes(__b[:StreamDataMessage.MAX_DATA_LEN])
chunk = __b[:StreamDataMessage.MAX_DATA_LEN] message = StreamDataMessage(self._stream_id, chunk, self._eof)
message = StreamDataMessage(self._stream_id, chunk, self._eof) self._channel.send(message)
self._channel.send(message) return len(chunk)
return len(chunk)
except RNS.Channel.ChannelException as cex: except RNS.Channel.ChannelException as cex:
if cex.type != RNS.Channel.CEType.ME_LINK_NOT_READY: if cex.type != RNS.Channel.CEType.ME_LINK_NOT_READY:
raise raise

View File

@ -356,6 +356,10 @@ class Channel(contextlib.AbstractContextManager):
envelope = Envelope(outlet=self._outlet, raw=raw) envelope = Envelope(outlet=self._outlet, raw=raw)
with self._lock: with self._lock:
message = envelope.unpack(self._message_factories) message = envelope.unpack(self._message_factories)
prev_env = self._rx_ring[0] if len(self._rx_ring) > 0 else None
if prev_env and envelope.sequence != (prev_env.sequence + 1) % 0x10000:
RNS.log("Channel: Out of order packet received", RNS.LOG_DEBUG)
return
is_new = self._emplace_envelope(envelope, self._rx_ring) is_new = self._emplace_envelope(envelope, self._rx_ring)
self._prune_rx_ring() self._prune_rx_ring()
if not is_new: if not is_new:
@ -403,6 +407,9 @@ class Channel(contextlib.AbstractContextManager):
def _packet_delivered(self, packet: TPacket): def _packet_delivered(self, packet: TPacket):
self._packet_tx_op(packet, lambda env: True) self._packet_tx_op(packet, lambda env: True)
def _get_packet_timeout_time(self, tries: int) -> float:
return pow(2, tries - 1) * max(self._outlet.rtt, 0.01) * 5
def _packet_timeout(self, packet: TPacket): def _packet_timeout(self, packet: TPacket):
def retry_envelope(envelope: Envelope) -> bool: def retry_envelope(envelope: Envelope) -> bool:
if envelope.tries >= self._max_tries: if envelope.tries >= self._max_tries:
@ -412,9 +419,11 @@ class Channel(contextlib.AbstractContextManager):
return True return True
envelope.tries += 1 envelope.tries += 1
self._outlet.resend(envelope.packet) self._outlet.resend(envelope.packet)
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
return False return False
self._packet_tx_op(packet, retry_envelope) if self._outlet.get_packet_state(packet) != MessageState.MSGSTATE_DELIVERED:
self._packet_tx_op(packet, retry_envelope)
def send(self, message: MessageBase) -> Envelope: def send(self, message: MessageBase) -> Envelope:
""" """
@ -439,7 +448,7 @@ class Channel(contextlib.AbstractContextManager):
envelope.packet = self._outlet.send(envelope.raw) envelope.packet = self._outlet.send(envelope.raw)
envelope.tries += 1 envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered) self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout) self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
return envelope return envelope
@property @property
@ -473,6 +482,7 @@ class LinkChannelOutlet(ChannelOutletBase):
return packet return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet: def resend(self, packet: RNS.Packet) -> RNS.Packet:
RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG)
if not packet.resend(): if not packet.resend():
RNS.log("Failed to resend packet", RNS.LOG_ERROR) RNS.log("Failed to resend packet", RNS.LOG_ERROR)
return packet return packet
@ -511,7 +521,7 @@ class LinkChannelOutlet(ChannelOutletBase):
def set_packet_timeout_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None, def set_packet_timeout_callback(self, packet: RNS.Packet, callback: Callable[[RNS.Packet], None] | None,
timeout: float | None = None): timeout: float | None = None):
if timeout: if timeout and packet.receipt:
packet.receipt.set_timeout(timeout) packet.receipt.set_timeout(timeout)
def inner(receipt: RNS.PacketReceipt): def inner(receipt: RNS.PacketReceipt):

View File

@ -86,6 +86,8 @@ class LocalClientInterface(Interface):
self.online = True self.online = True
self.writing = False self.writing = False
self._force_bitrate = False
self.announce_rate_target = None self.announce_rate_target = None
self.announce_rate_grace = None self.announce_rate_grace = None
self.announce_rate_penalty = None self.announce_rate_penalty = None
@ -137,6 +139,9 @@ class LocalClientInterface(Interface):
def processIncoming(self, data): def processIncoming(self, data):
if self._force_bitrate:
time.sleep(len(data) / self.bitrate * 8)
self.rxb += len(data) self.rxb += len(data)
if hasattr(self, "parent_interface") and self.parent_interface != None: if hasattr(self, "parent_interface") and self.parent_interface != None:
self.parent_interface.rxb += len(data) self.parent_interface.rxb += len(data)
@ -154,6 +159,8 @@ class LocalClientInterface(Interface):
if self.online: if self.online:
try: try:
self.writing = True self.writing = True
if self._force_bitrate:
time.sleep(len(data) / self.bitrate * 8)
data = bytes([HDLC.FLAG])+HDLC.escape(data)+bytes([HDLC.FLAG]) data = bytes([HDLC.FLAG])+HDLC.escape(data)+bytes([HDLC.FLAG])
self.socket.sendall(data) self.socket.sendall(data)
self.writing = False self.writing = False

View File

@ -809,9 +809,9 @@ class Link:
if not self._channel: if not self._channel:
RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG) RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG)
else: else:
packet.prove()
plaintext = self.decrypt(packet.data) plaintext = self.decrypt(packet.data)
self._channel._receive(plaintext) self._channel._receive(plaintext)
packet.prove()
elif packet.packet_type == RNS.Packet.PROOF: elif packet.packet_type == RNS.Packet.PROOF:
if packet.context == RNS.Packet.RESOURCE_PRF: if packet.context == RNS.Packet.RESOURCE_PRF:

View File

@ -882,6 +882,8 @@ class Transport:
return True return True
if packet.context == RNS.Packet.CACHE_REQUEST: if packet.context == RNS.Packet.CACHE_REQUEST:
return True return True
if packet.context == RNS.Packet.CHANNEL:
return True
if packet.destination_type == RNS.Destination.PLAIN: if packet.destination_type == RNS.Destination.PLAIN:
if packet.packet_type != RNS.Packet.ANNOUNCE: if packet.packet_type != RNS.Packet.ANNOUNCE:

View File

@ -155,6 +155,7 @@ class ProtocolHarness(contextlib.AbstractContextManager):
def __init__(self, rtt: float): def __init__(self, rtt: float):
self.outlet = ChannelOutletTest(mdu=500, rtt=rtt) self.outlet = ChannelOutletTest(mdu=500, rtt=rtt)
self.channel = Channel(self.outlet) self.channel = Channel(self.outlet)
Packet.timeout = self.channel._get_packet_timeout_time(1)
def cleanup(self): def cleanup(self):
self.channel._shutdown() self.channel._shutdown()
@ -169,9 +170,7 @@ class ProtocolHarness(contextlib.AbstractContextManager):
class TestChannel(unittest.TestCase): class TestChannel(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
print("") print("")
self.rtt = 0.001 self.rtt = 0.01
self.retry_interval = self.rtt * 150
Packet.timeout = self.retry_interval
self.h = ProtocolHarness(self.rtt) self.h = ProtocolHarness(self.rtt)
def tearDown(self) -> None: def tearDown(self) -> None:
@ -201,14 +200,14 @@ class TestChannel(unittest.TestCase):
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw) self.assertEqual(envelope.raw, packet.raw)
time.sleep(self.retry_interval * 1.5) time.sleep(self.h.channel._get_packet_timeout_time(1) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(2, envelope.tries) self.assertEqual(2, envelope.tries)
self.assertEqual(2, packet.tries) self.assertEqual(2, packet.tries)
self.assertEqual(1, packet.instances) self.assertEqual(1, packet.instances)
time.sleep(self.retry_interval) time.sleep(self.h.channel._get_packet_timeout_time(2) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(self.h.outlet.packets[0], packet) self.assertEqual(self.h.outlet.packets[0], packet)
@ -221,7 +220,7 @@ class TestChannel(unittest.TestCase):
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
time.sleep(self.retry_interval) time.sleep(self.h.channel._get_packet_timeout_time(3) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(3, envelope.tries) self.assertEqual(3, envelope.tries)
@ -253,7 +252,11 @@ class TestChannel(unittest.TestCase):
self.assertEqual(MessageState.MSGSTATE_SENT, packet.state) self.assertEqual(MessageState.MSGSTATE_SENT, packet.state)
self.assertEqual(envelope.raw, packet.raw) self.assertEqual(envelope.raw, packet.raw)
time.sleep(self.retry_interval * 7.5) time.sleep(self.h.channel._get_packet_timeout_time(1))
time.sleep(self.h.channel._get_packet_timeout_time(2))
time.sleep(self.h.channel._get_packet_timeout_time(3))
time.sleep(self.h.channel._get_packet_timeout_time(4))
time.sleep(self.h.channel._get_packet_timeout_time(5) * 1.1)
self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(5, envelope.tries) self.assertEqual(5, envelope.tries)
@ -317,7 +320,7 @@ class TestChannel(unittest.TestCase):
self.assertEqual(len(self.h.outlet.packets), 0) self.assertEqual(len(self.h.outlet.packets), 0)
envelope = self.h.channel.send(message) envelope = self.h.channel.send(message)
time.sleep(self.retry_interval * 0.5) time.sleep(self.h.channel._get_packet_timeout_time(1) * 0.5)
self.assertIsNotNone(envelope) self.assertIsNotNone(envelope)
self.assertIsNotNone(envelope.raw) self.assertIsNotNone(envelope.raw)
@ -339,7 +342,7 @@ class TestChannel(unittest.TestCase):
self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state) self.assertEqual(MessageState.MSGSTATE_DELIVERED, packet.state)
time.sleep(self.retry_interval * 2) time.sleep(self.h.channel._get_packet_timeout_time(1))
self.assertEqual(1, len(self.h.outlet.packets)) self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(1, envelope.tries) self.assertEqual(1, envelope.tries)
@ -460,6 +463,7 @@ class TestChannel(unittest.TestCase):
packet = self.h.outlet.packets[0] packet = self.h.outlet.packets[0]
self.h.channel._receive(packet.raw) self.h.channel._receive(packet.raw)
packet.delivered()
self.assertEqual(1, callbacks) self.assertEqual(1, callbacks)
self.assertEqual(len(data), last_cb_value) self.assertEqual(len(data), last_cb_value)
@ -472,6 +476,27 @@ class TestChannel(unittest.TestCase):
decoded = result.decode("utf-8") decoded = result.decode("utf-8")
self.assertEqual(data, decoded) self.assertEqual(data, decoded)
self.assertEqual(1, len(self.h.outlet.packets))
result = reader.read(1)
self.assertIsNone(result)
self.assertTrue(self.h.channel.is_ready_to_send())
writer.close()
self.assertEqual(2, len(self.h.outlet.packets))
packet = self.h.outlet.packets[1]
self.h.channel._receive(packet.raw)
packet.delivered()
result = reader.read(1)
self.assertIsNotNone(result)
self.assertTrue(len(result) == 0)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -4,10 +4,14 @@ import subprocess
import shlex import shlex
import threading import threading
import time import time
from unittest import skipIf
import RNS import RNS
import os import os
from tests.channel import MessageTest from tests.channel import MessageTest
from RNS.Channel import MessageBase from RNS.Channel import MessageBase
from RNS.Buffer import StreamDataMessage
from RNS.Interfaces.LocalInterface import LocalClientInterface
from math import ceil
APP_NAME = "rns_unit_tests" APP_NAME = "rns_unit_tests"
@ -438,6 +442,113 @@ class TestLink(unittest.TestCase):
time.sleep(0.5) time.sleep(0.5)
self.assertEqual(l1.status, RNS.Link.CLOSED) self.assertEqual(l1.status, RNS.Link.CLOSED)
def test_12_buffer_round_trip_big(self, local_bitrate = None):
global c_rns
init_rns(self)
print("")
print("Buffer round trip test")
local_interface = next(filter(lambda iface: isinstance(iface, LocalClientInterface), RNS.Transport.interfaces), None)
self.assertIsNotNone(local_interface)
original_bitrate = local_interface.bitrate
try:
if local_bitrate is not None:
local_interface.bitrate = local_bitrate
local_interface._force_bitrate = True
print("Forcing local bitrate of " + str(local_bitrate) + " bps (" + str(round(local_bitrate/8, 0)) + " B/s)")
# TODO: Load this from public bytes only
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
self.assertEqual(id1.hash, bytes.fromhex(fixed_keys[0][1]))
dest = RNS.Destination(id1, RNS.Destination.OUT, RNS.Destination.SINGLE, APP_NAME, "link", "establish")
self.assertEqual(dest.hash, bytes.fromhex("fb48da0e82e6e01ba0c014513f74540d"))
l1 = RNS.Link(dest)
# delay a reasonable time for link to come up at current bitrate
link_sleep = max(RNS.Link.MDU * 3 / local_interface.bitrate * 8, 2)
timeout_at = time.time() + link_sleep
print("Waiting " + str(round(link_sleep, 1)) + " sec for link to come up")
while l1.status != RNS.Link.ACTIVE and time.time() < timeout_at:
time.sleep(0.01)
self.assertEqual(l1.status, RNS.Link.ACTIVE)
buffer = None
received = []
def handle_data(ready_bytes: int):
data = buffer.read(ready_bytes)
received.append(data)
channel = l1.get_channel()
buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_data)
# try to make the message big enough to split across packets, but
# small enough to make the test complete in a reasonable amount of time
seed_text = "0123456789"
message = seed_text*ceil(min(max(local_interface.bitrate / 8,
StreamDataMessage.MAX_DATA_LEN * 2 / len(seed_text)),
1000))
# the return message will have an appendage string " back at you"
# for every StreamDataMessage that arrives. To verify, we need
# to insert that string every MAX_DATA_LEN and also at the end.
expected_rx_message = ""
for i in range(0, len(message)):
if i > 0 and (i % StreamDataMessage.MAX_DATA_LEN) == 0:
expected_rx_message += " back at you"
expected_rx_message += message[i]
expected_rx_message += " back at you"
# since the segments will be received at max length for a
# StreamDataMessage, the appended text will end up in a
# separate packet.
expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2)
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " +
"expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes")
transfer_sleep = max(expected_chunk_count * 3 * c_rns.MTU / local_interface.bitrate * 8, 3)
print("Will take up to " + str(round(transfer_sleep, 0)) + " seconds to transfer")
expected_ready_time = time.time() + transfer_sleep
buffer.write(message.encode("utf-8"))
buffer.flush()
# delay a reasonable time for the send and receive
# a chunk each way plus a little more for a proof each way
while time.time() < expected_ready_time and len(received) < expected_chunk_count:
time.sleep(0.1)
# sleep for at least one more chunk round trip in case there
# are more chunks than expected
if time.time() < expected_ready_time:
time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1))
# Why does this not always work out correctly?
# self.assertEqual(expected_chunk_count, len(received))
data = bytearray()
for rx in received:
data.extend(rx)
rx_message = data.decode("utf-8")
self.assertEqual(len(expected_rx_message), len(rx_message))
for i in range(0, len(expected_rx_message)):
self.assertEqual(expected_rx_message[i], rx_message[i])
self.assertEqual(expected_rx_message, rx_message)
l1.teardown()
time.sleep(0.5)
self.assertEqual(l1.status, RNS.Link.CLOSED)
finally:
local_interface.bitrate = original_bitrate
local_interface._force_bitrate = False
# Run with
# RUN_SLOW_TESTS=1 python tests/link.py TestLink.test_13_buffer_round_trip_big_slow
# Or
# make RUN_SLOW_TESTS=1 test
@skipIf(int(os.getenv('RUN_SLOW_TESTS', 0)) < 1, "Not running slow tests")
def test_13_buffer_round_trip_big_slow(self):
self.test_12_buffer_round_trip_big(local_bitrate=410)
def size_str(self, num, suffix='B'): def size_str(self, num, suffix='B'):
units = ['','K','M','G','T','P','E','Z'] units = ['','K','M','G','T','P','E','Z']