Files
sqisign_new/scripts/precomp/cformat.py

128 lines
5.0 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import sys, itertools
from math import ceil, floor, log
import sage.all
class Ibz:
def __init__(self, v):
self.v = int(v)
def _literal(self, sz):
val = int(self.v)
sgn = val < 0
num_limbs = (abs(val).bit_length() + sz-1) // sz if val else 0
limbs = [(abs(val) >> sz*i) & (2**sz-1) for i in range(num_limbs or 1)]
data = {
'._mp_alloc': 0,
'._mp_size': (-1)**sgn * num_limbs,
'._mp_d': '(mp_limb_t[]) {' + ','.join(map(hex,limbs)) + '}',
}
return '{{' + ', '.join(f'{k} = {v}' for k,v in data.items()) + '}}'
class FpEl:
ref_p5248_radix_map = { 16: 13, 32: 29, 64: 51 }
ref_p65376_radix_map = { 16: 13, 32: 28, 64: 55 }
ref_p27500_radix_map = { 16: 13, 32: 29, 64: 57 }
def __init__(self, n, p, montgomery=True):
self.n = n
self.p = p
self.montgomery = montgomery
def __get_radix(self, word_size, arith=None):
if arith == "ref" or arith is None:
# lvl1
if self.p == 0x4ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff:
return self.ref_p5248_radix_map[word_size]
# lvl3
elif self.p == 0x40ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff:
return self.ref_p65376_radix_map[word_size]
# lvl5
elif self.p == 0x1afffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff:
return self.ref_p27500_radix_map[word_size]
raise ValueError(f'Invalid prime \"{self.p}\"')
elif arith == "broadwell":
return word_size
raise ValueError(f'Invalid arithmetic implementation type \"{arith}\"')
def _literal(self, sz, arith=None):
radix = self.__get_radix(sz, arith=arith)
l = 1 + floor(log(self.p, 2**radix))
# If we're using Montgomery representation, we need to multiply
# by the Montgomery factor R = 2^nw (n = limb number, w = radix)
if self.montgomery:
R = 2**(radix * ceil(log(self.p, 2**radix)))
else:
R = 1
el = (self.n * R) % self.p
vs = [(int(el) >> radix*i) % 2**radix for i in range(l)]
return '{' + ', '.join(map(hex, vs)) + '}'
class Object:
def __init__(self, ty, name, obj):
if '[' in ty:
idx = ty.index('[')
depth = ty.count('[]')
def rec(os, d):
assert d >= 0
if not d:
return ()
assert isinstance(os,list) or isinstance(os,tuple)
r, = {rec(o, d-1) for o in os}
return (len(os),) + r
dims = rec(obj, depth)
self.ty = ty[:idx], ''.join(f'[{d}]' for d in dims)
else:
self.ty = ty, ''
self.name = name
self.obj = obj
def _declaration(self):
return f'extern const {self.ty[0]} {self.name}{self.ty[1]};'
def _literal(self):
def rec(obj):
if isinstance(obj, int):
if obj < 256: return str(obj)
else: return hex(obj)
if isinstance(obj, sage.all.Integer):
if obj < 256: return str(obj)
else: return hex(obj)
if isinstance(obj, Ibz):
literal = "\n#if 0"
for sz in (16, 32, 64):
literal += f"\n#elif GMP_LIMB_BITS == {sz}"
literal += f"\n{obj._literal(sz)}"
return literal + "\n#endif\n"
if isinstance(obj, FpEl):
literal = "\n#if 0"
for sz in (16, 32, 64):
literal += f"\n#elif RADIX == {sz}"
if sz == 64:
literal += "\n#if defined(SQISIGN_GF_IMPL_BROADWELL)"
literal += f"\n{obj._literal(sz, 'broadwell')}"
literal += "\n#else"
literal += f"\n{obj._literal(sz, 'ref')}"
literal += "\n#endif"
else:
literal += f"\n{obj._literal(sz, 'ref')}"
return literal + "\n#endif\n"
if isinstance(obj, list) or isinstance(obj, tuple):
return '{' + ', '.join(map(rec, obj)) + '}'
if isinstance(obj, str):
return obj
raise NotImplementedError(f'unknown type {type(obj)} in Formatter')
return rec(self.obj)
def _definition(self):
return f'const {self.ty[0]} {self.name}{self.ty[1]} = ' + self._literal() + ';'
class ObjectFormatter:
def __init__(self, objs):
self.objs = objs
def header(self, file=None):
for obj in self.objs:
assert isinstance(obj, Object)
print(obj._declaration(), file=file)
def implementation(self, file=None):
for obj in self.objs:
assert isinstance(obj, Object)
print(obj._definition(), file=file)