// =========================================
// Key Generation of (Projected) SFLASH
// The program generates a random (P)FLASH key pair
// and stores the public key in public_key.txt and the private key in private_key.txt
// =================================================

function squareandmultiply(m,e,p) // for fast exponentiation
	e:=Intseq(e,2);
	s:=# e; // bitlength of e
	erg:=1;
	for i:=1 to s do
		if e[i] eq 1 then
			erg:=erg *m mod p;
		end if;
		m:=m^2 mod p;
	end for;
	return erg;
end function;

// parameters

q:=4;
n:=5; // # variables
a:=2; // # Minus Equations
s:=1; // # projected variables
m:=n-a; // # equations
theta:=2;

g,thetainv,_:=XGCD(q^theta+1, q^n-1); // exponent for decryption
thetainv:= thetainv mod (q^n-1) ;

if g gt 1 then
	printf"FEHLER!!! Falsches theta";
end if;

F<w>:=GaloisField(q);
Vectn:=VectorSpace(F,n);
Vectm:=VectorSpace(F,m);
Pol<[x]>:=PolynomialRing(F,n);
Poly<X>:=PolynomialRing(Pol);
Poll<[y]>:=PolynomialRing(Pol,n);

irrpol:=Poly!(IrreduciblePolynomial(F,n));

FIELDEQ:=[]; // field equations over ground field
for i:=1 to n do
	FIELDEQ:=FIELDEQ cat [x[i]^q-x[i]];
end for;
FIELDIdeal:=Ideal(FIELDEQ);

// affine map T ----------------------------------------------------------------
repeat
	MT:=RandomMatrix(F,n,n);
until IsInvertible(MT) eq true;
cT:=Random(Vectn);

T:=[];
for i:=1 to n do
	T[i]:=cT[i]+MT[i][1]*Pol.1;
	for j:= 2 to n do
		T[i]:=T[i]+MT[i][j]*Pol.j;
	end for;
end for;
// ----------------------------------------------------------------------------

//----------------affine map S ----------------------------------------------
repeat
	MSF:=RandomMatrix(F,m,n);
until Rank(MSF) eq m;
MS:=Matrix(Pol,m,n,ChangeUniverse(Eltseq(MSF),Pol));
cSF:=RandomMatrix(F,m,1);
cS:=Matrix(Pol,m,1,ChangeUniverse(Eltseq(cSF),Pol));
// -----------------------------------------------------------------------

MIA:=0; // lift to the extension field
for i:=1 to n do
	MIA:=MIA+T[i]*X^(n-i);
end for;

MIB:=squareandmultiply(MIA,q^theta+1,irrpol) mod irrpol;

p:=[]; // move down to the vector space
for i:=1 to n do
	p[i]:=MonomialCoefficient(MIB,X^(n-i));
end for;

for i:=1 to n do // reduce modulo field equations
	p[i]:=NormalForm(p[i],FIELDIdeal);
end for;

MIB:=Pol!0;
for i:=1 to n do
	MIB:=MIB+p[i]*X^(i-1);
end for;

P:=ZeroMatrix(Pol,n,1);
for i:=1 to n do
	P[i][1]:=p[i];
end for;

Pk:=Eltseq(MS*P+cS); // Compute public key

// Projection ----------------------------
for i:=1 to m do
	for j:=n-s+1 to n do
		Pk[i]:= Evaluate(Pk[i],x[j],F!0);
	end for;
end for;
// ---------------------------------------------

// Output ===========================================

printf "************************************************ \n";
printf "*** SFLASH Signature Scheme - Key Generation *** \n";
printf "************************************************ \n \n";

printf "q:= %o  \n \n", q;
printf "n:= %o  \n \n", n;
printf "Minus:= %o  \n \n", a;
printf "Projection:= %o  \n \n", s;
printf "irrpol:= %o  \n \n", irrpol;
printf "theta:= %o \n \n", theta;
printf "thetainv:= %o \n \n", thetainv;
printf "S:= \n%o \n \n", MS;
printf "cS:= %o \n \n", VectorSpace(F,m)!Eltseq(cS);

printf "T:= \n%o \n \n", MT;
printf "cT:= %o \n \n", cT;
printf "MIA:= %o \n \n", MIA;
printf "MIB:= \n%o \n \n", MIB;
printf "pk:= \n%o \n \n", Pk;

printf"Write public_key.txt \n \n";
SetOutputFile("public_key.txt":Overwrite:=true);
printf "q:= %o ; \n \n", q;
printf "n:= %o ; \n \n", n;
printf "a:= %o ; \n \n", a;
printf "s:= %o ; \n \n", s;
printf "F<w>:=GaloisField(q); \n \n";
printf "Pol<[x]>:=PolynomialRing(F,n); \n \n";
printf "Pk:= %o ; \n \n",Pk; ;
UnsetOutputFile();
 
printf"Write private_key.txt \n \n";
SetOutputFile("private_key.txt":Overwrite:=true);
printf "q:= %o ; \n \n", q;
printf "n:= %o ; \n \n", n;
printf "a:= %o ; \n \n", a;
printf "s:= %o ; \n \n", s;
printf "F<w>:=GaloisField(q); \n \n";
printf "Pol<[x]>:=PolynomialRing(F,n); \n \n";
printf "P<X>:=PolynomialRing(F); \n \n";
printf "thetainv:= %o ;\n \n", thetainv;
printf "irrpol:= %o ; \n \n", irrpol;
printf "MT:= %o ; \n \n", Eltseq(MT);
printf "cT:= %o ; \n \n", Eltseq(cT);
printf "MS:= \n%o ; \n \n", Eltseq(MS);
printf "cS:= \n%o ; \n \n", Eltseq(cS);
UnsetOutputFile();
