From 481062fca10966cd2b94bc033179c4fe31145495 Mon Sep 17 00:00:00 2001 From: Mark Qvist Date: Mon, 18 Sep 2023 00:39:27 +0200 Subject: [PATCH] Added adaptive compression to Buffer class --- RNS/Buffer.py | 49 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/RNS/Buffer.py b/RNS/Buffer.py index 0dd7e18..078e15e 100644 --- a/RNS/Buffer.py +++ b/RNS/Buffer.py @@ -29,7 +29,7 @@ class StreamDataMessage(MessageBase): calculcated based on the value of OVERHEAD """ - def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False): + def __init__(self, stream_id: int = None, data: bytes = None, eof: bool = False, compressed: bool = False): """ This class is used to encapsulate binary stream data to be sent over a ``Channel``. @@ -42,7 +42,7 @@ class StreamDataMessage(MessageBase): if stream_id is not None and stream_id > self.STREAM_ID_MAX: raise ValueError("stream_id must be 0-16383") self.stream_id = stream_id - self.compressed = False + self.compressed = compressed self.data = data or bytes() self.eof = eof @@ -50,13 +50,6 @@ class StreamDataMessage(MessageBase): if self.stream_id is None: raise ValueError("stream_id") - compressed_data = bz2.compress(self.data) - saved = len(self.data)-len(compressed_data) - - if saved > 0: - self.data = compressed_data - self.compressed = True - header_val = (0x3fff & self.stream_id) | (0x8000 if self.eof else 0x0000) | (0x4000 if self.compressed > 0 else 0x0000) return bytes(struct.pack(">H", header_val) + (self.data if self.data else bytes())) @@ -133,7 +126,7 @@ class RawChannelReader(RawIOBase, AbstractContextManager): try: threading.Thread(target=listener, name="Message Callback", args=[len(self._buffer)], daemon=True).start() except Exception as ex: - RNS.log("Error calling RawChannelReader(" + str(self._stream_id) + ") callback: " + str(ex)) + RNS.log("Error calling RawChannelReader(" + str(self._stream_id) + ") callback: " + str(ex), RNS.LOG_ERROR) return True return False @@ -186,6 +179,10 @@ class RawChannelWriter(RawIOBase, AbstractContextManager): object, see the Python documentation for ``RawIOBase``. """ + + MAX_CHUNK_LEN = 1024*16 + COMPRESSION_TRIES = 4 + def __init__(self, stream_id: int, channel: Channel): """ Create a raw channel writer. @@ -199,10 +196,36 @@ class RawChannelWriter(RawIOBase, AbstractContextManager): def write(self, __b: bytes) -> int | None: try: - chunk = bytes(__b[:StreamDataMessage.MAX_DATA_LEN]) - message = StreamDataMessage(self._stream_id, chunk, self._eof) + comp_tries = RawChannelWriter.COMPRESSION_TRIES + comp_try = 1 + comp_success = False + chunk_len = len(__b) + if chunk_len > RawChannelWriter.MAX_CHUNK_LEN: + chunk_len = RawChannelWriter.MAX_CHUNK_LEN + __b = __b[:RawChannelWriter.MAX_CHUNK_LEN] + chunk_segment = None + while chunk_len > 32 and comp_try < comp_tries: + chunk_segment_length = int(chunk_len/comp_try) + compressed_chunk = bz2.compress(__b[:chunk_segment_length]) + compressed_length = len(compressed_chunk) + if compressed_length < StreamDataMessage.MAX_DATA_LEN and compressed_length < chunk_segment_length: + comp_success = True + break + else: + comp_try += 1 + + if comp_success: + chunk = compressed_chunk + processed_length = chunk_segment_length + else: + chunk = bytes(__b[:StreamDataMessage.MAX_DATA_LEN]) + processed_length = len(chunk) + + message = StreamDataMessage(self._stream_id, chunk, self._eof, comp_success) + self._channel.send(message) - return len(chunk) + return processed_length + except RNS.Channel.ChannelException as cex: if cex.type != RNS.Channel.CEType.ME_LINK_NOT_READY: raise