mirror of
https://github.com/markqvist/Reticulum.git
synced 2024-11-25 07:00:17 +00:00
369 lines
13 KiB
Python
369 lines
13 KiB
Python
|
# MIT License
|
||
|
#
|
||
|
# Copyright (c) 2015 Brian Warner and other contributors
|
||
|
|
||
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||
|
# of this software and associated documentation files (the "Software"), to deal
|
||
|
# in the Software without restriction, including without limitation the rights
|
||
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||
|
# copies of the Software, and to permit persons to whom the Software is
|
||
|
# furnished to do so, subject to the following conditions:
|
||
|
#
|
||
|
# The above copyright notice and this permission notice shall be included in all
|
||
|
# copies or substantial portions of the Software.
|
||
|
#
|
||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||
|
# SOFTWARE.
|
||
|
|
||
|
import binascii, hashlib, itertools
|
||
|
|
||
|
Q = 2**255 - 19
|
||
|
L = 2**252 + 27742317777372353535851937790883648493
|
||
|
|
||
|
def inv(x):
|
||
|
return pow(x, Q-2, Q)
|
||
|
|
||
|
d = -121665 * inv(121666)
|
||
|
I = pow(2,(Q-1)//4,Q)
|
||
|
|
||
|
def xrecover(y):
|
||
|
xx = (y*y-1) * inv(d*y*y+1)
|
||
|
x = pow(xx,(Q+3)//8,Q)
|
||
|
if (x*x - xx) % Q != 0: x = (x*I) % Q
|
||
|
if x % 2 != 0: x = Q-x
|
||
|
return x
|
||
|
|
||
|
By = 4 * inv(5)
|
||
|
Bx = xrecover(By)
|
||
|
B = [Bx % Q,By % Q]
|
||
|
|
||
|
# Extended Coordinates: x=X/Z, y=Y/Z, x*y=T/Z
|
||
|
# http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
|
||
|
|
||
|
def xform_affine_to_extended(pt):
|
||
|
(x, y) = pt
|
||
|
return (x%Q, y%Q, 1, (x*y)%Q) # (X,Y,Z,T)
|
||
|
|
||
|
def xform_extended_to_affine(pt):
|
||
|
(x, y, z, _) = pt
|
||
|
return ((x*inv(z))%Q, (y*inv(z))%Q)
|
||
|
|
||
|
def double_element(pt): # extended->extended
|
||
|
# dbl-2008-hwcd
|
||
|
(X1, Y1, Z1, _) = pt
|
||
|
A = (X1*X1)
|
||
|
B = (Y1*Y1)
|
||
|
C = (2*Z1*Z1)
|
||
|
D = (-A) % Q
|
||
|
J = (X1+Y1) % Q
|
||
|
E = (J*J-A-B) % Q
|
||
|
G = (D+B) % Q
|
||
|
F = (G-C) % Q
|
||
|
H = (D-B) % Q
|
||
|
X3 = (E*F) % Q
|
||
|
Y3 = (G*H) % Q
|
||
|
Z3 = (F*G) % Q
|
||
|
T3 = (E*H) % Q
|
||
|
return (X3, Y3, Z3, T3)
|
||
|
|
||
|
def add_elements(pt1, pt2): # extended->extended
|
||
|
# add-2008-hwcd-3 . Slightly slower than add-2008-hwcd-4, but -3 is
|
||
|
# unified, so it's safe for general-purpose addition
|
||
|
(X1, Y1, Z1, T1) = pt1
|
||
|
(X2, Y2, Z2, T2) = pt2
|
||
|
A = ((Y1-X1)*(Y2-X2)) % Q
|
||
|
B = ((Y1+X1)*(Y2+X2)) % Q
|
||
|
C = T1*(2*d)*T2 % Q
|
||
|
D = Z1*2*Z2 % Q
|
||
|
E = (B-A) % Q
|
||
|
F = (D-C) % Q
|
||
|
G = (D+C) % Q
|
||
|
H = (B+A) % Q
|
||
|
X3 = (E*F) % Q
|
||
|
Y3 = (G*H) % Q
|
||
|
T3 = (E*H) % Q
|
||
|
Z3 = (F*G) % Q
|
||
|
return (X3, Y3, Z3, T3)
|
||
|
|
||
|
def scalarmult_element_safe_slow(pt, n):
|
||
|
# this form is slightly slower, but tolerates arbitrary points, including
|
||
|
# those which are not in the main 1*L subgroup. This includes points of
|
||
|
# order 1 (the neutral element Zero), 2, 4, and 8.
|
||
|
assert n >= 0
|
||
|
if n==0:
|
||
|
return xform_affine_to_extended((0,1))
|
||
|
_ = double_element(scalarmult_element_safe_slow(pt, n>>1))
|
||
|
return add_elements(_, pt) if n&1 else _
|
||
|
|
||
|
def _add_elements_nonunfied(pt1, pt2): # extended->extended
|
||
|
# add-2008-hwcd-4 : NOT unified, only for pt1!=pt2. About 10% faster than
|
||
|
# the (unified) add-2008-hwcd-3, and safe to use inside scalarmult if you
|
||
|
# aren't using points of order 1/2/4/8
|
||
|
(X1, Y1, Z1, T1) = pt1
|
||
|
(X2, Y2, Z2, T2) = pt2
|
||
|
A = ((Y1-X1)*(Y2+X2)) % Q
|
||
|
B = ((Y1+X1)*(Y2-X2)) % Q
|
||
|
C = (Z1*2*T2) % Q
|
||
|
D = (T1*2*Z2) % Q
|
||
|
E = (D+C) % Q
|
||
|
F = (B-A) % Q
|
||
|
G = (B+A) % Q
|
||
|
H = (D-C) % Q
|
||
|
X3 = (E*F) % Q
|
||
|
Y3 = (G*H) % Q
|
||
|
Z3 = (F*G) % Q
|
||
|
T3 = (E*H) % Q
|
||
|
return (X3, Y3, Z3, T3)
|
||
|
|
||
|
def scalarmult_element(pt, n): # extended->extended
|
||
|
# This form only works properly when given points that are a member of
|
||
|
# the main 1*L subgroup. It will give incorrect answers when called with
|
||
|
# the points of order 1/2/4/8, including point Zero. (it will also work
|
||
|
# properly when given points of order 2*L/4*L/8*L)
|
||
|
assert n >= 0
|
||
|
if n==0:
|
||
|
return xform_affine_to_extended((0,1))
|
||
|
_ = double_element(scalarmult_element(pt, n>>1))
|
||
|
return _add_elements_nonunfied(_, pt) if n&1 else _
|
||
|
|
||
|
# points are encoded as 32-bytes little-endian, b255 is sign, b2b1b0 are 0
|
||
|
|
||
|
def encodepoint(P):
|
||
|
x = P[0]
|
||
|
y = P[1]
|
||
|
# MSB of output equals x.b0 (=x&1)
|
||
|
# rest of output is little-endian y
|
||
|
assert 0 <= y < (1<<255) # always < 0x7fff..ff
|
||
|
if x & 1:
|
||
|
y += 1<<255
|
||
|
return binascii.unhexlify("%064x" % y)[::-1]
|
||
|
|
||
|
def isoncurve(P):
|
||
|
x = P[0]
|
||
|
y = P[1]
|
||
|
return (-x*x + y*y - 1 - d*x*x*y*y) % Q == 0
|
||
|
|
||
|
class NotOnCurve(Exception):
|
||
|
pass
|
||
|
|
||
|
def decodepoint(s):
|
||
|
unclamped = int(binascii.hexlify(s[:32][::-1]), 16)
|
||
|
clamp = (1 << 255) - 1
|
||
|
y = unclamped & clamp # clear MSB
|
||
|
x = xrecover(y)
|
||
|
if bool(x & 1) != bool(unclamped & (1<<255)): x = Q-x
|
||
|
P = [x,y]
|
||
|
if not isoncurve(P): raise NotOnCurve("decoding point that is not on curve")
|
||
|
return P
|
||
|
|
||
|
# scalars are encoded as 32-bytes little-endian
|
||
|
|
||
|
def bytes_to_scalar(s):
|
||
|
assert len(s) == 32, len(s)
|
||
|
return int(binascii.hexlify(s[::-1]), 16)
|
||
|
|
||
|
def bytes_to_clamped_scalar(s):
|
||
|
# Ed25519 private keys clamp the scalar to ensure two things:
|
||
|
# 1: integer value is in L/2 .. L, to avoid small-logarithm
|
||
|
# non-wraparaound
|
||
|
# 2: low-order 3 bits are zero, so a small-subgroup attack won't learn
|
||
|
# any information
|
||
|
# set the top two bits to 01, and the bottom three to 000
|
||
|
a_unclamped = bytes_to_scalar(s)
|
||
|
AND_CLAMP = (1<<254) - 1 - 7
|
||
|
OR_CLAMP = (1<<254)
|
||
|
a_clamped = (a_unclamped & AND_CLAMP) | OR_CLAMP
|
||
|
return a_clamped
|
||
|
|
||
|
def random_scalar(entropy_f): # 0..L-1 inclusive
|
||
|
# reduce the bias to a safe level by generating 256 extra bits
|
||
|
oversized = int(binascii.hexlify(entropy_f(32+32)), 16)
|
||
|
return oversized % L
|
||
|
|
||
|
def password_to_scalar(pw):
|
||
|
oversized = hashlib.sha512(pw).digest()
|
||
|
return int(binascii.hexlify(oversized), 16) % L
|
||
|
|
||
|
def scalar_to_bytes(y):
|
||
|
y = y % L
|
||
|
assert 0 <= y < 2**256
|
||
|
return binascii.unhexlify("%064x" % y)[::-1]
|
||
|
|
||
|
# Elements, of various orders
|
||
|
|
||
|
def is_extended_zero(XYTZ):
|
||
|
# catch Zero
|
||
|
(X, Y, Z, T) = XYTZ
|
||
|
Y = Y % Q
|
||
|
Z = Z % Q
|
||
|
if X==0 and Y==Z and Y!=0:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
class ElementOfUnknownGroup:
|
||
|
# This is used for points of order 2,4,8,2*L,4*L,8*L
|
||
|
def __init__(self, XYTZ):
|
||
|
assert isinstance(XYTZ, tuple)
|
||
|
assert len(XYTZ) == 4
|
||
|
self.XYTZ = XYTZ
|
||
|
|
||
|
def add(self, other):
|
||
|
if not isinstance(other, ElementOfUnknownGroup):
|
||
|
raise TypeError("elements can only be added to other elements")
|
||
|
sum_XYTZ = add_elements(self.XYTZ, other.XYTZ)
|
||
|
if is_extended_zero(sum_XYTZ):
|
||
|
return Zero
|
||
|
return ElementOfUnknownGroup(sum_XYTZ)
|
||
|
|
||
|
def scalarmult(self, s):
|
||
|
if isinstance(s, ElementOfUnknownGroup):
|
||
|
raise TypeError("elements cannot be multiplied together")
|
||
|
assert s >= 0
|
||
|
product = scalarmult_element_safe_slow(self.XYTZ, s)
|
||
|
return ElementOfUnknownGroup(product)
|
||
|
|
||
|
def to_bytes(self):
|
||
|
return encodepoint(xform_extended_to_affine(self.XYTZ))
|
||
|
def __eq__(self, other):
|
||
|
return self.to_bytes() == other.to_bytes()
|
||
|
def __ne__(self, other):
|
||
|
return not self == other
|
||
|
|
||
|
class Element(ElementOfUnknownGroup):
|
||
|
# this only holds elements in the main 1*L subgroup. It never holds Zero,
|
||
|
# or elements of order 1/2/4/8, or 2*L/4*L/8*L.
|
||
|
|
||
|
def add(self, other):
|
||
|
if not isinstance(other, ElementOfUnknownGroup):
|
||
|
raise TypeError("elements can only be added to other elements")
|
||
|
sum_element = ElementOfUnknownGroup.add(self, other)
|
||
|
if sum_element is Zero:
|
||
|
return sum_element
|
||
|
if isinstance(other, Element):
|
||
|
# adding two subgroup elements results in another subgroup
|
||
|
# element, or Zero, and we've already excluded Zero
|
||
|
return Element(sum_element.XYTZ)
|
||
|
# not necessarily a subgroup member, so assume not
|
||
|
return sum_element
|
||
|
|
||
|
def scalarmult(self, s):
|
||
|
if isinstance(s, ElementOfUnknownGroup):
|
||
|
raise TypeError("elements cannot be multiplied together")
|
||
|
# scalarmult of subgroup members can be done modulo the subgroup
|
||
|
# order, and using the faster non-unified function.
|
||
|
s = s % L
|
||
|
# scalarmult(s=0) gets you Zero
|
||
|
if s == 0:
|
||
|
return Zero
|
||
|
# scalarmult(s=1) gets you self, which is a subgroup member
|
||
|
# scalarmult(s<grouporder) gets you a different subgroup member
|
||
|
return Element(scalarmult_element(self.XYTZ, s))
|
||
|
|
||
|
# negation and subtraction only make sense for the main subgroup
|
||
|
def negate(self):
|
||
|
# slow. Prefer e.scalarmult(-pw) to e.scalarmult(pw).negate()
|
||
|
return Element(scalarmult_element(self.XYTZ, L-2))
|
||
|
def subtract(self, other):
|
||
|
return self.add(other.negate())
|
||
|
|
||
|
class _ZeroElement(ElementOfUnknownGroup):
|
||
|
def add(self, other):
|
||
|
return other # zero+anything = anything
|
||
|
def scalarmult(self, s):
|
||
|
return self # zero*anything = zero
|
||
|
def negate(self):
|
||
|
return self # -zero = zero
|
||
|
def subtract(self, other):
|
||
|
return self.add(other.negate())
|
||
|
|
||
|
|
||
|
Base = Element(xform_affine_to_extended(B))
|
||
|
Zero = _ZeroElement(xform_affine_to_extended((0,1))) # the neutral (identity) element
|
||
|
|
||
|
_zero_bytes = Zero.to_bytes()
|
||
|
|
||
|
|
||
|
def arbitrary_element(seed): # unknown DL
|
||
|
# TODO: if we don't need uniformity, maybe use just sha256 here?
|
||
|
hseed = hashlib.sha512(seed).digest()
|
||
|
y = int(binascii.hexlify(hseed), 16) % Q
|
||
|
|
||
|
# we try successive Y values until we find a valid point
|
||
|
for plus in itertools.count(0):
|
||
|
y_plus = (y + plus) % Q
|
||
|
x = xrecover(y_plus)
|
||
|
Pa = [x,y_plus] # no attempt to use both "positive" and "negative" X
|
||
|
|
||
|
# only about 50% of Y coordinates map to valid curve points (I think
|
||
|
# the other half give you points on the "twist").
|
||
|
if not isoncurve(Pa):
|
||
|
continue
|
||
|
|
||
|
P = ElementOfUnknownGroup(xform_affine_to_extended(Pa))
|
||
|
# even if the point is on our curve, it may not be in our particular
|
||
|
# (order=L) subgroup. The curve has order 8*L, so an arbitrary point
|
||
|
# could have order 1,2,4,8,1*L,2*L,4*L,8*L (everything which divides
|
||
|
# the group order).
|
||
|
|
||
|
# [I MAY BE COMPLETELY WRONG ABOUT THIS, but my brief statistical
|
||
|
# tests suggest it's not too far off] There are phi(x) points with
|
||
|
# order x, so:
|
||
|
# 1 element of order 1: [(x=0,y=1)=Zero]
|
||
|
# 1 element of order 2 [(x=0,y=-1)]
|
||
|
# 2 elements of order 4
|
||
|
# 4 elements of order 8
|
||
|
# L-1 elements of order L (including Base)
|
||
|
# L-1 elements of order 2*L
|
||
|
# 2*(L-1) elements of order 4*L
|
||
|
# 4*(L-1) elements of order 8*L
|
||
|
|
||
|
# So 50% of random points will have order 8*L, 25% will have order
|
||
|
# 4*L, 13% order 2*L, and 13% will have our desired order 1*L (and a
|
||
|
# vanishingly small fraction will have 1/2/4/8). If we multiply any
|
||
|
# of the 8*L points by 2, we're sure to get an 4*L point (and
|
||
|
# multiplying a 4*L point by 2 gives us a 2*L point, and so on).
|
||
|
# Multiplying a 1*L point by 2 gives us a different 1*L point. So
|
||
|
# multiplying by 8 gets us from almost any point into a uniform point
|
||
|
# on the correct 1*L subgroup.
|
||
|
|
||
|
P8 = P.scalarmult(8)
|
||
|
|
||
|
# if we got really unlucky and picked one of the 8 low-order points,
|
||
|
# multiplying by 8 will get us to the identity (Zero), which we check
|
||
|
# for explicitly.
|
||
|
if is_extended_zero(P8.XYTZ):
|
||
|
continue
|
||
|
|
||
|
# Test that we're finally in the right group. We want to scalarmult
|
||
|
# by L, and we want to *not* use the trick in Group.scalarmult()
|
||
|
# which does x%L, because that would bypass the check we care about.
|
||
|
# P is still an _ElementOfUnknownGroup, which doesn't use x%L because
|
||
|
# that's not correct for points outside the main group.
|
||
|
assert is_extended_zero(P8.scalarmult(L).XYTZ)
|
||
|
|
||
|
return Element(P8.XYTZ)
|
||
|
# never reached
|
||
|
|
||
|
def bytes_to_unknown_group_element(bytes):
|
||
|
# this accepts all elements, including Zero and wrong-subgroup ones
|
||
|
if bytes == _zero_bytes:
|
||
|
return Zero
|
||
|
XYTZ = xform_affine_to_extended(decodepoint(bytes))
|
||
|
return ElementOfUnknownGroup(XYTZ)
|
||
|
|
||
|
def bytes_to_element(bytes):
|
||
|
# this strictly only accepts elements in the right subgroup
|
||
|
P = bytes_to_unknown_group_element(bytes)
|
||
|
if P is Zero:
|
||
|
raise ValueError("element was Zero")
|
||
|
if not is_extended_zero(P.scalarmult(L).XYTZ):
|
||
|
raise ValueError("element is not in the right group")
|
||
|
# the point is in the expected 1*L subgroup, not in the 2/4/8 groups,
|
||
|
# or in the 2*L/4*L/8*L groups. Promote it to a correct-group Element.
|
||
|
return Element(P.XYTZ)
|