# MIT License # # Copyright (c) 2015 Brian Warner and other contributors # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import binascii, hashlib, itertools Q = 2**255 - 19 L = 2**252 + 27742317777372353535851937790883648493 def inv(x): return pow(x, Q-2, Q) d = -121665 * inv(121666) I = pow(2,(Q-1)//4,Q) def xrecover(y): xx = (y*y-1) * inv(d*y*y+1) x = pow(xx,(Q+3)//8,Q) if (x*x - xx) % Q != 0: x = (x*I) % Q if x % 2 != 0: x = Q-x return x By = 4 * inv(5) Bx = xrecover(By) B = [Bx % Q,By % Q] # Extended Coordinates: x=X/Z, y=Y/Z, x*y=T/Z # http://www.hyperelliptic.org/EFD/g1p/auto-twisted-extended-1.html def xform_affine_to_extended(pt): (x, y) = pt return (x%Q, y%Q, 1, (x*y)%Q) # (X,Y,Z,T) def xform_extended_to_affine(pt): (x, y, z, _) = pt return ((x*inv(z))%Q, (y*inv(z))%Q) def double_element(pt): # extended->extended # dbl-2008-hwcd (X1, Y1, Z1, _) = pt A = (X1*X1) B = (Y1*Y1) C = (2*Z1*Z1) D = (-A) % Q J = (X1+Y1) % Q E = (J*J-A-B) % Q G = (D+B) % Q F = (G-C) % Q H = (D-B) % Q X3 = (E*F) % Q Y3 = (G*H) % Q Z3 = (F*G) % Q T3 = (E*H) % Q return (X3, Y3, Z3, T3) def add_elements(pt1, pt2): # extended->extended # add-2008-hwcd-3 . Slightly slower than add-2008-hwcd-4, but -3 is # unified, so it's safe for general-purpose addition (X1, Y1, Z1, T1) = pt1 (X2, Y2, Z2, T2) = pt2 A = ((Y1-X1)*(Y2-X2)) % Q B = ((Y1+X1)*(Y2+X2)) % Q C = T1*(2*d)*T2 % Q D = Z1*2*Z2 % Q E = (B-A) % Q F = (D-C) % Q G = (D+C) % Q H = (B+A) % Q X3 = (E*F) % Q Y3 = (G*H) % Q T3 = (E*H) % Q Z3 = (F*G) % Q return (X3, Y3, Z3, T3) def scalarmult_element_safe_slow(pt, n): # this form is slightly slower, but tolerates arbitrary points, including # those which are not in the main 1*L subgroup. This includes points of # order 1 (the neutral element Zero), 2, 4, and 8. assert n >= 0 if n==0: return xform_affine_to_extended((0,1)) _ = double_element(scalarmult_element_safe_slow(pt, n>>1)) return add_elements(_, pt) if n&1 else _ def _add_elements_nonunfied(pt1, pt2): # extended->extended # add-2008-hwcd-4 : NOT unified, only for pt1!=pt2. About 10% faster than # the (unified) add-2008-hwcd-3, and safe to use inside scalarmult if you # aren't using points of order 1/2/4/8 (X1, Y1, Z1, T1) = pt1 (X2, Y2, Z2, T2) = pt2 A = ((Y1-X1)*(Y2+X2)) % Q B = ((Y1+X1)*(Y2-X2)) % Q C = (Z1*2*T2) % Q D = (T1*2*Z2) % Q E = (D+C) % Q F = (B-A) % Q G = (B+A) % Q H = (D-C) % Q X3 = (E*F) % Q Y3 = (G*H) % Q Z3 = (F*G) % Q T3 = (E*H) % Q return (X3, Y3, Z3, T3) def scalarmult_element(pt, n): # extended->extended # This form only works properly when given points that are a member of # the main 1*L subgroup. It will give incorrect answers when called with # the points of order 1/2/4/8, including point Zero. (it will also work # properly when given points of order 2*L/4*L/8*L) assert n >= 0 if n==0: return xform_affine_to_extended((0,1)) _ = double_element(scalarmult_element(pt, n>>1)) return _add_elements_nonunfied(_, pt) if n&1 else _ # points are encoded as 32-bytes little-endian, b255 is sign, b2b1b0 are 0 def encodepoint(P): x = P[0] y = P[1] # MSB of output equals x.b0 (=x&1) # rest of output is little-endian y assert 0 <= y < (1<<255) # always < 0x7fff..ff if x & 1: y += 1<<255 return binascii.unhexlify("%064x" % y)[::-1] def isoncurve(P): x = P[0] y = P[1] return (-x*x + y*y - 1 - d*x*x*y*y) % Q == 0 class NotOnCurve(Exception): pass def decodepoint(s): unclamped = int(binascii.hexlify(s[:32][::-1]), 16) clamp = (1 << 255) - 1 y = unclamped & clamp # clear MSB x = xrecover(y) if bool(x & 1) != bool(unclamped & (1<<255)): x = Q-x P = [x,y] if not isoncurve(P): raise NotOnCurve("decoding point that is not on curve") return P # scalars are encoded as 32-bytes little-endian def bytes_to_scalar(s): assert len(s) == 32, len(s) return int(binascii.hexlify(s[::-1]), 16) def bytes_to_clamped_scalar(s): # Ed25519 private keys clamp the scalar to ensure two things: # 1: integer value is in L/2 .. L, to avoid small-logarithm # non-wraparaound # 2: low-order 3 bits are zero, so a small-subgroup attack won't learn # any information # set the top two bits to 01, and the bottom three to 000 a_unclamped = bytes_to_scalar(s) AND_CLAMP = (1<<254) - 1 - 7 OR_CLAMP = (1<<254) a_clamped = (a_unclamped & AND_CLAMP) | OR_CLAMP return a_clamped def random_scalar(entropy_f): # 0..L-1 inclusive # reduce the bias to a safe level by generating 256 extra bits oversized = int(binascii.hexlify(entropy_f(32+32)), 16) return oversized % L def password_to_scalar(pw): oversized = hashlib.sha512(pw).digest() return int(binascii.hexlify(oversized), 16) % L def scalar_to_bytes(y): y = y % L assert 0 <= y < 2**256 return binascii.unhexlify("%064x" % y)[::-1] # Elements, of various orders def is_extended_zero(XYTZ): # catch Zero (X, Y, Z, T) = XYTZ Y = Y % Q Z = Z % Q if X==0 and Y==Z and Y!=0: return True return False class ElementOfUnknownGroup: # This is used for points of order 2,4,8,2*L,4*L,8*L def __init__(self, XYTZ): assert isinstance(XYTZ, tuple) assert len(XYTZ) == 4 self.XYTZ = XYTZ def add(self, other): if not isinstance(other, ElementOfUnknownGroup): raise TypeError("elements can only be added to other elements") sum_XYTZ = add_elements(self.XYTZ, other.XYTZ) if is_extended_zero(sum_XYTZ): return Zero return ElementOfUnknownGroup(sum_XYTZ) def scalarmult(self, s): if isinstance(s, ElementOfUnknownGroup): raise TypeError("elements cannot be multiplied together") assert s >= 0 product = scalarmult_element_safe_slow(self.XYTZ, s) return ElementOfUnknownGroup(product) def to_bytes(self): return encodepoint(xform_extended_to_affine(self.XYTZ)) def __eq__(self, other): return self.to_bytes() == other.to_bytes() def __ne__(self, other): return not self == other class Element(ElementOfUnknownGroup): # this only holds elements in the main 1*L subgroup. It never holds Zero, # or elements of order 1/2/4/8, or 2*L/4*L/8*L. def add(self, other): if not isinstance(other, ElementOfUnknownGroup): raise TypeError("elements can only be added to other elements") sum_element = ElementOfUnknownGroup.add(self, other) if sum_element is Zero: return sum_element if isinstance(other, Element): # adding two subgroup elements results in another subgroup # element, or Zero, and we've already excluded Zero return Element(sum_element.XYTZ) # not necessarily a subgroup member, so assume not return sum_element def scalarmult(self, s): if isinstance(s, ElementOfUnknownGroup): raise TypeError("elements cannot be multiplied together") # scalarmult of subgroup members can be done modulo the subgroup # order, and using the faster non-unified function. s = s % L # scalarmult(s=0) gets you Zero if s == 0: return Zero # scalarmult(s=1) gets you self, which is a subgroup member # scalarmult(s<grouporder) gets you a different subgroup member return Element(scalarmult_element(self.XYTZ, s)) # negation and subtraction only make sense for the main subgroup def negate(self): # slow. Prefer e.scalarmult(-pw) to e.scalarmult(pw).negate() return Element(scalarmult_element(self.XYTZ, L-2)) def subtract(self, other): return self.add(other.negate()) class _ZeroElement(ElementOfUnknownGroup): def add(self, other): return other # zero+anything = anything def scalarmult(self, s): return self # zero*anything = zero def negate(self): return self # -zero = zero def subtract(self, other): return self.add(other.negate()) Base = Element(xform_affine_to_extended(B)) Zero = _ZeroElement(xform_affine_to_extended((0,1))) # the neutral (identity) element _zero_bytes = Zero.to_bytes() def arbitrary_element(seed): # unknown DL # TODO: if we don't need uniformity, maybe use just sha256 here? hseed = hashlib.sha512(seed).digest() y = int(binascii.hexlify(hseed), 16) % Q # we try successive Y values until we find a valid point for plus in itertools.count(0): y_plus = (y + plus) % Q x = xrecover(y_plus) Pa = [x,y_plus] # no attempt to use both "positive" and "negative" X # only about 50% of Y coordinates map to valid curve points (I think # the other half give you points on the "twist"). if not isoncurve(Pa): continue P = ElementOfUnknownGroup(xform_affine_to_extended(Pa)) # even if the point is on our curve, it may not be in our particular # (order=L) subgroup. The curve has order 8*L, so an arbitrary point # could have order 1,2,4,8,1*L,2*L,4*L,8*L (everything which divides # the group order). # [I MAY BE COMPLETELY WRONG ABOUT THIS, but my brief statistical # tests suggest it's not too far off] There are phi(x) points with # order x, so: # 1 element of order 1: [(x=0,y=1)=Zero] # 1 element of order 2 [(x=0,y=-1)] # 2 elements of order 4 # 4 elements of order 8 # L-1 elements of order L (including Base) # L-1 elements of order 2*L # 2*(L-1) elements of order 4*L # 4*(L-1) elements of order 8*L # So 50% of random points will have order 8*L, 25% will have order # 4*L, 13% order 2*L, and 13% will have our desired order 1*L (and a # vanishingly small fraction will have 1/2/4/8). If we multiply any # of the 8*L points by 2, we're sure to get an 4*L point (and # multiplying a 4*L point by 2 gives us a 2*L point, and so on). # Multiplying a 1*L point by 2 gives us a different 1*L point. So # multiplying by 8 gets us from almost any point into a uniform point # on the correct 1*L subgroup. P8 = P.scalarmult(8) # if we got really unlucky and picked one of the 8 low-order points, # multiplying by 8 will get us to the identity (Zero), which we check # for explicitly. if is_extended_zero(P8.XYTZ): continue # Test that we're finally in the right group. We want to scalarmult # by L, and we want to *not* use the trick in Group.scalarmult() # which does x%L, because that would bypass the check we care about. # P is still an _ElementOfUnknownGroup, which doesn't use x%L because # that's not correct for points outside the main group. assert is_extended_zero(P8.scalarmult(L).XYTZ) return Element(P8.XYTZ) # never reached def bytes_to_unknown_group_element(bytes): # this accepts all elements, including Zero and wrong-subgroup ones if bytes == _zero_bytes: return Zero XYTZ = xform_affine_to_extended(decodepoint(bytes)) return ElementOfUnknownGroup(XYTZ) def bytes_to_element(bytes): # this strictly only accepts elements in the right subgroup P = bytes_to_unknown_group_element(bytes) if P is Zero: raise ValueError("element was Zero") if not is_extended_zero(P.scalarmult(L).XYTZ): raise ValueError("element is not in the right group") # the point is in the expected 1*L subgroup, not in the 2/4/8 groups, # or in the 2*L/4*L/8*L groups. Promote it to a correct-group Element. return Element(P.XYTZ)