diff --git a/.gitignore b/.gitignore index c217835..de4e3f9 100755 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .DS_Store *.pyc t.py +t2.py TODO diff --git a/Notes/Header format b/Notes/Header format index 9464591..41e3409 100644 --- a/Notes/Header format +++ b/Notes/Header format @@ -1,8 +1,8 @@ header types ----------------- -type 1 00 One byte header, one 10 byte address field -type 2 01 One byte header, two 10 byte address fields -type 3 10 Reserved +type 1 00 Two byte header, one 10 byte address field +type 2 01 Two byte header, two 10 byte address fields +type 3 10 Two byte header, one 10 byte address field, used for link request proofs type 4 11 Reserved for extended header format diff --git a/RNS/Destination.py b/RNS/Destination.py index 3279954..402c3fc 100755 --- a/RNS/Destination.py +++ b/RNS/Destination.py @@ -9,6 +9,11 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import padding +class Callbacks: + def __init__(self): + self.link_established = None + self.packet = None + self.proof = None class Destination: KEYSIZE = RNS.Identity.KEYSIZE; @@ -59,11 +64,14 @@ class Destination: if "." in app_name: raise ValueError("Dots can't be used in app names") if not type in Destination.types: raise ValueError("Unknown destination type") if not direction in Destination.directions: raise ValueError("Unknown destination direction") + self.callbacks = Callbacks() self.type = type self.direction = direction self.proof_strategy = Destination.PROVE_NONE self.mtu = 0 + self.links = [] + if identity != None and type == Destination.SINGLE: aspects = aspects+(identity.hexhash,) @@ -87,11 +95,14 @@ class Destination: return "<"+self.name+"/"+self.hexhash+">" - def setCallback(self, callback): - self.callback = callback + def link_established_callback(self, callback): + self.callbacks.link_established = callback - def setProofCallback(self, callback): - self.proofcallback = callback + def packet_callback(self, callback): + self.callbacks.packet = callback + + def proof_callback(self, callback): + self.callbacks.proof = callback def setProofStrategy(self, proof_strategy): if not proof_strategy in Destination.proof_strategies: @@ -101,9 +112,19 @@ class Destination: def receive(self, packet): plaintext = self.decrypt(packet.data) - if plaintext != None and self.callback != None: - self.callback(plaintext, packet) + if plaintext != None: + if packet.packet_type == RNS.Packet.LINKREQUEST: + self.incomingLinkRequest(plaintext, packet) + if packet.packet_type == RNS.Packet.RESOURCE: + if self.callbacks.packet != None: + self.callbacks.packet(plaintext, packet) + + def incomingLinkRequest(self, data, packet): + link = RNS.Link.validateRequest(self, data, packet) + if link != None: + RNS.log(str(self)+" accepted link request", RNS.LOG_DEBUG) + self.links.append(link) def createKeys(self): if self.type == Destination.PLAIN: diff --git a/RNS/Identity.py b/RNS/Identity.py index 243569e..c652aa9 100644 --- a/RNS/Identity.py +++ b/RNS/Identity.py @@ -14,9 +14,9 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.asymmetric import padding class Identity: - # Configure key size - KEYSIZE = 1536 - DERKEYSIZE = 1808 + #KEYSIZE = 1536 + KEYSIZE = 1024 + DERKEYSIZE = KEYSIZE+272 # Padding size, not configurable PADDINGSIZE= 336 @@ -223,7 +223,7 @@ class Identity: ) ) except: - RNS.log("Decryption by "+RNS.prettyhexrep(self.hash)+" failed") + RNS.log("Decryption by "+RNS.prettyhexrep(self.hash)+" failed", RNS.LOG_VERBOSE) return plaintext; else: diff --git a/RNS/Link.py b/RNS/Link.py new file mode 100644 index 0000000..328894a --- /dev/null +++ b/RNS/Link.py @@ -0,0 +1,174 @@ +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.fernet import Fernet +import base64 +import RNS + +import traceback + +class LinkCallbacks: + def __init__(self): + self.link_established = None + self.packet = None + self.resource_started = None + self.resource_completed = None + +class Link: + CURVE = ec.SECP256R1() + ECPUBSIZE = 91 + + PENDING = 0x00 + ACTIVE = 0x01 + + @staticmethod + def validateRequest(owner, data, packet): + if len(data) == (Link.ECPUBSIZE): + try: + link = Link(owner = owner, peer_pub_bytes = data[:Link.ECPUBSIZE]) + link.setLinkID(packet) + RNS.log("Validating link request "+RNS.prettyhexrep(link.link_id), RNS.LOG_VERBOSE) + link.handshake() + link.attached_interface = packet.receiving_interface + link.prove() + RNS.Transport.registerLink(link) + if link.owner.callbacks.link_established != None: + link.owner.callbacks.link_established(link) + RNS.log("Incoming link request "+str(link)+" accepted", RNS.LOG_VERBOSE) + + except Exception as e: + RNS.log("Validating link request failed", RNS.LOG_VERBOSE) + return None + + + else: + RNS.log("Invalid link request payload size, dropping request", RNS.LOG_VERBOSE) + return None + + + def __init__(self, destination=None, owner=None, peer_pub_bytes = None): + self.callbacks = LinkCallbacks() + self.status = Link.PENDING + self.type = RNS.Destination.LINK + self.owner = owner + self.destination = destination + self.attached_interface = None + if self.destination == None: + self.initiator = False + else: + self.initiator = True + + self.prv = ec.generate_private_key(Link.CURVE, default_backend()) + self.pub = self.prv.public_key() + self.pub_bytes = self.pub.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + if peer_pub_bytes == None: + self.peer_pub = None + self.peer_pub_bytes = None + else: + self.loadPeer(peer_pub_bytes) + + if (self.initiator): + self.request_data = self.pub_bytes + self.packet = RNS.Packet(destination, self.request_data, packet_type=RNS.Packet.LINKREQUEST) + self.packet.pack() + self.setLinkID(self.packet) + RNS.Transport.registerLink(self) + self.packet.send() + RNS.log("Link request "+RNS.prettyhexrep(self.link_id)+" sent to "+str(self.destination), RNS.LOG_VERBOSE) + + + def loadPeer(self, peer_pub_bytes): + self.peer_pub_bytes = peer_pub_bytes + self.peer_pub = serialization.load_der_public_key(peer_pub_bytes, backend=default_backend()) + self.peer_pub.curce = Link.CURVE + + def setLinkID(self, packet): + self.link_id = RNS.Identity.truncatedHash(packet.raw) + self.hash = self.link_id + + def handshake(self): + self.shared_key = self.prv.exchange(ec.ECDH(), self.peer_pub) + self.derived_key = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=self.getSalt(), + info=self.getContext(), + backend=default_backend() + ).derive(self.shared_key) + + def prove(self): + signed_data = self.link_id+self.pub_bytes + signature = self.owner.identity.sign(signed_data) + + proof_data = self.pub_bytes+signature + proof = RNS.Packet(self, proof_data, packet_type=RNS.Packet.PROOF, header_type=RNS.Packet.HEADER_3) + proof.send() + + def validateProof(self, packet): + peer_pub_bytes = packet.data[:Link.ECPUBSIZE] + signed_data = self.link_id+peer_pub_bytes + signature = packet.data[Link.ECPUBSIZE:RNS.Identity.KEYSIZE/8+Link.ECPUBSIZE] + + if self.destination.identity.validate(signature, signed_data): + self.loadPeer(peer_pub_bytes) + self.handshake() + self.attached_interface = packet.receiving_interface + RNS.Transport.activateLink(self) + if self.callbacks.link_established != None: + self.callbacks.link_established(self) + RNS.log("Link "+str(self)+" established with "+str(self.destination), RNS.LOG_VERBOSE) + else: + RNS.log("Invalid link proof signature received by "+str(self), RNS.LOG_VERBOSE) + + + def getSalt(self): + return self.link_id + + def getContext(self): + return None + + def receive(self, packet): + if packet.receiving_interface != self.attached_interface: + RNS.log("Link-associated packet received on unexpected interface! Someone might be trying to manipulate your communication!", RNS.LOG_ERROR) + else: + plaintext = self.decrypt(packet.data) + if (self.callbacks.packet != None): + self.callbacks.packet(plaintext, packet) + + def encrypt(self, plaintext): + try: + fernet = Fernet(base64.urlsafe_b64encode(self.derived_key)) + ciphertext = base64.urlsafe_b64decode(fernet.encrypt(plaintext)) + return ciphertext + except Exception as e: + RNS.log("Encryption on link "+str(self)+" failed. The contained exception was: "+str(e), RNS.LOG_ERROR) + + + def decrypt(self, ciphertext): + try: + fernet = Fernet(base64.urlsafe_b64encode(self.derived_key)) + plaintext = fernet.decrypt(base64.urlsafe_b64encode(ciphertext)) + return plaintext + except Exception as e: + RNS.log("Decryption failed on link "+str(self)+". The contained exception was: "+str(e), RNS.LOG_ERROR) + + def link_established_callback(self, callback): + self.callbacks.link_established = callback + + def packet_callback(self, callback): + self.callbacks.packet = callback + + def resource_started_callback(self, callback): + self.callbacks.resource_started = callback + + def resource_completed_callback(self, callback): + self.callbacks.resource_completed = callback + + def __str__(self): + return RNS.prettyhexrep(self.link_id) \ No newline at end of file diff --git a/RNS/Packet.py b/RNS/Packet.py index 9c5a475..05e030e 100755 --- a/RNS/Packet.py +++ b/RNS/Packet.py @@ -12,7 +12,7 @@ class Packet: HEADER_1 = 0x00; # Normal header format HEADER_2 = 0x01; # Header format used for link packets in transport - HEADER_3 = 0x02; # Reserved + HEADER_3 = 0x02; # Normal header format, but used to indicate a link request proof HEADER_4 = 0x03; # Reserved header_types = [HEADER_1, HEADER_2, HEADER_3, HEADER_4] @@ -30,7 +30,7 @@ class Packet: self.transport_id = transport_id self.data = data self.flags = self.getPackedFlags() - self.MTU = self.destination.MTU + self.MTU = RNS.Reticulum.MTU self.raw = None self.packed = False @@ -45,7 +45,10 @@ class Packet: self.packet_hash = None def getPackedFlags(self): - packed_flags = (self.header_type << 6) | (self.transport_type << 4) | (self.destination.type << 2) | self.packet_type + if self.header_type == Packet.HEADER_3: + packed_flags = (self.header_type << 6) | (self.transport_type << 4) | RNS.Destination.LINK | self.packet_type + else: + packed_flags = (self.header_type << 6) | (self.transport_type << 4) | (self.destination.type << 2) | self.packet_type return packed_flags def pack(self): @@ -58,10 +61,14 @@ class Packet: else: raise IOError("Packet with header type 2 must have a transport ID") - self.header += self.destination.hash - if self.packet_type != Packet.ANNOUNCE: - self.ciphertext = self.destination.encrypt(self.data) - else: + if self.header_type == Packet.HEADER_1: + self.header += self.destination.hash + if self.packet_type != Packet.ANNOUNCE: + self.ciphertext = self.destination.encrypt(self.data) + else: + self.ciphertext = self.data + if self.header_type == Packet.HEADER_3: + self.header += self.destination.link_id self.ciphertext = self.data self.raw = self.header + self.ciphertext @@ -93,9 +100,10 @@ class Packet: def send(self): if not self.sent: - self.pack() - #RNS.log("Size: "+str(len(self.raw))+" header is "+str(len(self.header))+" payload is "+str(len(self.ciphertext)), RNS.LOG_DEBUG) - RNS.Transport.outbound(self.raw) + if not self.packed: + self.pack() + + RNS.Transport.outbound(self) self.packet_hash = RNS.Identity.fullHash(self.raw) self.sent_at = time.time() self.sent = True diff --git a/RNS/Reticulum.py b/RNS/Reticulum.py index 07ad9f3..9cfedbf 100755 --- a/RNS/Reticulum.py +++ b/RNS/Reticulum.py @@ -8,7 +8,7 @@ import os.path import os import RNS -import traceback +#import traceback class Reticulum: MTU = 500 diff --git a/RNS/Transport.py b/RNS/Transport.py index a20730b..a5a649f 100755 --- a/RNS/Transport.py +++ b/RNS/Transport.py @@ -10,15 +10,23 @@ class Transport: interfaces = [] destinations = [] + pending_links = [] + active_links = [] packet_hashlist = [] @staticmethod - def outbound(raw): - Transport.cacheRaw(raw) + def outbound(packet): + Transport.cacheRaw(packet.raw) for interface in Transport.interfaces: if interface.OUT: - RNS.log("Transmitting "+str(len(raw))+" bytes via: "+str(interface), RNS.LOG_DEBUG) - interface.processOutgoing(raw) + should_transmit = True + if packet.destination.type == RNS.Destination.LINK: + if interface != packet.destination.attached_interface: + should_transmit = False + + if should_transmit: + RNS.log("Transmitting "+str(len(packet.raw))+" bytes via: "+str(interface), RNS.LOG_DEBUG) + interface.processOutgoing(packet.raw) @staticmethod def inbound(raw, interface=None): @@ -30,24 +38,45 @@ class Transport: packet = RNS.Packet(None, raw) packet.unpack() packet.packet_hash = packet_hash + packet.receiving_interface = interface if packet.packet_type == RNS.Packet.ANNOUNCE: if RNS.Identity.validateAnnounce(packet): Transport.cache(packet) - - if packet.packet_type == RNS.Packet.RESOURCE: + + if packet.packet_type == RNS.Packet.LINKREQUEST: for destination in Transport.destinations: if destination.hash == packet.destination_hash and destination.type == packet.destination_type: packet.destination = destination destination.receive(packet) Transport.cache(packet) + + if packet.packet_type == RNS.Packet.RESOURCE: + if packet.destination_type == RNS.Destination.LINK: + for link in Transport.active_links: + if link.link_id == packet.destination_hash: + link.receive(packet) + Transport.cache(packet) + else: + for destination in Transport.destinations: + if destination.hash == packet.destination_hash and destination.type == packet.destination_type: + packet.destination = destination + destination.receive(packet) + Transport.cache(packet) if packet.packet_type == RNS.Packet.PROOF: - for destination in Transport.destinations: - if destination.hash == packet.destination_hash: - if destination.proofcallback != None: - destination.proofcallback(packet) - # TODO: add universal proof handling + if packet.header_type == RNS.Packet.HEADER_3: + # This is a link request proof, forward + # to a waiting link request + for link in Transport.pending_links: + if link.link_id == packet.destination_hash: + link.validateProof(packet) + else: + for destination in Transport.destinations: + if destination.hash == packet.destination_hash: + if destination.proofcallback != None: + destination.proofcallback(packet) + # TODO: add universal proof handling @staticmethod def registerDestination(destination): @@ -55,9 +84,35 @@ class Transport: if destination.direction == RNS.Destination.IN: Transport.destinations.append(destination) + @staticmethod + def registerLink(link): + RNS.log("Registering link "+str(link)) + if link.initiator: + Transport.pending_links.append(link) + else: + Transport.active_links.append(link) + + @staticmethod + def activateLink(link): + RNS.log("Activating link "+str(link)) + if link in Transport.pending_links: + Transport.pending_links.remove(link) + Transport.active_links.append(link) + link.status = RNS.Link.ACTIVE + else: + RNS.log("Attempted to activate a link that was not in the pending table", RNS.LOG_ERROR) + + + @staticmethod + def shouldCache(packet): + # TODO: Implement sensible rules for which + # packets to cache + return False + @staticmethod def cache(packet): - RNS.Transport.cacheRaw(packet.raw) + if RNS.Transport.shouldCache(packet): + RNS.Transport.cacheRaw(packet.raw) @staticmethod def cacheRaw(raw): diff --git a/RNS/__init__.py b/RNS/__init__.py index d44b171..f566c32 100755 --- a/RNS/__init__.py +++ b/RNS/__init__.py @@ -4,6 +4,7 @@ import time from .Reticulum import Reticulum from .Identity import Identity +from .Link import Link from .Transport import Transport from .Destination import Destination from .Packet import Packet