Fixed callback invocation on channel receive

This commit is contained in:
Mark Qvist 2023-05-19 01:58:28 +02:00
parent 1a860c6ffd
commit d7375bc4c3

View File

@ -176,6 +176,9 @@ class Envelope:
raise ChannelException(CEType.ME_NOT_REGISTERED, f"Unable to find constructor for Channel MSGTYPE {hex(msgtype)}")
message = ctor()
message.unpack(raw)
self.unpacked = True
self.message = message
return message
def pack(self) -> bytes:
@ -183,6 +186,7 @@ class Envelope:
raise ChannelException(CEType.ME_NO_MSG_TYPE, f"{self.message.__class__} lacks MSGTYPE")
data = self.message.pack()
self.raw = struct.pack(">HHH", self.message.MSGTYPE, self.sequence, len(data)) + data
self.packed = True
return self.raw
def __init__(self, outlet: ChannelOutletBase, message: MessageBase = None, raw: bytes = None, sequence: int = None):
@ -194,6 +198,8 @@ class Envelope:
self.sequence = sequence
self.outlet = outlet
self.tries = 0
self.unpacked = False
self.packed = False
self.tracked = False
@ -371,22 +377,29 @@ class Channel(contextlib.AbstractContextManager):
def _emplace_envelope(self, envelope: Envelope, ring: collections.deque[Envelope]) -> bool:
with self._lock:
i = 0
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
for existing in ring:
if existing.sequence > envelope.sequence \
and not existing.sequence // 2 > envelope.sequence: # account for overflow
ring.insert(i, envelope)
return True
if existing.sequence == envelope.sequence:
if envelope.sequence == existing.sequence:
RNS.log(f"Envelope: Emplacement of duplicate envelope with sequence "+str(envelope.sequence), RNS.LOG_EXTREME)
return False
if envelope.sequence < existing.sequence and not envelope.sequence < window_overflow:
ring.insert(i, envelope)
RNS.log("Inserted seq "+str(envelope.sequence)+" at "+str(i), RNS.LOG_DEBUG)
envelope.tracked = True
return True
i += 1
envelope.tracked = True
ring.append(envelope)
return True
def _run_callbacks(self, message: MessageBase):
with self._lock:
cbs = self._message_callbacks.copy()
cbs = self._message_callbacks.copy()
for cb in cbs:
try:
@ -405,12 +418,11 @@ class Channel(contextlib.AbstractContextManager):
window_overflow = (self._next_rx_sequence+Channel.WINDOW_MAX) % Channel.SEQ_MODULUS
if window_overflow < self._next_rx_sequence:
if envelope.sequence > window_overflow:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
return
else:
if envelope.sequence < self._next_rx_sequence:
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_DEBUG)
return
RNS.log("Invalid packet sequence ("+str(envelope.sequence)+") received on channel "+str(self), RNS.LOG_EXTREME)
return
is_new = self._emplace_envelope(envelope, self._rx_ring)
@ -426,9 +438,13 @@ class Channel(contextlib.AbstractContextManager):
self._next_rx_sequence = (self._next_rx_sequence + 1) % Channel.SEQ_MODULUS
for e in contigous:
m = e.unpack(self._message_factories)
if not e.unpacked:
m = e.unpack(self._message_factories)
else:
m = e.message
self._rx_ring.remove(e)
threading.Thread(target=self._run_callbacks, name="Message Callback", args=[m], daemon=True).start()
self._run_callbacks(m)
except Exception as e:
RNS.log("An error ocurred while receiving data on "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR)
@ -469,7 +485,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_min += 1
# TODO: Remove at some point
RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_DEBUG)
# RNS.log("Increased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
if self._outlet.rtt != 0:
if self._outlet.rtt > Channel.RTT_FAST:
@ -483,19 +499,17 @@ class Channel(contextlib.AbstractContextManager):
if self.window_max < Channel.WINDOW_MAX_MEDIUM and self.medium_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_MEDIUM
# TODO: Remove at some point
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else:
self.fast_rate_rounds += 1
if self.window_max < Channel.WINDOW_MAX_FAST and self.fast_rate_rounds == Channel.FAST_RATE_THRESHOLD:
self.window_max = Channel.WINDOW_MAX_FAST
# TODO: Remove at some point
RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
# RNS.log("Increased "+str(self)+" max window to "+str(self.window_max), RNS.LOG_EXTREME)
else:
RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_DEBUG)
RNS.log("Envelope not found in TX ring for "+str(self), RNS.LOG_EXTREME)
if not envelope:
RNS.log("Spurious message received on "+str(self), RNS.LOG_EXTREME)
@ -525,7 +539,7 @@ class Channel(contextlib.AbstractContextManager):
self.window_max -= 1
# TODO: Remove at some point
RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
# RNS.log("Decreased "+str(self)+" window to "+str(self.window), RNS.LOG_EXTREME)
return False
@ -543,16 +557,18 @@ class Channel(contextlib.AbstractContextManager):
with self._lock:
if not self.is_ready_to_send():
raise ChannelException(CEType.ME_LINK_NOT_READY, f"Link is not ready")
envelope = Envelope(self._outlet, message=message, sequence=self._next_sequence)
self._next_sequence = (self._next_sequence + 1) % Channel.SEQ_MODULUS
self._emplace_envelope(envelope, self._tx_ring)
if envelope is None:
raise BlockingIOError()
envelope.pack()
if len(envelope.raw) > self._outlet.mdu:
raise ChannelException(CEType.ME_TOO_BIG, f"Packed message too big for packet: {len(envelope.raw)} > {self._outlet.mdu}")
envelope.packet = self._outlet.send(envelope.raw)
envelope.tries += 1
self._outlet.set_packet_delivered_callback(envelope.packet, self._packet_delivered)
@ -591,7 +607,6 @@ class LinkChannelOutlet(ChannelOutletBase):
return packet
def resend(self, packet: RNS.Packet) -> RNS.Packet:
RNS.log("Resending packet " + RNS.prettyhexrep(packet.packet_hash), RNS.LOG_DEBUG)
receipt = packet.resend()
if not receipt:
RNS.log("Failed to resend packet", RNS.LOG_ERROR)