Reticulum/RNS/Cryptography/pure25519/basic.py
Pavol Rusnak ce9dc56048
modernize
2024-10-11 23:13:52 +02:00

369 lines
13 KiB
Python

# MIT License
#
# Copyright (c) 2015 Brian Warner and other contributors
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import binascii, hashlib, itertools
Q = 2**255 - 19
L = 2**252 + 27742317777372353535851937790883648493
def inv(x):
return pow(x, Q-2, Q)
d = -121665 * inv(121666)
I = pow(2,(Q-1)//4,Q)
def xrecover(y):
xx = (y*y-1) * inv(d*y*y+1)
x = pow(xx,(Q+3)//8,Q)
if (x*x - xx) % Q != 0: x = (x*I) % Q
if x % 2 != 0: x = Q-x
return x
By = 4 * inv(5)
Bx = xrecover(By)
B = [Bx % Q,By % Q]
# Extended Coordinates: x=X/Z, y=Y/Z, x*y=T/Z
# http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html
def xform_affine_to_extended(pt):
(x, y) = pt
return (x%Q, y%Q, 1, (x*y)%Q) # (X,Y,Z,T)
def xform_extended_to_affine(pt):
(x, y, z, _) = pt
return ((x*inv(z))%Q, (y*inv(z))%Q)
def double_element(pt): # extended->extended
# dbl-2008-hwcd
(X1, Y1, Z1, _) = pt
A = (X1*X1)
B = (Y1*Y1)
C = (2*Z1*Z1)
D = (-A) % Q
J = (X1+Y1) % Q
E = (J*J-A-B) % Q
G = (D+B) % Q
F = (G-C) % Q
H = (D-B) % Q
X3 = (E*F) % Q
Y3 = (G*H) % Q
Z3 = (F*G) % Q
T3 = (E*H) % Q
return (X3, Y3, Z3, T3)
def add_elements(pt1, pt2): # extended->extended
# add-2008-hwcd-3 . Slightly slower than add-2008-hwcd-4, but -3 is
# unified, so it's safe for general-purpose addition
(X1, Y1, Z1, T1) = pt1
(X2, Y2, Z2, T2) = pt2
A = ((Y1-X1)*(Y2-X2)) % Q
B = ((Y1+X1)*(Y2+X2)) % Q
C = T1*(2*d)*T2 % Q
D = Z1*2*Z2 % Q
E = (B-A) % Q
F = (D-C) % Q
G = (D+C) % Q
H = (B+A) % Q
X3 = (E*F) % Q
Y3 = (G*H) % Q
T3 = (E*H) % Q
Z3 = (F*G) % Q
return (X3, Y3, Z3, T3)
def scalarmult_element_safe_slow(pt, n):
# this form is slightly slower, but tolerates arbitrary points, including
# those which are not in the main 1*L subgroup. This includes points of
# order 1 (the neutral element Zero), 2, 4, and 8.
assert n >= 0
if n==0:
return xform_affine_to_extended((0,1))
_ = double_element(scalarmult_element_safe_slow(pt, n>>1))
return add_elements(_, pt) if n&1 else _
def _add_elements_nonunfied(pt1, pt2): # extended->extended
# add-2008-hwcd-4 : NOT unified, only for pt1!=pt2. About 10% faster than
# the (unified) add-2008-hwcd-3, and safe to use inside scalarmult if you
# aren't using points of order 1/2/4/8
(X1, Y1, Z1, T1) = pt1
(X2, Y2, Z2, T2) = pt2
A = ((Y1-X1)*(Y2+X2)) % Q
B = ((Y1+X1)*(Y2-X2)) % Q
C = (Z1*2*T2) % Q
D = (T1*2*Z2) % Q
E = (D+C) % Q
F = (B-A) % Q
G = (B+A) % Q
H = (D-C) % Q
X3 = (E*F) % Q
Y3 = (G*H) % Q
Z3 = (F*G) % Q
T3 = (E*H) % Q
return (X3, Y3, Z3, T3)
def scalarmult_element(pt, n): # extended->extended
# This form only works properly when given points that are a member of
# the main 1*L subgroup. It will give incorrect answers when called with
# the points of order 1/2/4/8, including point Zero. (it will also work
# properly when given points of order 2*L/4*L/8*L)
assert n >= 0
if n==0:
return xform_affine_to_extended((0,1))
_ = double_element(scalarmult_element(pt, n>>1))
return _add_elements_nonunfied(_, pt) if n&1 else _
# points are encoded as 32-bytes little-endian, b255 is sign, b2b1b0 are 0
def encodepoint(P):
x = P[0]
y = P[1]
# MSB of output equals x.b0 (=x&1)
# rest of output is little-endian y
assert 0 <= y < (1<<255) # always < 0x7fff..ff
if x & 1:
y += 1<<255
return binascii.unhexlify(f"{y:064x}")[::-1]
def isoncurve(P):
x = P[0]
y = P[1]
return (-x*x + y*y - 1 - d*x*x*y*y) % Q == 0
class NotOnCurve(Exception):
pass
def decodepoint(s):
unclamped = int(binascii.hexlify(s[:32][::-1]), 16)
clamp = (1 << 255) - 1
y = unclamped & clamp # clear MSB
x = xrecover(y)
if bool(x & 1) != bool(unclamped & (1<<255)): x = Q-x
P = [x,y]
if not isoncurve(P): raise NotOnCurve("decoding point that is not on curve")
return P
# scalars are encoded as 32-bytes little-endian
def bytes_to_scalar(s):
assert len(s) == 32, len(s)
return int(binascii.hexlify(s[::-1]), 16)
def bytes_to_clamped_scalar(s):
# Ed25519 private keys clamp the scalar to ensure two things:
# 1: integer value is in L/2 .. L, to avoid small-logarithm
# non-wraparaound
# 2: low-order 3 bits are zero, so a small-subgroup attack won't learn
# any information
# set the top two bits to 01, and the bottom three to 000
a_unclamped = bytes_to_scalar(s)
AND_CLAMP = (1<<254) - 1 - 7
OR_CLAMP = (1<<254)
a_clamped = (a_unclamped & AND_CLAMP) | OR_CLAMP
return a_clamped
def random_scalar(entropy_f): # 0..L-1 inclusive
# reduce the bias to a safe level by generating 256 extra bits
oversized = int(binascii.hexlify(entropy_f(32+32)), 16)
return oversized % L
def password_to_scalar(pw):
oversized = hashlib.sha512(pw).digest()
return int(binascii.hexlify(oversized), 16) % L
def scalar_to_bytes(y):
y = y % L
assert 0 <= y < 2**256
return binascii.unhexlify(f"{y:064x}")[::-1]
# Elements, of various orders
def is_extended_zero(XYTZ):
# catch Zero
(X, Y, Z, T) = XYTZ
Y = Y % Q
Z = Z % Q
if X==0 and Y==Z and Y!=0:
return True
return False
class ElementOfUnknownGroup:
# This is used for points of order 2,4,8,2*L,4*L,8*L
def __init__(self, XYTZ):
assert isinstance(XYTZ, tuple)
assert len(XYTZ) == 4
self.XYTZ = XYTZ
def add(self, other):
if not isinstance(other, ElementOfUnknownGroup):
raise TypeError("elements can only be added to other elements")
sum_XYTZ = add_elements(self.XYTZ, other.XYTZ)
if is_extended_zero(sum_XYTZ):
return Zero
return ElementOfUnknownGroup(sum_XYTZ)
def scalarmult(self, s):
if isinstance(s, ElementOfUnknownGroup):
raise TypeError("elements cannot be multiplied together")
assert s >= 0
product = scalarmult_element_safe_slow(self.XYTZ, s)
return ElementOfUnknownGroup(product)
def to_bytes(self):
return encodepoint(xform_extended_to_affine(self.XYTZ))
def __eq__(self, other):
return self.to_bytes() == other.to_bytes()
def __ne__(self, other):
return not self == other
class Element(ElementOfUnknownGroup):
# this only holds elements in the main 1*L subgroup. It never holds Zero,
# or elements of order 1/2/4/8, or 2*L/4*L/8*L.
def add(self, other):
if not isinstance(other, ElementOfUnknownGroup):
raise TypeError("elements can only be added to other elements")
sum_element = ElementOfUnknownGroup.add(self, other)
if sum_element is Zero:
return sum_element
if isinstance(other, Element):
# adding two subgroup elements results in another subgroup
# element, or Zero, and we've already excluded Zero
return Element(sum_element.XYTZ)
# not necessarily a subgroup member, so assume not
return sum_element
def scalarmult(self, s):
if isinstance(s, ElementOfUnknownGroup):
raise TypeError("elements cannot be multiplied together")
# scalarmult of subgroup members can be done modulo the subgroup
# order, and using the faster non-unified function.
s = s % L
# scalarmult(s=0) gets you Zero
if s == 0:
return Zero
# scalarmult(s=1) gets you self, which is a subgroup member
# scalarmult(s<grouporder) gets you a different subgroup member
return Element(scalarmult_element(self.XYTZ, s))
# negation and subtraction only make sense for the main subgroup
def negate(self):
# slow. Prefer e.scalarmult(-pw) to e.scalarmult(pw).negate()
return Element(scalarmult_element(self.XYTZ, L-2))
def subtract(self, other):
return self.add(other.negate())
class _ZeroElement(ElementOfUnknownGroup):
def add(self, other):
return other # zero+anything = anything
def scalarmult(self, s):
return self # zero*anything = zero
def negate(self):
return self # -zero = zero
def subtract(self, other):
return self.add(other.negate())
Base = Element(xform_affine_to_extended(B))
Zero = _ZeroElement(xform_affine_to_extended((0,1))) # the neutral (identity) element
_zero_bytes = Zero.to_bytes()
def arbitrary_element(seed): # unknown DL
# TODO: if we don't need uniformity, maybe use just sha256 here?
hseed = hashlib.sha512(seed).digest()
y = int(binascii.hexlify(hseed), 16) % Q
# we try successive Y values until we find a valid point
for plus in itertools.count(0):
y_plus = (y + plus) % Q
x = xrecover(y_plus)
Pa = [x,y_plus] # no attempt to use both "positive" and "negative" X
# only about 50% of Y coordinates map to valid curve points (I think
# the other half give you points on the "twist").
if not isoncurve(Pa):
continue
P = ElementOfUnknownGroup(xform_affine_to_extended(Pa))
# even if the point is on our curve, it may not be in our particular
# (order=L) subgroup. The curve has order 8*L, so an arbitrary point
# could have order 1,2,4,8,1*L,2*L,4*L,8*L (everything which divides
# the group order).
# [I MAY BE COMPLETELY WRONG ABOUT THIS, but my brief statistical
# tests suggest it's not too far off] There are phi(x) points with
# order x, so:
# 1 element of order 1: [(x=0,y=1)=Zero]
# 1 element of order 2 [(x=0,y=-1)]
# 2 elements of order 4
# 4 elements of order 8
# L-1 elements of order L (including Base)
# L-1 elements of order 2*L
# 2*(L-1) elements of order 4*L
# 4*(L-1) elements of order 8*L
# So 50% of random points will have order 8*L, 25% will have order
# 4*L, 13% order 2*L, and 13% will have our desired order 1*L (and a
# vanishingly small fraction will have 1/2/4/8). If we multiply any
# of the 8*L points by 2, we're sure to get an 4*L point (and
# multiplying a 4*L point by 2 gives us a 2*L point, and so on).
# Multiplying a 1*L point by 2 gives us a different 1*L point. So
# multiplying by 8 gets us from almost any point into a uniform point
# on the correct 1*L subgroup.
P8 = P.scalarmult(8)
# if we got really unlucky and picked one of the 8 low-order points,
# multiplying by 8 will get us to the identity (Zero), which we check
# for explicitly.
if is_extended_zero(P8.XYTZ):
continue
# Test that we're finally in the right group. We want to scalarmult
# by L, and we want to *not* use the trick in Group.scalarmult()
# which does x%L, because that would bypass the check we care about.
# P is still an _ElementOfUnknownGroup, which doesn't use x%L because
# that's not correct for points outside the main group.
assert is_extended_zero(P8.scalarmult(L).XYTZ)
return Element(P8.XYTZ)
# never reached
def bytes_to_unknown_group_element(bytes):
# this accepts all elements, including Zero and wrong-subgroup ones
if bytes == _zero_bytes:
return Zero
XYTZ = xform_affine_to_extended(decodepoint(bytes))
return ElementOfUnknownGroup(XYTZ)
def bytes_to_element(bytes):
# this strictly only accepts elements in the right subgroup
P = bytes_to_unknown_group_element(bytes)
if P is Zero:
raise ValueError("element was Zero")
if not is_extended_zero(P.scalarmult(L).XYTZ):
raise ValueError("element is not in the right group")
# the point is in the expected 1*L subgroup, not in the 2/4/8 groups,
# or in the 2*L/4*L/8*L groups. Promote it to a correct-group Element.
return Element(P.XYTZ)