//*********************************************************
// Linearization Equations Attack against MI 
// The program reads in the public key from public_key.txt
// and performs Patarins Linearization Equations Attack to decrypt
// a randomly chosen ciphertext
//*********************************************************


printf "**************************************************************************\n";
printf "Linearization Equations Attack against the Matsumoto-Imai Cryptosystem *** \n";
printf "*************************************************************************** \n \n";


load "public_key.txt";


Vn:=VectorSpace(F,n);

ciphertext:=Random(Vn);


printf "We want to find the plaintext corresponding to the ciphertext %o \n \n", Vn!ciphertext;


number1:=(n+1)^2;

// find equations of the form \sum a_ij x_iy_j + \sum b_i x_i + sum c_j y_j + d = 0

printf" Create %o plaintext/ciphertext pairs \n \n ", number1; 
// Table of plaintext/ciphertext pairs
Plaintexts:=[];
Ciphertexts:=[];
for i:=1 to number1 do
	Pub:=Pk; 
	plaintext:=Random(Vn);
	for loop:=1 to n do
		for j:=1 to n do
			Pub[loop]:=Evaluate(Pub[loop],x[j],plaintext[j]);
		end for;
	end for;
	Plaintexts[i]:=plaintext;
	Ciphertexts[i]:=Vn!Pub;
end for; // --------------------------------------------------------------------------------------------
printf" Plaintexts \t \t Ciphertexts \n";
for i:=1 to number1 do
	printf" %o \t  \t %o \n", Plaintexts[i], Ciphertexts[i];
end for;
printf" \n \n";

printf"Find Linearization equations of the form \sum x_iy_i + \sum x_i + \sum y_i + d =0 \n \n";
//Matrix for the equations
LinEq:=ZeroMatrix(F,number1,n*n+n+n+1);
for loop:=1 to  number1 do
	counter:=1;
	for i:=1 to n do // a_ij
		for j:=1 to n do
			LinEq[loop][counter]:=Plaintexts[loop][i]*Ciphertexts[loop][j];
			counter:=counter+1;
		end for;
	end for;
	for i:=1 to n do // b_i
		LinEq[loop][counter]:=Plaintexts[loop][i];
		counter:=counter+1;
	end for;
	for j:=1 to n do //c_i
		LinEq[loop][counter]:=Ciphertexts[loop][j];
		counter:=counter+1;
	end for;
	LinEq[loop][counter]:=1; // d;
end for;

Kera:=Kernel(Transpose(LinEq));
dim:=Dimension(Kera);
printf " Found %o linear independent Linearization Equations \n \n", dim;
Base:=Basis(Kera);
//=============================================================================================================

Poly:=PolynomialRing(F,2*n);

nam:=[];
for i:=1 to n do
	nam[i]:= "u[" cat IntegerToString(i) cat "]";
end for;
for i:=n+1 to 2*n do
	nam[i]:= "v[" cat IntegerToString(i-n) cat "]";
end for;

AssignNames(~Poly,nam);


Lineq1:=[]; //Linearization Equations
for loop:=1 to #Base do
	Lineq1[loop]:=Poly!0;
	counter:=1;
	for i:=1 to n do
		for j:=n+1 to 2*n do
			Lineq1[loop]:=Lineq1[loop]+ Base[loop][counter]*Poly.i*Poly.j; // a_ij
			counter:=counter+1;
		end for;
	end for;
	for i:=1 to n do
		Lineq1[loop]:=Lineq1[loop]+Base[loop][counter]*Poly.i; // b_i
		counter:=counter+1;
	end for;
	for j:=n+1 to 2*n do
		Lineq1[loop]:=Lineq1[loop]+Base[loop][counter]*Poly.j; // c_j
		counter:=counter+1;
	end for;
	Lineq1[loop]:=Lineq1[loop]+Base[loop][counter]; // d

end for;
Lineq1;
//===========================================================================================================


printf "Insert ciphertext into the linearization equations \n \n";

for loop:=1 to #Base do // substitute ciphertext into linearization equations
	for i:=1 to n do
		Lineq1[loop]:=Evaluate(Lineq1[loop],Poly.(i+n),ciphertext[i]);
	end for;
end for;

Mat:=ZeroMatrix(F,#Base,n+1); // Equations for plaintext variables
for loop:=1 to #Base do
	for i:=1 to n do
		Mat[loop][i]:=MonomialCoefficient(Lineq1[loop],Poly.i);
	end for;
	Mat[loop][n+1]:=MonomialCoefficient(Lineq1[loop],1);
end for;

Mat:=EchelonForm(Mat);
t:=Rank(Mat);

printf"Found %o linear equations in the plaintext variables \n \n", t;

pol:=[];
for loop:=1 to t do
	pol[loop]:=Pol!0;
	for i:=1 to n do
		pol[loop]:=pol[loop]+ Mat[loop][i]*x[i];
	end for;
	pol[loop]:=pol[loop]+Mat[loop][n+1];
end for;

for loop:=1 to t do
	printf" %o = 0 \n", pol[loop];
end for;

V:=Variety(Ideal(pol));
printf "\n possible Plaintexts \t \t Ciphertexts \n";
for i:=1 to #V do
	ciph:= Pk;
	for j:=1 to n do
		for k:=1 to n do
			ciph[j]:= Evaluate(ciph[j],x[k],V[i][k]);
		end for;
	end for;
	printf" %o \t \t %o \n", V[i], Vn!ciph;
	if Vn!ciph eq ciphertext then
		sol:=V[i];
	end if;
end for;


printf"\nOur ciphertext %o corresponds to the plaintext %o \n \n", ciphertext, sol;
