Added per-packet compression to buffer

This commit is contained in:
Mark Qvist 2023-05-09 22:13:57 +02:00
parent d96a4853fe
commit f522cb1db1
3 changed files with 23 additions and 9 deletions

View File

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
import bz2
import sys import sys
import threading import threading
from threading import RLock from threading import RLock
@ -9,7 +10,6 @@ from io import RawIOBase, BufferedRWPair, BufferedReader, BufferedWriter
from typing import Callable from typing import Callable
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
class StreamDataMessage(MessageBase): class StreamDataMessage(MessageBase):
MSGTYPE = SystemMessageTypes.SMT_STREAM_DATA MSGTYPE = SystemMessageTypes.SMT_STREAM_DATA
""" """
@ -17,9 +17,9 @@ class StreamDataMessage(MessageBase):
uses a system-reserved message type. uses a system-reserved message type.
""" """
STREAM_ID_MAX = 0x7fff # 32767 STREAM_ID_MAX = 0x3fff # 16383
""" """
The stream id is limited to 2 bytes - 1 bit The stream id is limited to 2 bytes - 2 bit
""" """
MAX_DATA_LEN = RNS.Link.MDU - 2 - 6 # 2 for stream data message header, 6 for channel envelope MAX_DATA_LEN = RNS.Link.MDU - 2 - 6 # 2 for stream data message header, 6 for channel envelope
@ -39,22 +39,34 @@ class StreamDataMessage(MessageBase):
""" """
super().__init__() super().__init__()
if stream_id is not None and stream_id > self.STREAM_ID_MAX: if stream_id is not None and stream_id > self.STREAM_ID_MAX:
raise ValueError("stream_id must be 0-32767") raise ValueError("stream_id must be 0-16383")
self.stream_id = stream_id self.stream_id = stream_id
self.compressed = False
self.data = data or bytes() self.data = data or bytes()
self.eof = eof self.eof = eof
def pack(self) -> bytes: def pack(self) -> bytes:
if self.stream_id is None: if self.stream_id is None:
raise ValueError("stream_id") raise ValueError("stream_id")
header_val = (0x7fff & self.stream_id) | (0x8000 if self.eof else 0x0000)
compressed_data = bz2.compress(self.data)
saved = len(self.data)-len(compressed_data)
if saved > 0:
self.data = compressed_data
self.compressed = True
header_val = (0x3fff & self.stream_id) | (0x8000 if self.eof else 0x0000) | (0x4000 if self.compressed > 0 else 0x0000)
return bytes(struct.pack(">H", header_val) + (self.data if self.data else bytes())) return bytes(struct.pack(">H", header_val) + (self.data if self.data else bytes()))
def unpack(self, raw): def unpack(self, raw):
self.stream_id = struct.unpack(">H", raw[:2])[0] self.stream_id = struct.unpack(">H", raw[:2])[0]
self.eof = (0x8000 & self.stream_id) > 0 self.eof = (0x8000 & self.stream_id) > 0
self.stream_id = self.stream_id & 0x7fff self.compressed = (0x4000 & self.stream_id) > 0
self.stream_id = self.stream_id & 0x3fff
self.data = raw[2:] self.data = raw[2:]
if self.compressed:
self.data = bz2.decompress(self.data)
class RawChannelReader(RawIOBase, AbstractContextManager): class RawChannelReader(RawIOBase, AbstractContextManager):

View File

@ -358,12 +358,12 @@ class Channel(contextlib.AbstractContextManager):
message = envelope.unpack(self._message_factories) message = envelope.unpack(self._message_factories)
prev_env = self._rx_ring[0] if len(self._rx_ring) > 0 else None prev_env = self._rx_ring[0] if len(self._rx_ring) > 0 else None
if prev_env and envelope.sequence != (prev_env.sequence + 1) % 0x10000: if prev_env and envelope.sequence != (prev_env.sequence + 1) % 0x10000:
RNS.log("Channel: Out of order packet received", RNS.LOG_DEBUG) RNS.log("Channel: Out of order packet received", RNS.LOG_EXTREME)
return
is_new = self._emplace_envelope(envelope, self._rx_ring) is_new = self._emplace_envelope(envelope, self._rx_ring)
self._prune_rx_ring() self._prune_rx_ring()
if not is_new: if not is_new:
RNS.log("Channel: Duplicate message received", RNS.LOG_DEBUG) RNS.log("Channel: Duplicate message received", RNS.LOG_EXTREME)
return return
RNS.log(f"Message received: {message}", RNS.LOG_DEBUG) RNS.log(f"Message received: {message}", RNS.LOG_DEBUG)
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start() threading.Thread(target=self._run_callbacks, name="Message Callback", args=[message], daemon=True).start()

View File

@ -521,6 +521,8 @@ class TestLink(unittest.TestCase):
if time.time() < expected_ready_time: if time.time() < expected_ready_time:
time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1)) time.sleep(max(c_rns.MTU * 2 / local_interface.bitrate * 8, 1))
time.sleep(0.25)
# Why does this not always work out correctly? # Why does this not always work out correctly?
# self.assertEqual(expected_chunk_count, len(received)) # self.assertEqual(expected_chunk_count, len(received))