Source code for pragma_sdk.common.randomness.randomness_utils

# Copyright (C) 2020 Eric Schorn, NCC Group Plc; Provided under the MIT license

# This code follows the (IETF) IRTF CFRG Verifiable Random Functions (VRFs) spec *very* closely.
# Please refer to https://tools.ietf.org/pdf/draft-irtf-cfrg-vrf-06.pdf

# https://github.com/nccgroup/draft-irtf-cfrg-vrf-06


import hashlib

# Public API


# Section 5.1. ECVRF Proving
[docs] def ecvrf_prove(secret_key, alpha_string): """ Input: secret_key - VRF private key (32 bytes) alpha_string - input alpha, an octet string Output: ("VALID", pi_string) - where pi_string is the VRF proof, octet string of length ptLen+n+qLen (80) bytes, or ("INVALID", []) upon failure """ # 1. Use secret_key to derive the VRF secret scalar x and the VRF public key y = x*B secret_scalar_x = _get_secret_scalar(secret_key) public_key_y = get_public_key(secret_key) # 2. H = ECVRF_hash_to_curve(suite_string, y, alpha_string) h = _ecvrf_hash_to_curve_elligator2_25519(SUITE_STRING, public_key_y, alpha_string) if h == "INVALID": return "INVALID", [] # 3. h_string = point_to_string(H) h_string = _decode_point(h) if h_string == "INVALID": return "INVALID", [] # 4. Gamma = x*H gamma = _scalar_multiply(p=h_string, e=secret_scalar_x) # 5. k = ECVRF_nonce_generation(secret_key, h_string) k = _ecvrf_nonce_generation_rfc8032(secret_key, h) # 6. c = ECVRF_hash_points(H, Gamma, k*B, k*H) k_b = _scalar_multiply(p=BASE, e=k) k_h = _scalar_multiply(p=h_string, e=k) c = _ecvrf_hash_points(h_string, gamma, k_b, k_h) # 7. s = (k + c*x) mod q s = (k + c * secret_scalar_x) % ORDER # 8. pi_string = point_to_string(Gamma) || int_to_string(c, n) || int_to_string(s, qLen) pi_string = ( _encode_point(gamma) + int.to_bytes(c, 16, "little") + int.to_bytes(s, 32, "little") ) if "test_dict" in globals(): _assert_and_sample( [ "secret_scalar_x", "public_key_y", "h", "gamma", "k_b", "k_h", "pi_string", ], [ secret_scalar_x.to_bytes(32, "little"), public_key_y, h, _encode_point(gamma), _encode_point(k_b), _encode_point(k_h), pi_string, ], ) # 9. Output pi_string return "VALID", pi_string
# Section 5.2. ECVRF Proof To Hash
[docs] def ecvrf_proof_to_hash(pi_string): """ Input: pi_string - VRF proof, octet string of length ptLen+n+qLen (80) bytes Output: ("VALID", beta_string) where beta_string is the VRF hash output, octet string of length hLen (64) bytes, or ("INVALID", []) upon failure """ # 1. D = ECVRF_decode_proof(pi_string) d = _ecvrf_decode_proof(pi_string) # 2. If D is "INVALID", output "INVALID" and stop if d == "INVALID": return "INVALID", [] # 3. (Gamma, c, s) = D gamma, _c, _s = d # 4. three_string = 0x03 = int_to_string(3, 1), a single octet with value 3 three_string = bytes([0x03]) # 5. beta_string = Hash(suite_string || three_string || point_to_string(cofactor * Gamma)) cofactor_gamma = _scalar_multiply(p=gamma, e=COFACTOR) # Curve cofactor beta_string = _short_hash( SUITE_STRING + three_string + _encode_point(cofactor_gamma) ) if "test_dict" in globals(): _assert_and_sample(["beta_string"], [beta_string]) # 6. Output beta_string return "VALID", beta_string
# Section 5.3. ECVRF Verifying
[docs] def ecvrf_verify(y, pi_string, alpha_string) -> (str, list): # type: ignore[syntax] """ Input: y - public key, an EC point as bytes pi_string - VRF proof, octet string of length ptLen+n+qLen (80) bytes alpha_string - VRF input, octet string Output: ("VALID", beta_string), where beta_string is the VRF hash output, octet string of length hLen (64) bytes; or ("INVALID", []) upon failure """ # Note that the API caller is expected to verify that the returned beta_string is the # expected one and this has a strong potential for mistakes/oversights (such as checking # for "VALID" but not the actual value). Production code would be better served by # passing in the expected beta_string and getting a simpler pass/fail in response. # 1. D = ECVRF_decode_proof(pi_string) d = _ecvrf_decode_proof(pi_string) # 2. If D is "INVALID", output "INVALID" and stop if d == "INVALID": return "INVALID", [] # 3. (Gamma, c, s) = D gamma, c, s = d # 4. H = ECVRF_hash_to_curve(suite_string, y, alpha_string) h = _ecvrf_hash_to_curve_elligator2_25519(SUITE_STRING, y, alpha_string) if h == "INVALID": return "INVALID", [] # 5. U = s*B - c*y y_point = _decode_point(y) h_point = _decode_point(h) if y_point == "INVALID" or h_point == "INVALID": return "INVALID", [] s_b = _scalar_multiply(p=BASE, e=s) c_y = _scalar_multiply(p=y_point, e=c) nc_y = [PRIME - c_y[0], c_y[1]] u = _edwards_add(s_b, nc_y) # 6. V = s*H - c*Gamma s_h = _scalar_multiply(p=h_point, e=s) c_g = _scalar_multiply(p=gamma, e=c) nc_g = [PRIME - c_g[0], c_g[1]] v = _edwards_add(nc_g, s_h) # 7. c’ = ECVRF_hash_points(H, Gamma, U, V) cp = _ecvrf_hash_points(h_point, gamma, u, v) if "test_dict" in globals(): _assert_and_sample(["h", "u", "v"], [h, _encode_point(u), _encode_point(v)]) # 8. If c and c’ are equal, output ("VALID", ECVRF_proof_to_hash(pi_string)); else output "INVALID" if c == cp: return ecvrf_proof_to_hash(pi_string) # Includes logic for VALID/INVALID return "INVALID", []
[docs] def get_public_key(secret_key: str) -> str: """Calculate and return the public_key as an encoded point string (bytes)""" secret_int = _get_secret_scalar(secret_key) public_point = _scalar_multiply(p=BASE, e=secret_int) public_string = _encode_point(public_point) return public_string # type: ignore[no-any-return]
# Internal functions # Section 5.4.1.2. ECVRF_hash_to_curve_elligator2_25519 def _ecvrf_hash_to_curve_elligator2_25519(suite_string, y, alpha_string): """ Input: suite_string - a single octet specifying ECVRF ciphersuite. alpha_string - value to be hashed, an octet string y - public key, an EC point as bytes Output: H - hashed value, a finite EC point in G, or INVALID upon failure Fixed options: p = 2^255-19, the size of the finite field F, a prime, for edwards25519 and curve25519 curves A = 486662, Montgomery curve constant for curve25519 cofactor = 8, the cofactor for edwards25519 and curve25519 curves """ assert suite_string == SUITE_STRING # 1. PK_string = point_to_string(y) # 2. one_string = 0x01 = int_to_string(1, 1) (a single octet with value 1) one_string = bytes([0x01]) # 3. hash_string = Hash(suite_string || one_string || PK_string || alpha_string ) hash_string = _hash(suite_string + one_string + y + alpha_string) # 4. r_string = hash_string[0]...hash_string[31] r_string = bytearray(hash_string[0:32]) # 5. oneTwentySeven_string = 0x7F = int_to_string(127, 1) (a single octet with value 127) one_twenty_seven_string = 0x7F # Note: '&' wants an int, not a byte # 6. r_string[31] = r_string[31] & oneTwentySeven_string (this step clears the high-order bit of octet 31) r_string[31] = int(r_string[31] & one_twenty_seven_string) # 7. r = string_to_int(truncated_h_string) r = int.from_bytes(r_string, "little") # 8. u = - A / (1 + 2*(r^2) ) mod p (note: the inverse of (1+2*(r^2)) modulo p is guaranteed to exist) u = (PRIME - A) * _inverse(1 + 2 * (r**2)) % PRIME # 9. w = u * (u^2 + A*u + 1) mod p (this step evaluates the Montgomery equation for Curve25519) w = u * (u**2 + A * u + 1) % PRIME # 10. Let e equal the Legendre symbol of w and p (see note after item 16) e = pow(w, (PRIME - 1) // 2, PRIME) # 11. If e is equal to 1 then final_u = u; else final_u = (-A - u) mod p (see note after item 16) final_u = (e * u + (e - 1) * A * TWO_INV) % PRIME # Note that while the above formula makes some sense in a constant-time implementation, this # implementation is not intended to be constant time. Thus it could be considerably simplified. # 12. y_coordinate = (final_u - 1) / (final_u + 1) mod p y_coordinate = (final_u - 1) * _inverse(final_u + 1) % PRIME # 13. y_string = int_to_string (y_coordinate, 32) y_string = int.to_bytes(y_coordinate, 32, "little") # 14. H_prelim = string_to_point(h_string) h_prelim = _decode_point(y_string) if h_prelim == "INVALID": return "INVALID" # 15. Set H = cofactor * H_prelim h = _scalar_multiply(p=h_prelim, e=COFACTOR) # Curve cofactor # 16. Output H h_point = _encode_point(h) if "test_dict" in globals(): _assert_and_sample( ["r", "w", "e"], [r_string, int.to_bytes(w, 32, "little"), int.to_bytes(e, 32, "little")], ) return h_point # 5.4.2.2. ECVRF Nonce Generation From RFC 8032 def _ecvrf_nonce_generation_rfc8032(secret_key, h_string): """ Input: secret_key - an ECVRF secret key as bytes h_string - an octet string Output: k - an integer between 0 and q-1 """ # 1. hashed_sk_string = Hash (secret_key) hashed_sk_string = _hash(secret_key) # 2. truncated_hashed_sk_string = hashed_sk_string[32]...hashed_sk_string[63] truncated_hashed_sk_string = hashed_sk_string[32:] # 3. k_string = Hash(truncated_hashed_sk_string || h_string) k_string = _hash(truncated_hashed_sk_string + h_string) # 4. k = string_to_int(k_string) mod q k = int.from_bytes(k_string, "little") % ORDER if "test_dict" in globals(): _assert_and_sample(["k"], [k_string]) return k # Section 5.4.3. ECVRF Hash Points def _ecvrf_hash_points(p1, p2, p3, p4): """ Input: P1...PM - EC points in G Output: c - hash value, integer between 0 and 2^(8n)-1 """ # 1. two_string = 0x02 = int_to_string(2, 1), a single octet with value 2 two_string = bytes([0x02]) # 2. Initialize str = suite_string || two_string string = SUITE_STRING + two_string # 3. for PJ in [P1, P2, ... PM]: # str = str || point_to_string(PJ) string = ( string + _encode_point(p1) + _encode_point(p2) + _encode_point(p3) + _encode_point(p4) ) # 4. c_string = Hash(str) c_string = _hash(string) # 5. truncated_c_string = c_string[0]...c_string[n-1] truncated_c_string = c_string[0:16] # 6. c = string_to_int(truncated_c_string) c = int.from_bytes(truncated_c_string, "little") # 7. Output c return c # Section 5.4.4. ECVRF Decode Proof def _ecvrf_decode_proof(pi_string): """ Input: pi_string - VRF proof, octet string (ptLen+n+qLen octets) Output: "INVALID", or Gamma - EC point c - integer between 0 and 2^(8n)-1 s - integer between 0 and 2^(8qLen)-1 """ if len(pi_string) != 80: # ptLen+n+qLen octets = 32+16+32 = 80 return "INVALID" # 1. let gamma_string = pi_string[0]...p_string[ptLen-1] gamma_string = pi_string[0:32] # 2. let c_string = pi_string[ptLen]...pi_string[ptLen+n-1] c_string = pi_string[32:48] # 3. let s_string =pi_string[ptLen+n]...pi_string[ptLen+n+qLen-1] s_string = pi_string[48:] # 4. Gamma = string_to_point(gamma_string) gamma = _decode_point(gamma_string) # 5. if Gamma = "INVALID" output "INVALID" and stop. if gamma == "INVALID": return "INVALID" # 6. c = string_to_int(c_string) c = int.from_bytes(c_string, "little") # 7. s = string_to_int(s_string) s = int.from_bytes(s_string, "little") # 8. Output Gamma, c, and s return gamma, c, s def _assert_and_sample(keys, actuals): """ Input: key - key for assert values, basename (+ '_sample') for sampled values. Output: None; asserts actuals then and assigns into global test_dict If key exists, assert dict expected value against provided actual value. Sample actual value and store into test_dict under key + '_sample'. """ # noinspection PyGlobalUndefined global test_dict for key, actual in zip(keys, actuals): if key in test_dict and actual: assert ( actual == test_dict[key] ), f"{key} actual:{actual.hex()} != expected:{test_dict[key].hex()}" test_dict[key + "_sample"] = actual # Much of the following code has been adapted from ed25519.py # at https://ed25519.cr.yp.to/software.html retrieved 27 Dec 2019. # While it is gloriously inefficient, it provides an excellent demonstration of the underlying math. # For example, production code would likely avoid inversion via Fermat's little theorem as it is # extremely expensive with a cost of ~300 field multiplies. def _edwards_add(p, q): """Edwards curve point addition""" x1 = p[0] y1 = p[1] x2 = q[0] y2 = q[1] x3 = (x1 * y2 + x2 * y1) * _inverse(1 + D * x1 * x2 * y1 * y2) y3 = (y1 * y2 + x1 * x2) * _inverse(1 - D * x1 * x2 * y1 * y2) return [x3 % PRIME, y3 % PRIME] def _encode_point(p): """Encode point to string containing LSB OF X and 254 bits of y""" return ((p[1] & ((1 << 255) - 1)) + ((p[0] & 1) << 255)).to_bytes(32, "little") def _decode_point(s): """Decode string containing LSB of X and 254 bits of y into point. Checks on-curve. May return \"INVALID\" """ y = int.from_bytes(s, "little") & ((1 << 255) - 1) x = _x_recover(y) if x & 1 != _get_bit(s, BITS - 1): x = PRIME - x p = [x, y] if not _is_on_curve(p): return "INVALID" return p def _get_bit(h, i): """Return specified bit from string for subsequent testing""" h1 = int.from_bytes(h, "little") return (h1 >> i) & 0x01 def _get_secret_scalar(secret_key): """Calculate and return the secret_scalar integer""" h = bytearray(_hash(secret_key)[0:32]) h[31] = int((h[31] & 0x7F) | 0x40) h[0] = int(h[0] & 0xF8) secret_int = int.from_bytes(h, "little") return secret_int def _hash(message): """Return 64-byte SHA512 hash of arbitrary-length byte message""" return hashlib.sha512(message).digest() def _short_hash(message): """Return 64-byte SHA512 hash of arbitrary-length byte message... trimmed to 31 byte felt""" return hashlib.sha512(message).digest()[:31] def _inverse(x): """Calculate inverse via Fermat's little theorem""" return pow(x, PRIME - 2, PRIME) def _is_on_curve(p): """Check to confirm point is on curve; return boolean""" x = p[0] y = p[1] result = (-x * x + y * y - 1 - D * x * x * y * y) % PRIME return result == 0 def _scalar_multiply(p, e): """Scalar multiplied by curve point""" if e == 0: return [0, 1] q = _scalar_multiply(p, e // 2) q = _edwards_add(q, q) if e & 1: q = _edwards_add(q, p) return q def _x_recover(y): """Recover x coordinate from y coordinate""" xx = (y * y - 1) * _inverse(D * y * y + 1) x = pow(xx, (PRIME + 3) // 8, PRIME) if (x * x - xx) % PRIME != 0: x = (x * II) % PRIME if x % 2 != 0: x = PRIME - x return x # Constants, some of which are calculated/checked at runtime using above routines # See https://ed25519.cr.yp.to/python/checkparams.py SUITE_STRING = bytes([0x04]) BITS = 256 PRIME = 2**255 - 19 ORDER = 2**252 + 27742317777372353535851937790883648493 COFACTOR = 8 TWO_INV = _inverse(2) II = pow(2, (PRIME - 1) // 4, PRIME) A = 486662 D = -121665 * _inverse(121666) BASEy = 4 * _inverse(5) BASEx = _x_recover(BASEy) BASE = [BASEx % PRIME, BASEy % PRIME] assert BITS >= 10 assert 8 * len(_hash("hash input".encode("UTF-8"))) == 2 * BITS assert pow(2, PRIME - 1, PRIME) == 1 assert PRIME % 4 == 1 assert pow(2, ORDER - 1, ORDER) == 1 assert ORDER >= 2 ** (BITS - 4) assert ORDER <= 2 ** (BITS - 3) assert pow(D, (PRIME - 1) // 2, PRIME) == PRIME - 1 assert pow(II, 2, PRIME) == PRIME - 1 assert _is_on_curve(BASE) assert _scalar_multiply(BASE, ORDER) == [0, 1]