Compare commits

..

7 Commits

Author SHA1 Message Date
Mark Qvist
d7375bc4c3 Fixed callback invocation on channel receive 2023-05-19 01:58:28 +02:00
Mark Qvist
1a860c6ffd Add EOF signal on buffer close 2023-05-19 01:57:20 +02:00
Mark Qvist
800ed3af7a Fixed ready callback invocation 2023-05-18 23:35:28 +02:00
Mark Qvist
9c8e79546c Fixed missing check in receipt culling 2023-05-18 23:33:26 +02:00
Mark Qvist
4c272aa536 Updated buffer tests for windowed channel 2023-05-18 23:32:29 +02:00
Mark Qvist
e184861822 Enabled channel tests 2023-05-18 23:31:29 +02:00
Mark Qvist
d40e19f08d Updated gitignore 2023-05-18 23:29:31 +02:00
6 changed files with 161 additions and 125 deletions

2
.gitignore vendored
View File

@ -10,6 +10,6 @@ docs/build
rns*.egg-info
profile.data
tests/rnsconfig/storage
tests/rnsconfig/logfile
tests/rnsconfig/logfile*
*.data
*.result

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import bz2
import sys
import time
import threading
from threading import RLock
import struct
@ -65,6 +66,7 @@ class StreamDataMessage(MessageBase):
self.compressed = (0x4000 & self.stream_id) > 0
self.stream_id = self.stream_id & 0x3fff
self.data = raw[2:]
if self.compressed:
self.data = bz2.decompress(self.data)
@ -129,7 +131,7 @@ class RawChannelReader(RawIOBase, AbstractContextManager):
self._eof = True
for listener in self._listeners:
try:
listener(len(self._buffer))
threading.Thread(target=listener, name="Message Callback", args=[len(self._buffer)], daemon=True).start()
except Exception as ex:
RNS.log("Error calling RawChannelReader(" + str(self._stream_id) + ") callback: " + str(ex))
return True
@ -207,6 +209,15 @@ class RawChannelWriter(RawIOBase, AbstractContextManager):
return 0
def close(self):
try:
link_rtt = self._channel._outlet.link.rtt
timeout = time.time() + (link_rtt * len(self._channel._tx_ring) * 1)
except Exception as e:
timeout = time.time() + 15
while time.time() < timeout and not self._channel.is_ready_to_send():
time.sleep(0.05)
self._eof = True
self.write(bytes())

View File

@ -176,6 +176,9 @@ class Envelope:
raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}")
message = ctor()
message.unpack(raw)
self.unpacked = True
self.message = message
return message
def pack(self) -> bytes:
@ -183,6 +186,7 @@ class Envelope:
raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE")
data = self.message.pack()
self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data
self.packed = True
return self.raw
def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None):
@ -194,6 +198,8 @@ class Envelope:
self.sequence = sequence
self.outlet = outlet
self.tries = 0
self.unpacked = False
self.packed = False
self.tracked = False
@ -371,22 +377,29 @@ class Channel(contextlib.AbstractContextManager):
def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
with self._lock:
i = 0
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
for existing in ring:
if existing.sequence > envelope.sequence \
and not existing.sequence // 2 > envelope.sequence: # account for overflow
ring.insert(i, envelope)
return True
if existing.sequence == envelope.sequence:
if envelope.sequence == existing.sequence:
RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME)
return False
if envelope.sequence < existing.sequence and not envelope.sequence < window_overflow:
ring.insert(i, envelope)
RNS.log("Inserted seq "+str(envelope.sequence)+" at "+str(i), RNS.LOG_DEBUG)
envelope.tracked = True
return True
i += 1
envelope.tracked = True
ring.append(envelope)
return True
def _run_callbacks(self, message: MessageBase):
with self._lock:
cbs = self._message_callbacks.copy()
cbs = self._message_callbacks.copy()
for cb in cbs:
try:
@ -405,12 +418,11 @@ class Channel(contextlib.AbstractContextManager):
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
if window_overflow < self._next_rx_sequence:
if envelope.sequence > window_overflow:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
return
else:
if envelope.sequence < self._next_rx_sequence:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
return
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
return
is_new = self._emplace_envelope(envelope, self._rx_ring)
@ -426,9 +438,13 @@ class Channel(contextlib.AbstractContextManager):
self._next_rx_sequence = (self._next_rx_sequence + 1) % Channel.SEQ_MODULUS
for e in contigous:
m = e.unpack(self._message_factories)
if not e.unpacked:
m = e.unpack(self._message_factories)
else:
m = e.message
self._rx_ring.remove(e)
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[m], daemon=True).start()
self._run_callbacks(m)
except Exception as e:
RNS.log("An error ocurred while receiving data on "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR)
@ -469,7 +485,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_min += 1
# TODO: Remove at some point
RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_DEBUG)
# RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
if self._outlet.rtt != 0:
if self._outlet.rtt > Channel.RTT_FAST:
@ -483,19 +499,17 @@ class Channel(contextlib.AbstractContextManager):
if self.window_max < Channel.WINDOW_MAX_MEDIUM and self.medium_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_MEDIUM
# TODO: Remove at some point
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else:
self.fast_rate_rounds += 1
if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_FAST
# TODO: Remove at some point
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else:
RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_DEBUG)
RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_EXTREME)
if not envelope:
RNS.log("Spurious message received on "+str(self), RNS.LOG_EXTREME)
@ -525,7 +539,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_max -= 1
# TODO: Remove at some point
RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
# RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
return False
@ -543,16 +557,18 @@ class Channel(contextlib.AbstractContextManager):
with self._lock:
if not self.is_ready_to_send():
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS
self._emplace_envelope(envelope, self._tx_ring)
if envelope is None:
raise BlockingIOError()
envelope.pack()
if len(envelope.raw) > self._outlet.mdu:
raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
envelope.packet = self._outlet.send(envelope.raw)
envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
@ -591,7 +607,6 @@ class LinkChannelOutlet(ChannelOutletBase):
return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet:
RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG)
receipt = packet.resend()
if not receipt:
RNS.log("Failed to resend packet", RNS.LOG_ERROR)

View File

@ -334,7 +334,8 @@ class Transport:
for receipt in Transport.receipts:
receipt.check_timeout()
if receipt.status != RNS.PacketReceipt.SENT:
Transport.receipts.remove(receipt)
if receipt in Transport.receipts:
Transport.receipts.remove(receipt)
Transport.receipts_last_checked = time.time()

View File

@ -404,104 +404,104 @@ class TestChannel(unittest.TestCase):
self.assertEqual(data, decoded)
# def test_buffer_big(self):
# writer = RNS.Buffer.create_writer(15, self.h.channel)
# reader = RNS.Buffer.create_reader(15, self.h.channel)
# data = "01234556789"*1024 # 10 KB
# count = 0
# write_finished = False
def test_buffer_big(self):
writer = RNS.Buffer.create_writer(15, self.h.channel)
reader = RNS.Buffer.create_reader(15, self.h.channel)
data = "01234556789"*1024 # 10 KB
count = 0
write_finished = False
# def write_thread():
# nonlocal count, write_finished
# count = writer.write(data.encode("utf-8"))
# writer.flush()
# writer.close()
# write_finished = True
# threading.Thread(target=write_thread, name="Write Thread", daemon=True).start()
def write_thread():
nonlocal count, write_finished
count = writer.write(data.encode("utf-8"))
writer.flush()
writer.close()
write_finished = True
threading.Thread(target=write_thread, name="Write Thread", daemon=True).start()
# while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED,
# self.h.outlet.packets), None) is not None:
# with self.h.outlet.lock:
# for packet in self.h.outlet.packets:
# if packet.state != MessageState.MSGSTATE_DELIVERED:
# self.h.channel._receive(packet.raw)
# packet.delivered()
# time.sleep(0.0001)
while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED,
self.h.outlet.packets), None) is not None:
with self.h.outlet.lock:
for packet in self.h.outlet.packets:
if packet.state != MessageState.MSGSTATE_DELIVERED:
self.h.channel._receive(packet.raw)
packet.delivered()
time.sleep(0.0001)
# self.assertEqual(len(data), count)
self.assertEqual(len(data), count)
# read_finished = False
# result = bytes()
read_finished = False
result = bytes()
# def read_thread():
# nonlocal read_finished, result
# result = reader.read()
# read_finished = True
# threading.Thread(target=read_thread, name="Read Thread", daemon=True).start()
def read_thread():
nonlocal read_finished, result
result = reader.read()
read_finished = True
threading.Thread(target=read_thread, name="Read Thread", daemon=True).start()
# timeout_at = time.time() + 7
# while not read_finished and time.time() < timeout_at:
# time.sleep(0.001)
timeout_at = time.time() + 7
while not read_finished and time.time() < timeout_at:
time.sleep(0.001)
# self.assertTrue(read_finished)
# self.assertEqual(len(data), len(result))
self.assertTrue(read_finished)
self.assertEqual(len(data), len(result))
# decoded = result.decode("utf-8")
decoded = result.decode("utf-8")
# self.assertSequenceEqual(data, decoded)
self.assertSequenceEqual(data, decoded)
# def test_buffer_small_with_callback(self):
# callbacks = 0
# last_cb_value = None
def test_buffer_small_with_callback(self):
callbacks = 0
last_cb_value = None
# def callback(ready: int):
# nonlocal callbacks, last_cb_value
# callbacks += 1
# last_cb_value = ready
def callback(ready: int):
nonlocal callbacks, last_cb_value
callbacks += 1
last_cb_value = ready
# data = "Hello\n"
# with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader:
# reader.add_ready_callback(callback)
# count = writer.write(data.encode("utf-8"))
# writer.flush()
data = "Hello\n"
with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader:
reader.add_ready_callback(callback)
count = writer.write(data.encode("utf-8"))
writer.flush()
# self.assertEqual(len(data), count)
# self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(len(data), count)
self.assertEqual(1, len(self.h.outlet.packets))
# packet = self.h.outlet.packets[0]
# self.h.channel._receive(packet.raw)
# packet.delivered()
packet = self.h.outlet.packets[0]
self.h.channel._receive(packet.raw)
packet.delivered()
# self.assertEqual(1, callbacks)
# self.assertEqual(len(data), last_cb_value)
self.assertEqual(1, callbacks)
self.assertEqual(len(data), last_cb_value)
# result = reader.readline()
result = reader.readline()
# self.assertIsNotNone(result)
# self.assertEqual(len(result), len(data))
self.assertIsNotNone(result)
self.assertEqual(len(result), len(data))
# decoded = result.decode("utf-8")
decoded = result.decode("utf-8")
# self.assertEqual(data, decoded)
# self.assertEqual(1, len(self.h.outlet.packets))
self.assertEqual(data, decoded)
self.assertEqual(1, len(self.h.outlet.packets))
# result = reader.read(1)
result = reader.read(1)
# self.assertIsNone(result)
# self.assertTrue(self.h.channel.is_ready_to_send())
self.assertIsNone(result)
self.assertTrue(self.h.channel.is_ready_to_send())
# writer.close()
writer.close()
# self.assertEqual(2, len(self.h.outlet.packets))
self.assertEqual(2, len(self.h.outlet.packets))
# packet = self.h.outlet.packets[1]
# self.h.channel._receive(packet.raw)
# packet.delivered()
packet = self.h.outlet.packets[1]
self.h.channel._receive(packet.raw)
packet.delivered()
# result = reader.read(1)
result = reader.read(1)
# self.assertIsNotNone(result)
# self.assertTrue(len(result) == 0)
self.assertIsNotNone(result)
self.assertTrue(len(result) == 0)
if __name__ == '__main__':

View File

@ -4,6 +4,7 @@ import subprocess
import shlex
import threading
import time
import random
from unittest import skipIf
import RNS
import os
@ -23,6 +24,8 @@ fixed_keys = [
("08bb35f92b06a0832991165a0d9b4fd91af7b7765ce4572aa6222070b11b767092b61b0fd18b3a59cae6deb9db6d4bfb1c7fcfe076cfd66eea7ddd5f877543b9", "d13712efc45ef87674fb5ac26c37c912"),
]
BUFFER_TEST_TARGET = 32000
def targets_job(caller):
cmd = "python -c \"from tests.link import targets; targets()\""
print("Opening subprocess for "+str(cmd)+"...", RNS.LOG_VERBOSE)
@ -455,7 +458,7 @@ class TestLink(unittest.TestCase):
# @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None and os.getenv('RUN_SLOW_TESTS') == None, "Skipping")
def test_12_buffer_round_trip_big(self, local_bitrate = None):
global c_rns
global c_rns, buffer_read_target
init_rns(self)
print("")
print("Buffer round trip test")
@ -490,9 +493,9 @@ class TestLink(unittest.TestCase):
buffer = None
received = []
def handle_data(ready_bytes: int):
# TODO: Remove
RNS.log("Handling data")
global received_bytes
data = buffer.read(ready_bytes)
received.append(data)
@ -509,10 +512,11 @@ class TestLink(unittest.TestCase):
if local_interface.bitrate < 1000:
target_bytes = 3000
else:
target_bytes = 16000
target_bytes = BUFFER_TEST_TARGET
message = os.urandom(target_bytes)
random.seed(154889)
message = random.randbytes(target_bytes)
buffer_read_target = len(message)
# the return message will have an appendage string " back at you"
# for every StreamDataMessage that arrives. To verify, we need
@ -527,35 +531,24 @@ class TestLink(unittest.TestCase):
# since the segments will be received at max length for a
# StreamDataMessage, the appended text will end up in a
# separate packet.
expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2)-1
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " +
"expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes")
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, ")
buffer.write(message)
buffer.flush()
# delay a reasonable time for the send and receive
# a chunk each way plus a little more for a proof each way
# while time.time() < expected_ready_time and len(received) < expected_chunk_count:
# time.sleep(0.1)
# # sleep for at least one more chunk round trip in case there
# # are more chunks than expected
# if time.time() < expected_ready_time:
# time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1))
timeout = time.time() + 10
while len(received) < expected_chunk_count and not time.time() > timeout:
time.sleep(2)
print(f"Received {len(received)} out of {expected_chunk_count} chunks so far")
time.sleep(2)
print(f"Received {len(received)} out of {expected_chunk_count} chunks")
timeout = time.time() + 4
while not time.time() > timeout:
time.sleep(1)
print(f"Received {len(received)} chunks so far")
time.sleep(1)
data = bytearray()
for rx in received:
data.extend(rx)
rx_message = data
print(f"Received {len(received)} chunks, totalling {len(rx_message)} bytes")
self.assertEqual(len(expected_rx_message), len(rx_message))
for i in range(0, len(expected_rx_message)):
self.assertEqual(expected_rx_message[i], rx_message[i])
@ -598,7 +591,7 @@ class TestLink(unittest.TestCase):
if __name__ == '__main__':
unittest.main(verbosity=1)
buffer_read_len = 0
def targets(yp=False):
if yp:
import yappi
@ -645,10 +638,26 @@ def targets(yp=False):
buffer = None
response_data = []
def handle_buffer(ready_bytes: int):
global buffer_read_len, BUFFER_TEST_TARGET
data = buffer.read(ready_bytes)
buffer.write(data + " back at you".encode("utf-8"))
buffer.flush()
buffer_read_len += len(data)
response_data.append(data)
if data == "Hi there".encode("utf-8"):
RNS.log("Sending response")
for data in response_data:
buffer.write(data + " back at you".encode("utf-8"))
buffer.flush()
buffer_read_len = 0
if buffer_read_len == BUFFER_TEST_TARGET:
RNS.log("Sending response")
for data in response_data:
buffer.write(data + " back at you".encode("utf-8"))
buffer.flush()
buffer_read_len = 0
buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_buffer)