Compare commits

..

No commits in common. "d7375bc4c3df22c43ac3beb90699933cfd1d954d" and "817ee0721a5b9816c19163200c5da460d15b4922" have entirely different histories.

6 changed files with 125 additions and 161 deletions

2
.gitignore vendored
View File

@ -10,6 +10,6 @@ docs/build
rns*.egg-info
profile.data
tests/rnsconfig/storage
tests/rnsconfig/logfile*
tests/rnsconfig/logfile
*.data
*.result

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import bz2
import sys
import time
import threading
from threading import RLock
import struct
@ -66,7 +65,6 @@ class StreamDataMessage(MessageBase):
self.compressed = (0x4000 & self.stream_id) > 0
self.stream_id = self.stream_id & 0x3fff
self.data = raw[2:]
if self.compressed:
self.data = bz2.decompress(self.data)
@ -131,7 +129,7 @@ class RawChannelReader(RawIOBase, AbstractContextManager):
self._eof = True
for listener in self._listeners:
try:
threading.Thread(target=listener, name="Message Callback", args=[len(self._buffer)], daemon=True).start()
listener(len(self._buffer))
except Exception as ex:
RNS.log("Error calling RawChannelReader(" + str(self._stream_id) + ") callback: " + str(ex))
return True
@ -209,15 +207,6 @@ class RawChannelWriter(RawIOBase, AbstractContextManager):
return 0
def close(self):
try:
link_rtt = self._channel._outlet.link.rtt
timeout = time.time() + (link_rtt * len(self._channel._tx_ring) * 1)
except Exception as e:
timeout = time.time() + 15
while time.time() < timeout and not self._channel.is_ready_to_send():
time.sleep(0.05)
self._eof = True
self.write(bytes())

View File

@ -176,9 +176,6 @@ class Envelope:
raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}")
message = ctor()
message.unpack(raw)
self.unpacked = True
self.message = message
return message
def pack(self) -> bytes:
@ -186,7 +183,6 @@ class Envelope:
raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE")
data = self.message.pack()
self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data
self.packed = True
return self.raw
def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None):
@ -198,8 +194,6 @@ class Envelope:
self.sequence = sequence
self.outlet = outlet
self.tries = 0
self.unpacked = False
self.packed = False
self.tracked = False
@ -377,29 +371,22 @@ class Channel(contextlib.AbstractContextManager):
def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
with self._lock:
i = 0
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
for existing in ring:
if envelope.sequence == existing.sequence:
if existing.sequence > envelope.sequence \
and not existing.sequence // 2 > envelope.sequence: # account for overflow
ring.insert(i, envelope)
return True
if existing.sequence == envelope.sequence:
RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME)
return False
if envelope.sequence < existing.sequence and not envelope.sequence < window_overflow:
ring.insert(i, envelope)
RNS.log("Inserted seq "+str(envelope.sequence)+" at "+str(i), RNS.LOG_DEBUG)
envelope.tracked = True
return True
i += 1
envelope.tracked = True
ring.append(envelope)
return True
def _run_callbacks(self, message: MessageBase):
cbs = self._message_callbacks.copy()
with self._lock:
cbs = self._message_callbacks.copy()
for cb in cbs:
try:
@ -418,11 +405,12 @@ class Channel(contextlib.AbstractContextManager):
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
if window_overflow < self._next_rx_sequence:
if envelope.sequence > window_overflow:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
return
else:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
return
if envelope.sequence < self._next_rx_sequence:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
return
is_new = self._emplace_envelope(envelope, self._rx_ring)
@ -438,13 +426,9 @@ class Channel(contextlib.AbstractContextManager):
self._next_rx_sequence = (self._next_rx_sequence + 1) % Channel.SEQ_MODULUS
for e in contigous:
if not e.unpacked:
m = e.unpack(self._message_factories)
else:
m = e.message
m = e.unpack(self._message_factories)
self._rx_ring.remove(e)
self._run_callbacks(m)
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[m], daemon=True).start()
except Exception as e:
RNS.log("An error ocurred while receiving data on "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR)
@ -485,7 +469,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_min += 1
# TODO: Remove at some point
# RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_DEBUG)
if self._outlet.rtt != 0:
if self._outlet.rtt > Channel.RTT_FAST:
@ -499,17 +483,19 @@ class Channel(contextlib.AbstractContextManager):
if self.window_max < Channel.WINDOW_MAX_MEDIUM and self.medium_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_MEDIUM
# TODO: Remove at some point
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else:
self.fast_rate_rounds += 1
if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_FAST
# TODO: Remove at some point
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else:
RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_EXTREME)
RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_DEBUG)
if not envelope:
RNS.log("Spurious message received on "+str(self), RNS.LOG_EXTREME)
@ -539,7 +525,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_max -= 1
# TODO: Remove at some point
# RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
return False
@ -557,18 +543,16 @@ class Channel(contextlib.AbstractContextManager):
with self._lock:
if not self.is_ready_to_send():
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS
self._emplace_envelope(envelope, self._tx_ring)
self._emplace_envelope(envelope, self._tx_ring)
if envelope is None:
raise BlockingIOError()
envelope.pack()
if len(envelope.raw) > self._outlet.mdu:
raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
envelope.packet = self._outlet.send(envelope.raw)
envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
@ -607,6 +591,7 @@ class LinkChannelOutlet(ChannelOutletBase):
return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet:
RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG)
receipt = packet.resend()
if not receipt:
RNS.log("Failed to resend packet", RNS.LOG_ERROR)

View File

@ -334,8 +334,7 @@ class Transport:
for receipt in Transport.receipts:
receipt.check_timeout()
if receipt.status != RNS.PacketReceipt.SENT:
if receipt in Transport.receipts:
Transport.receipts.remove(receipt)
Transport.receipts.remove(receipt)
Transport.receipts_last_checked = time.time()

View File

@ -404,104 +404,104 @@ class TestChannel(unittest.TestCase):
self.assertEqual(data, decoded)
def test_buffer_big(self):
writer = RNS.Buffer.create_writer(15, self.h.channel)
reader = RNS.Buffer.create_reader(15, self.h.channel)
data = "01234556789"*1024 # 10 KB
count = 0
write_finished = False
# def test_buffer_big(self):
# writer = RNS.Buffer.create_writer(15, self.h.channel)
# reader = RNS.Buffer.create_reader(15, self.h.channel)
# data = "01234556789"*1024 # 10 KB
# count = 0
# write_finished = False
def write_thread():
nonlocal count, write_finished
count = writer.write(data.encode("utf-8"))
writer.flush()
writer.close()
write_finished = True
threading.Thread(target=write_thread, name="Write Thread", daemon=True).start()
# def write_thread():
# nonlocal count, write_finished
# count = writer.write(data.encode("utf-8"))
# writer.flush()
# writer.close()
# write_finished = True
# threading.Thread(target=write_thread, name="Write Thread", daemon=True).start()
while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED,
self.h.outlet.packets), None) is not None:
with self.h.outlet.lock:
for packet in self.h.outlet.packets:
if packet.state != MessageState.MSGSTATE_DELIVERED:
self.h.channel._receive(packet.raw)
packet.delivered()
time.sleep(0.0001)
# while not write_finished or next(filter(lambda x: x.state != MessageState.MSGSTATE_DELIVERED,
# self.h.outlet.packets), None) is not None:
# with self.h.outlet.lock:
# for packet in self.h.outlet.packets:
# if packet.state != MessageState.MSGSTATE_DELIVERED:
# self.h.channel._receive(packet.raw)
# packet.delivered()
# time.sleep(0.0001)
self.assertEqual(len(data), count)
# self.assertEqual(len(data), count)
read_finished = False
result = bytes()
# read_finished = False
# result = bytes()
def read_thread():
nonlocal read_finished, result
result = reader.read()
read_finished = True
threading.Thread(target=read_thread, name="Read Thread", daemon=True).start()
# def read_thread():
# nonlocal read_finished, result
# result = reader.read()
# read_finished = True
# threading.Thread(target=read_thread, name="Read Thread", daemon=True).start()
timeout_at = time.time() + 7
while not read_finished and time.time() < timeout_at:
time.sleep(0.001)
# timeout_at = time.time() + 7
# while not read_finished and time.time() < timeout_at:
# time.sleep(0.001)
self.assertTrue(read_finished)
self.assertEqual(len(data), len(result))
# self.assertTrue(read_finished)
# self.assertEqual(len(data), len(result))
decoded = result.decode("utf-8")
# decoded = result.decode("utf-8")
self.assertSequenceEqual(data, decoded)
# self.assertSequenceEqual(data, decoded)
def test_buffer_small_with_callback(self):
callbacks = 0
last_cb_value = None
# def test_buffer_small_with_callback(self):
# callbacks = 0
# last_cb_value = None
def callback(ready: int):
nonlocal callbacks, last_cb_value
callbacks += 1
last_cb_value = ready
# def callback(ready: int):
# nonlocal callbacks, last_cb_value
# callbacks += 1
# last_cb_value = ready
data = "Hello\n"
with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader:
reader.add_ready_callback(callback)
count = writer.write(data.encode("utf-8"))
writer.flush()
# data = "Hello\n"
# with RNS.RawChannelWriter(0, self.h.channel) as writer, RNS.RawChannelReader(0, self.h.channel) as reader:
# reader.add_ready_callback(callback)
# count = writer.write(data.encode("utf-8"))
# writer.flush()
self.assertEqual(len(data), count)
self.assertEqual(1, len(self.h.outlet.packets))
# self.assertEqual(len(data), count)
# self.assertEqual(1, len(self.h.outlet.packets))
packet = self.h.outlet.packets[0]
self.h.channel._receive(packet.raw)
packet.delivered()
# packet = self.h.outlet.packets[0]
# self.h.channel._receive(packet.raw)
# packet.delivered()
self.assertEqual(1, callbacks)
self.assertEqual(len(data), last_cb_value)
# self.assertEqual(1, callbacks)
# self.assertEqual(len(data), last_cb_value)
result = reader.readline()
# result = reader.readline()
self.assertIsNotNone(result)
self.assertEqual(len(result), len(data))
# self.assertIsNotNone(result)
# self.assertEqual(len(result), len(data))
decoded = result.decode("utf-8")
# decoded = result.decode("utf-8")
self.assertEqual(data, decoded)
self.assertEqual(1, len(self.h.outlet.packets))
# self.assertEqual(data, decoded)
# self.assertEqual(1, len(self.h.outlet.packets))
result = reader.read(1)
# result = reader.read(1)
self.assertIsNone(result)
self.assertTrue(self.h.channel.is_ready_to_send())
# self.assertIsNone(result)
# self.assertTrue(self.h.channel.is_ready_to_send())
writer.close()
# writer.close()
self.assertEqual(2, len(self.h.outlet.packets))
# self.assertEqual(2, len(self.h.outlet.packets))
packet = self.h.outlet.packets[1]
self.h.channel._receive(packet.raw)
packet.delivered()
# packet = self.h.outlet.packets[1]
# self.h.channel._receive(packet.raw)
# packet.delivered()
result = reader.read(1)
# result = reader.read(1)
self.assertIsNotNone(result)
self.assertTrue(len(result) == 0)
# self.assertIsNotNone(result)
# self.assertTrue(len(result) == 0)
if __name__ == '__main__':

View File

@ -4,7 +4,6 @@ import subprocess
import shlex
import threading
import time
import random
from unittest import skipIf
import RNS
import os
@ -24,8 +23,6 @@ fixed_keys = [
("08bb35f92b06a0832991165a0d9b4fd91af7b7765ce4572aa6222070b11b767092b61b0fd18b3a59cae6deb9db6d4bfb1c7fcfe076cfd66eea7ddd5f877543b9", "d13712efc45ef87674fb5ac26c37c912"),
]
BUFFER_TEST_TARGET = 32000
def targets_job(caller):
cmd = "python -c \"from tests.link import targets; targets()\""
print("Opening subprocess for "+str(cmd)+"...", RNS.LOG_VERBOSE)
@ -458,7 +455,7 @@ class TestLink(unittest.TestCase):
# @skipIf(os.getenv('SKIP_NORMAL_TESTS') != None and os.getenv('RUN_SLOW_TESTS') == None, "Skipping")
def test_12_buffer_round_trip_big(self, local_bitrate = None):
global c_rns, buffer_read_target
global c_rns
init_rns(self)
print("")
print("Buffer round trip test")
@ -493,9 +490,9 @@ class TestLink(unittest.TestCase):
buffer = None
received = []
def handle_data(ready_bytes: int):
global received_bytes
# TODO: Remove
RNS.log("Handling data")
data = buffer.read(ready_bytes)
received.append(data)
@ -512,11 +509,10 @@ class TestLink(unittest.TestCase):
if local_interface.bitrate < 1000:
target_bytes = 3000
else:
target_bytes = BUFFER_TEST_TARGET
target_bytes = 16000
random.seed(154889)
message = random.randbytes(target_bytes)
buffer_read_target = len(message)
message = os.urandom(target_bytes)
# the return message will have an appendage string " back at you"
# for every StreamDataMessage that arrives. To verify, we need
@ -531,23 +527,34 @@ class TestLink(unittest.TestCase):
# since the segments will be received at max length for a
# StreamDataMessage, the appended text will end up in a
# separate packet.
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, ")
expected_chunk_count = ceil(len(message)/StreamDataMessage.MAX_DATA_LEN * 2)-1
print("Sending " + str(len(message)) + " bytes, receiving " + str(len(expected_rx_message)) + " bytes, " +
"expecting " + str(expected_chunk_count) + " chunks of " + str(StreamDataMessage.MAX_DATA_LEN) + " bytes")
buffer.write(message)
buffer.flush()
timeout = time.time() + 4
while not time.time() > timeout:
time.sleep(1)
print(f"Received {len(received)} chunks so far")
time.sleep(1)
# delay a reasonable time for the send and receive
# a chunk each way plus a little more for a proof each way
# while time.time() < expected_ready_time and len(received) < expected_chunk_count:
# time.sleep(0.1)
# # sleep for at least one more chunk round trip in case there
# # are more chunks than expected
# if time.time() < expected_ready_time:
# time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1))
timeout = time.time() + 10
while len(received) < expected_chunk_count and not time.time() > timeout:
time.sleep(2)
print(f"Received {len(received)} out of {expected_chunk_count} chunks so far")
time.sleep(2)
print(f"Received {len(received)} out of {expected_chunk_count} chunks")
data = bytearray()
for rx in received:
data.extend(rx)
rx_message = data
print(f"Received {len(received)} chunks, totalling {len(rx_message)} bytes")
rx_message = data
self.assertEqual(len(expected_rx_message), len(rx_message))
for i in range(0, len(expected_rx_message)):
@ -591,7 +598,7 @@ class TestLink(unittest.TestCase):
if __name__ == '__main__':
unittest.main(verbosity=1)
buffer_read_len = 0
def targets(yp=False):
if yp:
import yappi
@ -638,26 +645,10 @@ def targets(yp=False):
buffer = None
response_data = []
def handle_buffer(ready_bytes: int):
global buffer_read_len, BUFFER_TEST_TARGET
data = buffer.read(ready_bytes)
buffer_read_len += len(data)
response_data.append(data)
if data == "Hi there".encode("utf-8"):
RNS.log("Sending response")
for data in response_data:
buffer.write(data + " back at you".encode("utf-8"))
buffer.flush()
buffer_read_len = 0
if buffer_read_len == BUFFER_TEST_TARGET:
RNS.log("Sending response")
for data in response_data:
buffer.write(data + " back at you".encode("utf-8"))
buffer.flush()
buffer_read_len = 0
buffer.write(data + " back at you".encode("utf-8"))
buffer.flush()
buffer = RNS.Buffer.create_bidirectional_buffer(0, 0, channel, handle_buffer)