/* ==============================================
 Key Generation of the PMI(+) Cryptosystem
 The program generates a key pair of the PMI cryptosystem
 and stores private and public key in the files public_key.txt and private_key.txt
===========================================================*/

function squareandmultiply(m,e,p) // for fast exponentiation of polynomials
	e:=Intseq(e,2);
	s:=# e; // bitlength of e
	erg:=1;
	for i:=1 to s do
		if e[i] eq 1 then
			erg:=erg *m mod p;
		end if;
		m:=m^2 mod p;
	end for;
	return erg;
end function;

// ============ parameters =====================

q:=4; // field size
n:=5; // # variables
r:=1; // dimension of the perturbation space
s:=1; // # plus polynomials
theta:=2; // MI parameter

F<w>:=GaloisField(q);
Vectn:=VectorSpace(F,n);

Pol<[x]>:=PolynomialRing(F,n);
Poll<[y]>:=PolynomialRing(Pol,r);
Poly<X>:=PolynomialRing(Pol);
irrpol:=Poly!(IrreduciblePolynomial(F,n));

FIELDEQ:=[]; // field equations over ground field
for i:=1 to n do
	FIELDEQ:=FIELDEQ cat [x[i]^q-x[i]];
end for;
FIELDIdeal:=Ideal(FIELDEQ);
// =============================================

g,thetainv,_:=XGCD(q^theta+1,q^n-1); // parameter for decryption
thetainv:=thetainv mod (q^n-1);

if g gt 1 then
	printf"ERROR!!! Wrong theta";
end if;


// affine map T ----------------------------------------------------------------
repeat
	MT:=RandomMatrix(F,n,n);
until IsInvertible(MT) eq true;
cT:=Random(Vectn);

T:=[];
for i:=1 to n do
	T[i]:=cT[i]+MT[i][1]*x[1];
	for j:= 2 to n do
		T[i]:=T[i]+MT[i][j]*x[j];
	end for;
end for;
// ----------------------------------------------------------------------------

//----------------affine map S ----------------------------------------------

repeat
	MSF:=RandomMatrix(F,n+s,n+s);
until IsInvertible(MSF) eq true;
MS:=Matrix(Pol,n+s,n+s,ChangeUniverse(Eltseq(MSF),Pol));
cSF:=RandomMatrix(F,n+s,1);
cS:=Matrix(Pol,n+s,1,ChangeUniverse(Eltseq(cSF),Pol));
// -----------------------------------------------------------------------

MIA:=0; // lift into the extension field
for i:=1 to n do
	MIA:=MIA+T[i]*X^(i-1);
end for;

MIB:=squareandmultiply(MIA,1+q^theta,irrpol); // MIB=MIA^(1+q^theta)

p:=[];  // move down to the vector space
for i:=1 to n do
	p[i]:=MonomialCoefficient(MIB,X^(i-1));
end for;

for i:=1 to n do // reduce modulo field equations
	p[i]:=NormalForm(p[i],FIELDIdeal);
end for;

MIB:=Pol!0;
for i:=1 to n do
	MIB:=MIB+p[i]*X^(i-1);
end for;

// Internal Perturbation
Z:= [];
for i:=1 to r do
	Z[i]:=Pol!0;
	for j:=1 to n do
		Z[i]:=Z[i]+Random(F) * x[j];
	end for;
end for;

Fb:=[];
for i:=1 to n do
	Fb[i]:=Poll!0;
	for j:=1 to r do
		for k:=j to r do
			Fb[i]:=Fb[i] + Random(F) * y[j]*y[k];
		end for;
		Fb[i]:=Fb[i]+Random(F) * y[j];
	end for;
	Fb[i]:=Fb[i]+Random(F);
end for;

// mu and lambda table
A:=CartesianPower(F,r);
B:=[i: i in A];
Lambda:=[];
for i:=1 to q^r do
	lambda:= [];
	for j:=1 to r do
		lambda:= lambda cat [B[i][j]];
	end for;
	Lambda:=Lambda cat [lambda];
end for;

Mu:=[];
for i:=1 to q^r do
	mu:=Fb;
	for j:=1 to n do
		for k:=1 to r do
			mu[j]:=Evaluate(mu[j],y[k],Lambda[i][k]);
		end for;
	mu[j]:=MonomialCoefficient(mu[j],1);
	mu[j]:=MonomialCoefficient(mu[j],1);
end for;
Mu:=Mu cat [ChangeUniverse(mu,F)];
end for;

Fc:=Fb;
for i:=1 to n do
	for j:=1 to r do
		Fc[i]:=Evaluate(Fc[i],y[j],Z[j]);
	end for;
	p[i]:=p[i] + MonomialCoefficient(Fc[i],1);
end for;
//+++++++========================================

// Plus polynomials =========================
for i:=n+1 to n+s do
	p[i]:=Pol!0;
	for j:=1 to n do
		for k:=1 to n do
			p[i]:=p[i]+Random(F)*x[j]*x[k];
		end for;
		p[i]:=p[i]+Random(F)*x[j];
	end for;
	p[i]:=p[i]+Random(F);
end for;
//===============================================

P:=ZeroMatrix(Pol,n+s,1); // Public Key Computation
for i:=1 to n+s do
	P[i][1]:=p[i];
end for;

Pk:=MS*P+cS; // Compute public key

for i:=1 to n+s do
	Pk[i][1]:=NormalForm(Pk[i][1],FIELDIdeal); // reduce modulo field equations
end for;


// Output ==========================================================================
printf "q:= %o  \n \n", q;
printf "n:= %o  \n \n", n;
printf "irrpol:= %o  \n \n", irrpol;
printf "theta:= %o \n \n", theta;
printf "thetainv:= %o \n \n", thetainv;
printf "S:= \n%o \n \n", MS;
printf "cS:= %o \n \n", cS;

printf "T:= \n%o \n \n", MT;
printf "cT:= %o \n \n", cT;
printf "MIA:= %o \n \n", MIA;
printf "MIB:= \n%o \n \n", MIB;
printf "Z:= %o \n \n", Z;
printf "Fb:= \n%o \n \n", Fb;
printf "Lambda \t \t Mu \n";

for i:=1 to q^r do
	printf "%o \t \t %o \n", Lambda[i], Mu[i];
end for;

printf "\n";
printf "p:= \n%o \n \n", p;
printf "pk:= \n%o \n \n", Pk;

printf"Write public_key.txt \n \n";
SetOutputFile("public_key.txt":Overwrite:=true);
printf "q:= %o ; \n \n", q;
printf "n:= %o ; \n \n", n;
printf "s:= %o ; \n \n", s;
printf "F<w>:=GaloisField(q); \n \n";
printf "Pol<[x]>:=PolynomialRing(F,n); \n \n";
printf "Pk:= %o ; \n \n", Eltseq(Pk);
UnsetOutputFile();
 
printf"Write private_key.txt \n \n";
SetOutputFile("private_key.txt":Overwrite:=true);
printf "q:= %o ; \n \n", q;
printf "n:= %o ; \n \n", n;
printf "r:= %o ; \n \n", r;
printf "s:= %o ; \n \n", s;
printf "F<w>:=GaloisField(q); \n \n";
printf "Pol<[x]>:=PolynomialRing(F,n); \n \n";
printf "Poly<X>:=PolynomialRing(F); \n \n";
printf "thetainv:= %o ;\n \n", thetainv;
printf "irrpol:= %o ; \n \n", irrpol;
printf "MT:= %o ; \n \n", Eltseq(MT);
printf "cT:= %o ; \n \n", Eltseq(cT);
printf "MS:= \n%o ; \n \n", Eltseq(MS);
printf "cS:= \n%o ; \n \n", Eltseq(cS);
printf "Lambda:= \n%o ; \n \n", Eltseq(Lambda);
printf "Mu:= \n%o ; \n \n", Eltseq(Mu);
printf "Z:= \n%o ; \n \n", Eltseq(Z);
UnsetOutputFile();

