mirror of
https://github.com/markqvist/Reticulum.git
synced 2024-11-08 07:10:15 +00:00
Initial framework for channel windowing
This commit is contained in:
parent
7df11a6f67
commit
a4c64abed4
1
.gitignore
vendored
1
.gitignore
vendored
@ -10,5 +10,6 @@ docs/build
|
|||||||
rns*.egg-info
|
rns*.egg-info
|
||||||
profile.data
|
profile.data
|
||||||
tests/rnsconfig/storage
|
tests/rnsconfig/storage
|
||||||
|
tests/rnsconfig/logfile
|
||||||
*.data
|
*.data
|
||||||
*.result
|
*.result
|
||||||
|
136
RNS/Channel.py
136
RNS/Channel.py
@ -223,6 +223,38 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
``Channel`` is not instantiated directly, but rather
|
``Channel`` is not instantiated directly, but rather
|
||||||
obtained from a ``Link`` with ``get_channel()``.
|
obtained from a ``Link`` with ``get_channel()``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The initial window size at channel setup
|
||||||
|
WINDOW = 2
|
||||||
|
|
||||||
|
# Absolute minimum window size
|
||||||
|
WINDOW_MIN = 1
|
||||||
|
|
||||||
|
# The maximum window size for transfers on slow links
|
||||||
|
WINDOW_MAX_SLOW = 10
|
||||||
|
|
||||||
|
# The maximum window size for transfers on fast links
|
||||||
|
WINDOW_MAX_FAST = 75
|
||||||
|
|
||||||
|
# For calculating maps and guard segments, this
|
||||||
|
# must be set to the global maximum window.
|
||||||
|
WINDOW_MAX = WINDOW_MAX_FAST
|
||||||
|
|
||||||
|
# If the fast rate is sustained for this many request
|
||||||
|
# rounds, the fast link window size will be allowed.
|
||||||
|
FAST_RATE_THRESHOLD = WINDOW_MAX_SLOW - WINDOW - 2
|
||||||
|
|
||||||
|
# If the RTT rate is higher than this value,
|
||||||
|
# the max window size for fast links will be used.
|
||||||
|
# The default is 50 Kbps (the value is stored in
|
||||||
|
# bytes per second, hence the "/ 8").
|
||||||
|
RATE_FAST = (50*1000) / 8
|
||||||
|
|
||||||
|
# The minimum allowed flexibility of the window size.
|
||||||
|
# The difference between window_max and window_min
|
||||||
|
# will never be smaller than this value.
|
||||||
|
WINDOW_FLEXIBILITY = 4
|
||||||
|
|
||||||
def __init__(self, outlet: ChannelOutletBase):
|
def __init__(self, outlet: ChannelOutletBase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -234,8 +266,10 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
self._rx_ring: collections.deque[Envelope] = collections.deque()
|
self._rx_ring: collections.deque[Envelope] = collections.deque()
|
||||||
self._message_callbacks: [MessageCallbackType] = []
|
self._message_callbacks: [MessageCallbackType] = []
|
||||||
self._next_sequence = 0
|
self._next_sequence = 0
|
||||||
|
self._next_rx_sequence = 0
|
||||||
self._message_factories: dict[int, Type[MessageBase]] = {}
|
self._message_factories: dict[int, Type[MessageBase]] = {}
|
||||||
self._max_tries = 5
|
self._max_tries = 5
|
||||||
|
self._max_outstanding = Channel.WINDOW
|
||||||
|
|
||||||
def __enter__(self) -> Channel:
|
def __enter__(self) -> Channel:
|
||||||
return self
|
return self
|
||||||
@ -325,21 +359,13 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
ring.insert(i, envelope)
|
ring.insert(i, envelope)
|
||||||
return True
|
return True
|
||||||
if existing.sequence == envelope.sequence:
|
if existing.sequence == envelope.sequence:
|
||||||
RNS.log(f"Envelope: Emplacement of duplicate envelope sequence.", RNS.LOG_EXTREME)
|
RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME)
|
||||||
return False
|
return False
|
||||||
i += 1
|
i += 1
|
||||||
envelope.tracked = True
|
envelope.tracked = True
|
||||||
ring.append(envelope)
|
ring.append(envelope)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _prune_rx_ring(self):
|
|
||||||
with self._lock:
|
|
||||||
# Implementation for fixed window = 1
|
|
||||||
stale = list(sorted(self._rx_ring, key=lambda env: env.sequence, reverse=True))[1:]
|
|
||||||
for env in stale:
|
|
||||||
env.tracked = False
|
|
||||||
self._rx_ring.remove(env)
|
|
||||||
|
|
||||||
def _run_callbacks(self, message: MessageBase):
|
def _run_callbacks(self, message: MessageBase):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
cbs = self._message_callbacks.copy()
|
cbs = self._message_callbacks.copy()
|
||||||
@ -349,24 +375,61 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
if cb(message):
|
if cb(message):
|
||||||
return
|
return
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
RNS.log(f"Channel: Error running message callback: {ex}", RNS.LOG_ERROR)
|
RNS.log(f"Channel "+str(self)+" experienced an error while running message callback: {ex}", RNS.LOG_ERROR)
|
||||||
|
|
||||||
def _receive(self, raw: bytes):
|
def _receive(self, raw: bytes):
|
||||||
try:
|
try:
|
||||||
envelope = Envelope(outlet=self._outlet, raw=raw)
|
envelope = Envelope(outlet=self._outlet, raw=raw)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
message = envelope.unpack(self._message_factories)
|
message = envelope.unpack(self._message_factories)
|
||||||
prev_env = self._rx_ring[0] if len(self._rx_ring) > 0 else None
|
|
||||||
if prev_env and envelope.sequence != (prev_env.sequence + 1) % 0x10000:
|
# TODO: Test sequence overflow
|
||||||
RNS.log("Channel: Out of order packet received", RNS.LOG_EXTREME)
|
if envelope.sequence < self._next_rx_sequence:
|
||||||
|
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % 0x10000
|
||||||
|
if window_overflow < self._next_rx_sequence:
|
||||||
|
if envelope.sequence > window_overflow:
|
||||||
|
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if envelope.sequence < self._next_rx_sequence:
|
||||||
|
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
|
||||||
|
return
|
||||||
|
|
||||||
is_new = self._emplace_envelope(envelope, self._rx_ring)
|
is_new = self._emplace_envelope(envelope, self._rx_ring)
|
||||||
self._prune_rx_ring()
|
|
||||||
if not is_new:
|
if not is_new:
|
||||||
RNS.log("Channel: Duplicate message received", RNS.LOG_EXTREME)
|
RNS.log("Duplicate message received on channel "+str(self), RNS.LOG_EXTREME)
|
||||||
return
|
return
|
||||||
RNS.log(f"Message received: {message}", RNS.LOG_DEBUG)
|
else:
|
||||||
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start()
|
|
||||||
|
# TODO: Remove
|
||||||
|
# rmsg = "RX Ring State:\n "
|
||||||
|
# for e in self._rx_ring:
|
||||||
|
# rmsg += "["+str(e.sequence)+"]"
|
||||||
|
# rmsg += "\n"
|
||||||
|
# RNS.log(rmsg)
|
||||||
|
# print(rmsg)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
contigous = []
|
||||||
|
for e in self._rx_ring:
|
||||||
|
if e.sequence == self._next_rx_sequence:
|
||||||
|
contigous.append(e)
|
||||||
|
self._next_rx_sequence = (self._next_rx_sequence + 1) % 0x10000
|
||||||
|
|
||||||
|
for e in contigous:
|
||||||
|
m = e.unpack(self._message_factories)
|
||||||
|
self._rx_ring.remove(e)
|
||||||
|
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[m], daemon=True).start()
|
||||||
|
|
||||||
|
# TODO: Remove
|
||||||
|
# rmsg = "RX Ring State:\n "
|
||||||
|
# for e in self._rx_ring:
|
||||||
|
# rmsg += "["+str(e.sequence)+"]"
|
||||||
|
# rmsg += "\n"
|
||||||
|
# RNS.log(rmsg)
|
||||||
|
# print(rmsg)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
RNS.log(f"Channel: Error receiving data: {ex}")
|
RNS.log(f"Channel: Error receiving data: {ex}")
|
||||||
|
|
||||||
@ -381,14 +444,15 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
outstanding = 0
|
||||||
for envelope in self._tx_ring:
|
for envelope in self._tx_ring:
|
||||||
if envelope.outlet == self._outlet and (not envelope.packet
|
if envelope.outlet == self._outlet:
|
||||||
or self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_SENT):
|
if not envelope.packet or not self._outlet.get_packet_state(envelope.packet) == MessageState.MSGSTATE_DELIVERED:
|
||||||
# TODO: Check if this should be enabled with some kind of
|
outstanding += 1
|
||||||
# rate limiting, since it currently floods log output when
|
|
||||||
# messages are waiting.
|
if outstanding >= self._max_outstanding:
|
||||||
# RNS.log("Channel: Link has a pending message.", RNS.LOG_EXTREME)
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]):
|
def _packet_tx_op(self, packet: TPacket, op: Callable[[TPacket], bool]):
|
||||||
@ -419,6 +483,7 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
return True
|
return True
|
||||||
envelope.tries += 1
|
envelope.tries += 1
|
||||||
self._outlet.resend(envelope.packet)
|
self._outlet.resend(envelope.packet)
|
||||||
|
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
|
||||||
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
|
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -449,6 +514,25 @@ class Channel(contextlib.AbstractContextManager):
|
|||||||
envelope.tries += 1
|
envelope.tries += 1
|
||||||
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
|
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
|
||||||
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
|
self._outlet.set_packet_timeout_callback(envelope.packet, self._packet_timeout, self._get_packet_timeout_time(envelope.tries))
|
||||||
|
|
||||||
|
# TODO: Remove
|
||||||
|
# try:
|
||||||
|
# tmsg = "TX Ring State:\n "
|
||||||
|
# for e in self._tx_ring:
|
||||||
|
# estat="u"
|
||||||
|
# status = e.packet.receipt.get_status()
|
||||||
|
# if status == RNS.PacketReceipt.SENT:
|
||||||
|
# estat="s"
|
||||||
|
# if status == RNS.PacketReceipt.DELIVERED:
|
||||||
|
# estat="d"
|
||||||
|
# if status == RNS.PacketReceipt.FAILED:
|
||||||
|
# estat="f"
|
||||||
|
# tmsg += "["+str(e.sequence)+estat+"]"
|
||||||
|
# print(tmsg)
|
||||||
|
# RNS.log(tmsg)
|
||||||
|
# except:
|
||||||
|
# pass
|
||||||
|
|
||||||
return envelope
|
return envelope
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -483,7 +567,8 @@ class LinkChannelOutlet(ChannelOutletBase):
|
|||||||
|
|
||||||
def resend(self, packet: RNS.Packet) -> RNS.Packet:
|
def resend(self, packet: RNS.Packet) -> RNS.Packet:
|
||||||
RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG)
|
RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG)
|
||||||
if not packet.resend():
|
receipt = packet.resend()
|
||||||
|
if not receipt:
|
||||||
RNS.log("Failed to resend packet", RNS.LOG_ERROR)
|
RNS.log("Failed to resend packet", RNS.LOG_ERROR)
|
||||||
return packet
|
return packet
|
||||||
|
|
||||||
@ -538,4 +623,7 @@ class LinkChannelOutlet(ChannelOutletBase):
|
|||||||
packet.receipt.set_delivery_callback(inner if callback else None)
|
packet.receipt.set_delivery_callback(inner if callback else None)
|
||||||
|
|
||||||
def get_packet_id(self, packet: RNS.Packet) -> any:
|
def get_packet_id(self, packet: RNS.Packet) -> any:
|
||||||
|
if packet and hasattr(packet, "get_hash") and callable(packet.get_hash):
|
||||||
return packet.get_hash()
|
return packet.get_hash()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
13
RNS/Link.py
13
RNS/Link.py
@ -809,6 +809,19 @@ class Link:
|
|||||||
if not self._channel:
|
if not self._channel:
|
||||||
RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG)
|
RNS.log(f"Channel data received without open channel", RNS.LOG_DEBUG)
|
||||||
else:
|
else:
|
||||||
|
# TODO: Remove packet loss simulator ######
|
||||||
|
# if not hasattr(self, "drop_counter"):
|
||||||
|
# self.drop_counter = 0
|
||||||
|
# self.drop_counter += 1
|
||||||
|
|
||||||
|
# if self.drop_counter%6 == 0:
|
||||||
|
# RNS.log("Dropping channel packet for testing", RNS.LOG_DEBUG)
|
||||||
|
# else:
|
||||||
|
# packet.prove()
|
||||||
|
# plaintext = self.decrypt(packet.data)
|
||||||
|
# self._channel._receive(plaintext)
|
||||||
|
############################################
|
||||||
|
|
||||||
packet.prove()
|
packet.prove()
|
||||||
plaintext = self.decrypt(packet.data)
|
plaintext = self.decrypt(packet.data)
|
||||||
self._channel._receive(plaintext)
|
self._channel._receive(plaintext)
|
||||||
|
@ -61,6 +61,7 @@ class TestLink(unittest.TestCase):
|
|||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
close_rns()
|
close_rns()
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_0_valid_announce(self):
|
def test_0_valid_announce(self):
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
print("")
|
print("")
|
||||||
@ -71,6 +72,7 @@ class TestLink(unittest.TestCase):
|
|||||||
ap.pack()
|
ap.pack()
|
||||||
self.assertEqual(RNS.Identity.validate_announce(ap), True)
|
self.assertEqual(RNS.Identity.validate_announce(ap), True)
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_1_invalid_announce(self):
|
def test_1_invalid_announce(self):
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
print("")
|
print("")
|
||||||
@ -86,6 +88,7 @@ class TestLink(unittest.TestCase):
|
|||||||
ap.send()
|
ap.send()
|
||||||
self.assertEqual(RNS.Identity.validate_announce(ap), False)
|
self.assertEqual(RNS.Identity.validate_announce(ap), False)
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_2_establish(self):
|
def test_2_establish(self):
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
print("")
|
print("")
|
||||||
@ -105,6 +108,7 @@ class TestLink(unittest.TestCase):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_3_packets(self):
|
def test_3_packets(self):
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
print("")
|
print("")
|
||||||
@ -171,6 +175,7 @@ class TestLink(unittest.TestCase):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_4_micro_resource(self):
|
def test_4_micro_resource(self):
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
print("")
|
print("")
|
||||||
@ -206,6 +211,7 @@ class TestLink(unittest.TestCase):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_5_mini_resource(self):
|
def test_5_mini_resource(self):
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
print("")
|
print("")
|
||||||
@ -241,6 +247,7 @@ class TestLink(unittest.TestCase):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_6_small_resource(self):
|
def test_6_small_resource(self):
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
print("")
|
print("")
|
||||||
@ -276,6 +283,7 @@ class TestLink(unittest.TestCase):
|
|||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_7_medium_resource(self):
|
def test_7_medium_resource(self):
|
||||||
if RNS.Cryptography.backend() == "internal":
|
if RNS.Cryptography.backend() == "internal":
|
||||||
print("Skipping medium resource test...")
|
print("Skipping medium resource test...")
|
||||||
@ -314,6 +322,7 @@ class TestLink(unittest.TestCase):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_9_large_resource(self):
|
def test_9_large_resource(self):
|
||||||
if RNS.Cryptography.backend() == "internal":
|
if RNS.Cryptography.backend() == "internal":
|
||||||
print("Skipping large resource test...")
|
print("Skipping large resource test...")
|
||||||
@ -352,6 +361,7 @@ class TestLink(unittest.TestCase):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
#@skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_10_channel_round_trip(self):
|
def test_10_channel_round_trip(self):
|
||||||
global c_rns
|
global c_rns
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
@ -393,13 +403,14 @@ class TestLink(unittest.TestCase):
|
|||||||
self.assertEqual("Hello back", rx_message.data)
|
self.assertEqual("Hello back", rx_message.data)
|
||||||
self.assertEqual(test_message.id, rx_message.id)
|
self.assertEqual(test_message.id, rx_message.id)
|
||||||
self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized)
|
self.assertNotEqual(test_message.not_serialized, rx_message.not_serialized)
|
||||||
self.assertEqual(1, len(l1._channel._rx_ring))
|
self.assertEqual(0, len(l1._channel._rx_ring))
|
||||||
|
|
||||||
l1.teardown()
|
l1.teardown()
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
self.assertEqual(0, len(l1._channel._rx_ring))
|
self.assertEqual(0, len(l1._channel._rx_ring))
|
||||||
|
|
||||||
|
# @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None, "Skipping")
|
||||||
def test_11_buffer_round_trip(self):
|
def test_11_buffer_round_trip(self):
|
||||||
global c_rns
|
global c_rns
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
@ -442,6 +453,7 @@ class TestLink(unittest.TestCase):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
self.assertEqual(l1.status, RNS.Link.CLOSED)
|
||||||
|
|
||||||
|
# @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None and os.getenv('RUN_SLOW_TESTS') == None, "Skipping")
|
||||||
def test_12_buffer_round_trip_big(self, local_bitrate = None):
|
def test_12_buffer_round_trip_big(self, local_bitrate = None):
|
||||||
global c_rns
|
global c_rns
|
||||||
init_rns(self)
|
init_rns(self)
|
||||||
@ -479,6 +491,8 @@ class TestLink(unittest.TestCase):
|
|||||||
buffer = None
|
buffer = None
|
||||||
received = []
|
received = []
|
||||||
def handle_data(ready_bytes: int):
|
def handle_data(ready_bytes: int):
|
||||||
|
# TODO: Remove
|
||||||
|
RNS.log("Handling data")
|
||||||
data = buffer.read(ready_bytes)
|
data = buffer.read(ready_bytes)
|
||||||
received.append(data)
|
received.append(data)
|
||||||
|
|
||||||
@ -487,50 +501,60 @@ class TestLink(unittest.TestCase):
|
|||||||
|
|
||||||
# try to make the message big enough to split across packets, but
|
# try to make the message big enough to split across packets, but
|
||||||
# small enough to make the test complete in a reasonable amount of time
|
# small enough to make the test complete in a reasonable amount of time
|
||||||
seed_text = "0123456789"
|
# seed_text = "0123456789"
|
||||||
message = seed_text*ceil(min(max(local_interface.bitrate / 8,
|
# message = seed_text*ceil(min(max(local_interface.bitrate / 8,
|
||||||
StreamDataMessage.MAX_DATA_LEN * 2 / len(seed_text)),
|
# StreamDataMessage.MAX_DATA_LEN * 2 / len(seed_text)),
|
||||||
1000))
|
# 1000))
|
||||||
|
|
||||||
|
if local_interface.bitrate < 1000:
|
||||||
|
target_bytes = 3000
|
||||||
|
else:
|
||||||
|
target_bytes = 16000
|
||||||
|
|
||||||
|
|
||||||
|
message = os.urandom(target_bytes)
|
||||||
|
|
||||||
# the return message will have an appendage string " back at you"
|
# the return message will have an appendage string " back at you"
|
||||||
# for every StreamDataMessage that arrives. To verify, we need
|
# for every StreamDataMessage that arrives. To verify, we need
|
||||||
# to insert that string every MAX_DATA_LEN and also at the end.
|
# to insert that string every MAX_DATA_LEN and also at the end.
|
||||||
expected_rx_message = ""
|
expected_rx_message = b""
|
||||||
for i in range(0, len(message)):
|
for i in range(0, len(message)):
|
||||||
if i > 0 and (i % StreamDataMessage.MAX_DATA_LEN) == 0:
|
if i > 0 and (i % StreamDataMessage.MAX_DATA_LEN) == 0:
|
||||||
expected_rx_message += " back at you"
|
expected_rx_message += " back at you".encode("utf-8")
|
||||||
expected_rx_message += message[i]
|
expected_rx_message += bytes([message[i]])
|
||||||
expected_rx_message += " back at you"
|
expected_rx_message += " back at you".encode("utf-8")
|
||||||
|
|
||||||
# since the segments will be received at max length for a
|
# since the segments will be received at max length for a
|
||||||
# StreamDataMessage, the appended text will end up in a
|
# StreamDataMessage, the appended text will end up in a
|
||||||
# separate packet.
|
# separate packet.
|
||||||
expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2)
|
expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2)-1
|
||||||
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " +
|
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " +
|
||||||
"expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes")
|
"expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes")
|
||||||
transfer_sleep = max(expected_chunk_count * 3 * c_rns.MTU / local_interface.bitrate * 8, 3)
|
|
||||||
print("Will take up to " + str(round(transfer_sleep, 0)) + " seconds to transfer")
|
buffer.write(message)
|
||||||
expected_ready_time = time.time() + transfer_sleep
|
|
||||||
buffer.write(message.encode("utf-8"))
|
|
||||||
buffer.flush()
|
buffer.flush()
|
||||||
|
|
||||||
# delay a reasonable time for the send and receive
|
# delay a reasonable time for the send and receive
|
||||||
# a chunk each way plus a little more for a proof each way
|
# a chunk each way plus a little more for a proof each way
|
||||||
while time.time() < expected_ready_time and len(received) < expected_chunk_count:
|
# while time.time() < expected_ready_time and len(received) < expected_chunk_count:
|
||||||
time.sleep(0.1)
|
# time.sleep(0.1)
|
||||||
# sleep for at least one more chunk round trip in case there
|
# # sleep for at least one more chunk round trip in case there
|
||||||
# are more chunks than expected
|
# # are more chunks than expected
|
||||||
if time.time() < expected_ready_time:
|
# if time.time() < expected_ready_time:
|
||||||
time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1))
|
# time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1))
|
||||||
|
|
||||||
time.sleep(0.25)
|
timeout = time.time() + 10
|
||||||
|
while len(received) < expected_chunk_count and not time.time() > timeout:
|
||||||
# Why does this not always work out correctly?
|
time.sleep(2)
|
||||||
# self.assertEqual(expected_chunk_count, len(received))
|
print(f"Received {len(received)} out of {expected_chunk_count} chunks so far")
|
||||||
|
time.sleep(2)
|
||||||
|
print(f"Received {len(received)} out of {expected_chunk_count} chunks")
|
||||||
|
|
||||||
data = bytearray()
|
data = bytearray()
|
||||||
for rx in received:
|
for rx in received:
|
||||||
data.extend(rx)
|
data.extend(rx)
|
||||||
|
|
||||||
rx_message = data.decode("utf-8")
|
rx_message = data
|
||||||
|
|
||||||
self.assertEqual(len(expected_rx_message), len(rx_message))
|
self.assertEqual(len(expected_rx_message), len(rx_message))
|
||||||
for i in range(0, len(expected_rx_message)):
|
for i in range(0, len(expected_rx_message)):
|
||||||
@ -548,7 +572,7 @@ class TestLink(unittest.TestCase):
|
|||||||
# RUN_SLOW_TESTS=1 python tests/link.py TestLink.test_13_buffer_round_trip_big_slow
|
# RUN_SLOW_TESTS=1 python tests/link.py TestLink.test_13_buffer_round_trip_big_slow
|
||||||
# Or
|
# Or
|
||||||
# make RUN_SLOW_TESTS=1 test
|
# make RUN_SLOW_TESTS=1 test
|
||||||
@skipIf(int(os.getenv('RUN_SLOW_TESTS', 0)) < 1, "Not running slow tests")
|
@skipIf(os.getenv('RUN_SLOW_TESTS') == None, "Not running slow tests")
|
||||||
def test_13_buffer_round_trip_big_slow(self):
|
def test_13_buffer_round_trip_big_slow(self):
|
||||||
self.test_12_buffer_round_trip_big(local_bitrate=410)
|
self.test_12_buffer_round_trip_big(local_bitrate=410)
|
||||||
|
|
||||||
@ -612,8 +636,10 @@ def targets(yp=False):
|
|||||||
channel = link.get_channel()
|
channel = link.get_channel()
|
||||||
|
|
||||||
def handle_message(message):
|
def handle_message(message):
|
||||||
|
if isinstance(message, MessageTest):
|
||||||
message.data = message.data + " back"
|
message.data = message.data + " back"
|
||||||
channel.send(message)
|
channel.send(message)
|
||||||
|
|
||||||
channel.register_message_type(MessageTest)
|
channel.register_message_type(MessageTest)
|
||||||
channel.add_message_handler(handle_message)
|
channel.add_message_handler(handle_message)
|
||||||
|
|
||||||
@ -621,12 +647,12 @@ def targets(yp=False):
|
|||||||
|
|
||||||
def handle_buffer(ready_bytes: int):
|
def handle_buffer(ready_bytes: int):
|
||||||
data = buffer.read(ready_bytes)
|
data = buffer.read(ready_bytes)
|
||||||
buffer.write((data.decode("utf-8") + " back at you").encode("utf-8"))
|
buffer.write(data + " back at you".encode("utf-8"))
|
||||||
buffer.flush()
|
buffer.flush()
|
||||||
|
|
||||||
buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_buffer)
|
buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_buffer)
|
||||||
|
|
||||||
m_rns = RNS.Reticulum("./tests/rnsconfig")
|
m_rns = RNS.Reticulum("./tests/rnsconfig", logdest=RNS.LOG_FILE, loglevel=RNS.LOG_EXTREME)
|
||||||
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
|
id1 = RNS.Identity.from_bytes(bytes.fromhex(fixed_keys[0][0]))
|
||||||
d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish")
|
d1 = RNS.Destination(id1, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME, "link", "establish")
|
||||||
d1.set_proof_strategy(RNS.Destination.PROVE_ALL)
|
d1.set_proof_strategy(RNS.Destination.PROVE_ALL)
|
||||||
|
Loading…
Reference in New Issue
Block a user