From bf6e73e163d3380f456c872d82167efa9b4831e4 Mon Sep 17 00:00:00 2001 From: Mark Qvist Date: Sat, 11 Jan 2025 11:43:47 +0100 Subject: [PATCH] Path MTU discovery for links --- RNS/Interfaces/Interface.py | 7 ++-- RNS/Interfaces/LocalInterface.py | 3 -- RNS/Interfaces/TCPInterface.py | 3 +- RNS/Link.py | 65 ++++++++++++++++++++++++++------ RNS/Transport.py | 34 +++++++++++++++-- 5 files changed, 90 insertions(+), 22 deletions(-) diff --git a/RNS/Interfaces/Interface.py b/RNS/Interfaces/Interface.py index 40eef7d..580bca4 100755 --- a/RNS/Interfaces/Interface.py +++ b/RNS/Interfaces/Interface.py @@ -65,11 +65,12 @@ class Interface: IC_HELD_RELEASE_INTERVAL = 30 def __init__(self): - self.rxb = 0 - self.txb = 0 + self.rxb = 0 + self.txb = 0 self.created = time.time() - self.online = False + self.online = False self.bitrate = 1e6 + self.HW_MTU = None self.ingress_control = True self.ic_max_held_announces = Interface.MAX_HELD_ANNOUNCES diff --git a/RNS/Interfaces/LocalInterface.py b/RNS/Interfaces/LocalInterface.py index 9dff4c2..b9c615d 100644 --- a/RNS/Interfaces/LocalInterface.py +++ b/RNS/Interfaces/LocalInterface.py @@ -56,9 +56,6 @@ class LocalClientInterface(Interface): def __init__(self, owner, name, target_port = None, connected_socket=None): super().__init__() - # TODO: Remove at some point - # self.rxptime = 0 - self.HW_MTU = 32768 self.online = False diff --git a/RNS/Interfaces/TCPInterface.py b/RNS/Interfaces/TCPInterface.py index bd60149..b1e8ee4 100644 --- a/RNS/Interfaces/TCPInterface.py +++ b/RNS/Interfaces/TCPInterface.py @@ -96,8 +96,7 @@ class TCPClientInterface(Interface): connect_timeout = c.as_int("connect_timeout") if "connect_timeout" in c else None max_reconnect_tries = c.as_int("max_reconnect_tries") if "max_reconnect_tries" in c else None - self.HW_MTU = 32768 - + self.HW_MTU = 32768 self.IN = True self.OUT = False self.socket = None diff --git a/RNS/Link.py b/RNS/Link.py index f90adad..fe1323c 100644 --- a/RNS/Link.py +++ b/RNS/Link.py @@ -28,6 +28,7 @@ from time import sleep from .vendor import umsgpack as umsgpack import threading import inspect +import struct import math import time import RNS @@ -107,6 +108,29 @@ class Link: ACCEPT_ALL = 0x02 resource_strategies = [ACCEPT_NONE, ACCEPT_APP, ACCEPT_ALL] + @staticmethod + def mtu_bytes(mtu): + return struct.pack(">I", mtu & 0xFFFFFF)[1:] + + @staticmethod + def mtu_from_lr_packet(packet): + if len(packet.data) == Link.ECPUBSIZE+Link.LINK_MTU_SIZE: + return (packet.data[Link.ECPUBSIZE] << 16) + (packet.data[Link.ECPUBSIZE+1] << 8) + (packet.data[Link.ECPUBSIZE+2]) + else: + return None + + @staticmethod + def mtu_from_lp_packet(packet): + if len(packet.data) == RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE: + mtu_bytes = packet.data[RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE] + return (mtu_bytes[0] << 16) + (mtu_bytes[1] << 8) + (mtu_bytes[2]) + else: + return None + + @staticmethod + def link_id_from_lr_packet(packet): + return RNS.Identity.truncated_hash(packet.get_hashable_part()[:Link.ECPUBSIZE]) + @staticmethod def validate_request(owner, data, packet): if len(data) == Link.ECPUBSIZE or len(data) == Link.ECPUBSIZE+Link.LINK_MTU_SIZE: @@ -115,15 +139,17 @@ class Link: link.set_link_id(packet) if len(data) == Link.ECPUBSIZE+Link.LINK_MTU_SIZE: + RNS.log("Link request includes MTU signalling") # TODO: Remove debug try: - link.mtu = (ord(a[Link.ECPUBSIZE]) << 16) + (ord(a[Link.ECPUBSIZE+1]) << 8) + (ord(a[Link.ECPUBSIZE+2])) + link.mtu = Link.mtu_from_lr_packet(packet) or Reticulum.MTU except Exception as e: - link.mtu = Reticulum.MTU + RNS.trace_exception(e) + link.mtu = RNS.Reticulum.MTU link.destination = packet.destination link.establishment_timeout = Link.ESTABLISHMENT_TIMEOUT_PER_HOP * max(1, packet.hops) + Link.KEEPALIVE link.establishment_cost += len(packet.raw) - RNS.log(f"Validating link request {RNS.prettyhexrep(link.link_id), RNS.LOG_VERBOSE}") + RNS.log(f"Validating link request {RNS.prettyhexrep(link.link_id)}", RNS.LOG_VERBOSE) RNS.log(f"Link MTU configured to {RNS.prettysize(link.mtu)}", RNS.LOG_EXTREME) RNS.log(f"Establishment timeout is {RNS.prettytime(link.establishment_timeout)} for incoming link request "+RNS.prettyhexrep(link.link_id), RNS.LOG_EXTREME) link.handshake() @@ -143,7 +169,7 @@ class Link: return None else: - RNS.log("Invalid link request payload size, dropping request", RNS.LOG_DEBUG) + RNS.log(f"Invalid link request payload size of {len(data)} bytes, dropping request", RNS.LOG_DEBUG) return None @@ -151,7 +177,7 @@ class Link: if destination != None and destination.type != RNS.Destination.SINGLE: raise TypeError("Links can only be established to the \"single\" destination type") self.rtt = None - self.mtu = Reticulum.MTU + self.mtu = RNS.Reticulum.MTU self.establishment_cost = 0 self.establishment_rate = None self.callbacks = LinkCallbacks() @@ -218,8 +244,13 @@ class Link: if closed_callback != None: self.set_link_closed_callback(closed_callback) - if (self.initiator): - self.request_data = self.pub_bytes+self.sig_pub_bytes + if self.initiator: + link_mtu = b"" + nh_hw_mtu = RNS.Transport.next_hop_interface_hw_mtu(destination.hash) + if nh_hw_mtu: + link_mtu = Link.mtu_bytes(nh_hw_mtu) + RNS.log(f"Signalling link MTU of {RNS.prettysize(nh_hw_mtu)} for link") # TODO: Remove debug + self.request_data = self.pub_bytes+self.sig_pub_bytes+link_mtu self.packet = RNS.Packet(destination, self.request_data, packet_type=RNS.Packet.LINKREQUEST) self.packet.pack() self.establishment_cost += len(self.packet.raw) @@ -244,7 +275,7 @@ class Link: self.peer_pub.curve = Link.CURVE def set_link_id(self, packet): - self.link_id = packet.getTruncatedHash() + self.link_id = Link.link_id_from_lr_packet(packet) self.hash = self.link_id def handshake(self): @@ -263,10 +294,14 @@ class Link: def prove(self): - signed_data = self.link_id+self.pub_bytes+self.sig_pub_bytes + mtu_bytes = b"" + if self.mtu != RNS.Reticulum.MTU: + mtu_bytes = Link.mtu_bytes(self.mtu) + + signed_data = self.link_id+self.pub_bytes+self.sig_pub_bytes+mtu_bytes signature = self.owner.identity.sign(signed_data) - proof_data = signature+self.pub_bytes + proof_data = signature+self.pub_bytes+mtu_bytes proof = RNS.Packet(self, proof_data, packet_type=RNS.Packet.PROOF, context=RNS.Packet.LRPROOF) proof.send() self.establishment_cost += len(proof.raw) @@ -289,6 +324,14 @@ class Link: def validate_proof(self, packet): try: if self.status == Link.PENDING: + mtu_bytes = b"" + confirmed_mtu = None + if len(packet.data) == RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2+Link.LINK_MTU_SIZE: + confirmed_mtu = Link.mtu_from_lp_packet(packet) + mtu_bytes = Link.mtu_bytes(confirmed_mtu) + packet.data = packet.data[:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2] + RNS.log(f"Destination confirmed link MTU of {RNS.prettysize(confirmed_mtu)}") # TODO: Remove debug + if self.initiator and len(packet.data) == RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2: peer_pub_bytes = packet.data[RNS.Identity.SIGLENGTH//8:RNS.Identity.SIGLENGTH//8+Link.ECPUBSIZE//2] peer_sig_pub_bytes = self.destination.identity.get_public_key()[Link.ECPUBSIZE//2:Link.ECPUBSIZE] @@ -296,7 +339,7 @@ class Link: self.handshake() self.establishment_cost += len(packet.raw) - signed_data = self.link_id+self.peer_pub_bytes+self.peer_sig_pub_bytes + signed_data = self.link_id+self.peer_pub_bytes+self.peer_sig_pub_bytes+mtu_bytes signature = packet.data[:RNS.Identity.SIGLENGTH//8] if self.destination.identity.validate(signature, signed_data): diff --git a/RNS/Transport.py b/RNS/Transport.py index a6b0d6e..c6f1e9f 100755 --- a/RNS/Transport.py +++ b/RNS/Transport.py @@ -1282,6 +1282,22 @@ class Transport: now = time.time() proof_timeout = Transport.extra_link_proof_timeout(packet.receiving_interface) proof_timeout += now + RNS.Link.ESTABLISHMENT_TIMEOUT_PER_HOP * max(1, remaining_hops) + + path_mtu = RNS.Link.mtu_from_lr_packet(packet) + nh_mtu = outbound_interface.HW_MTU + if path_mtu: + RNS.log(f"Seeing transported LR path MTU of {RNS.prettysize(path_mtu)}") # TODO: Remove debug + if outbound_interface.HW_MTU == None: + RNS.log(f"No next-hop HW MTU, disabling link MTU upgrade") # TODO: Remove debug + path_mtu = None + new_raw = new_raw[:RNS.Link.ECPUBSIZE] + else: + if nh_mtu < path_mtu: + path_mtu = nh_mtu + clamped_mtu = RNS.Link.mtu_bytes(path_mtu) + RNS.log(f"Clamping link MTU to {RNS.prettysize(nh_mtu)}: {RNS.hexrep(clamped_mtu)}") # TODO: Remove debug + RNS.log(f"New raw: {RNS.hexrep(new_raw)}") + new_raw = new_raw[:-RNS.Link.LINK_MTU_SIZE]+clamped_mtu # Entry format is link_entry = [ now, # 0: Timestamp, @@ -1294,7 +1310,7 @@ class Transport: False, # 7: Validated proof_timeout] # 8: Proof timeout timestamp - Transport.link_table[packet.getTruncatedHash()] = link_entry + Transport.link_table[RNS.Link.link_id_from_lr_packet(packet)] = link_entry else: # Entry format is @@ -1790,12 +1806,16 @@ class Transport: if packet.hops == link_entry[3]: if packet.receiving_interface == link_entry[2]: try: - if len(packet.data) == RNS.Identity.SIGLENGTH//8+RNS.Link.ECPUBSIZE//2: + if len(packet.data) == RNS.Identity.SIGLENGTH//8+RNS.Link.ECPUBSIZE//2 or len(packet.data) == RNS.Identity.SIGLENGTH//8+RNS.Link.ECPUBSIZE//2+RNS.Link.LINK_MTU_SIZE: + mtu_bytes = b"" + if len(packet.data) == RNS.Identity.SIGLENGTH//8+RNS.Link.ECPUBSIZE//2+RNS.Link.LINK_MTU_SIZE: + mtu_bytes = RNS.Link.mtu_bytes(RNS.Link.mtu_from_lp_packet(packet)) + peer_pub_bytes = packet.data[RNS.Identity.SIGLENGTH//8:RNS.Identity.SIGLENGTH//8+RNS.Link.ECPUBSIZE//2] peer_identity = RNS.Identity.recall(link_entry[6]) peer_sig_pub_bytes = peer_identity.get_public_key()[RNS.Link.ECPUBSIZE//2:RNS.Link.ECPUBSIZE] - signed_data = packet.destination_hash+peer_pub_bytes+peer_sig_pub_bytes + signed_data = packet.destination_hash+peer_pub_bytes+peer_sig_pub_bytes+mtu_bytes signature = packet.data[:RNS.Identity.SIGLENGTH//8] if peer_identity.validate(signature, signed_data): @@ -2189,6 +2209,14 @@ class Transport: else: return None + @staticmethod + def next_hop_interface_hw_mtu(destination_hash): + next_hop_interface = Transport.next_hop_interface(destination_hash) + if next_hop_interface != None: + return next_hop_interface.HW_MTU + else: + return None + @staticmethod def next_hop_per_bit_latency(destination_hash): next_hop_interface_bitrate = Transport.next_hop_interface_bitrate(destination_hash)