// ******************************************
// MQ identification scheme (3-round version)
// creates a random quadratic system of m equations in n variables
// and a key pair of MQIdent
//  The function MQIDround(Ch) generates, on the input of Ch in (0,1,2), the output of the prover
// **********************************************


q:=4;
F<w>:=GF(q);
m:=4;
n:=4;
Vm:=VectorSpace(F,m);
Vn:=VectorSpace(F,n);
Vnm:=VectorSpace(F,n+m);
VRsp:=VectorSpace(F,2*n+m);
Pol<[x]>:=PolynomialRing(F,n);

printf "******************************************************** \n";
printf "*** MQ based identification scheme (3 round version) *** \n";
printf "******************************************************** \n\n\n";


printf "***Parameters*** \n\n";
printf "F:= GF(%o) \n \n",q;
printf "m:= %o \n\n",m;
printf "n:= %o \n\n", n;

ev:= function(m,n,pol,s)
    v:=[];
    for i:=1 to m do
        v[i]:=Evaluate(pol[i],x[1],s[1]);
        for j:=2 to n do
            v[i]:=Evaluate(v[i],x[j],s[j]);
        end for;
    end for;
    v:=Vm!ChangeUniverse(v,F);
    return v;
end function;

pol:=[];
for i:= 1 to m do
    pol[i]:=Pol!0;
    for j:=1 to n do
        for k:=j to n do
            pol[i]:=pol[i]+Random(F)*x[j]*x[k];
        end for;
        pol[i]:=pol[i]+Random(F)*x[j];
    end for;
end for;

printf "\n\n***Key Generation***\n\n";
printf "P:= %o \n \n", pol;

s:= Random(Vn); // private key

v:=ev(m,n,pol,s);

printf "s:= %o \n \n", s;
printf "v:= %o \n \n", v;

MQIDround:= function(Ch)
    r0:=Random(Vn);
    t0:=Random(Vn);
    e0:=Random(Vm);

    r1:=s-r0;
    t1:=r0-t0;

    h:=ev(m,n,pol,r0);

    e1:=h-e0;

    printf "r0:= %o \n \n", r0;
    printf "t0:= %o \n \n", t0;
    printf "e0:= %o \n \n", e0;
    printf "r1:= %o \n \n", r1;
    printf "t1:= %o \n \n", t1;
    printf "e1:= %o \n \n\n", e1;
    
    g:=ev(m,n,pol,t0+r1)-ev(m,n,pol,t0)-ev(m,n,pol,r1) + e0;

    c0:= Vnm!(Eltseq(r1) cat Eltseq(g));
    c1:= Vnm!(Eltseq(t0) cat Eltseq(e0));
    c2:= Vnm!(Eltseq(t1) cat Eltseq(e1));
    COM:=[c0,c1,c2];

    printf "c0:=Com %o \n \n", c0;
    printf "c1:=Com %o \n \n", c1;
    printf "c2:=Com %o \n \n", c2;  

    printf"Ch:= %o \n\n", Ch;
    
    case Ch:
        when 0: Rsp:=Eltseq(r0) cat Eltseq(t1) cat Eltseq(e1);
        when 1: Rsp:=Eltseq(r1) cat Eltseq(t1) cat Eltseq(e1);
        when 2: Rsp:=Eltseq(r1) cat Eltseq(t0) cat Eltseq(e0);
    end case;

    printf "Rsp:= %o \n \n", VRsp!(Rsp);
    return COM, Rsp;	
end function; 

MQver:= function(COM,Ch,Rsp)
    tr:=true;
    case Ch:
        when 0: 
            r0:= [Rsp[1]]; // parse Rsp into r0,t1,e1
            for i:=2 to n do
	         r0:=r0 cat [Rsp[i]];
	     end for;
	     r0:=Vn!(r0);  
	
            t1:= [Rsp[n+1]];
            for i:=n+2 to 2*n do
	         t1:=t1 cat [Rsp[i]];
	     end for;
	     t1:=Vn!(t1);
		
            e1:= [Rsp[2*n+1]];
	     for i:=2*n+2 to 2*n+m do
	         e1:=e1 cat [Rsp[i]];
	     end for;
	     e1:=Vn!(e1);
	     c1:= COM[2];
	     c2:= COM[3];
			
	     if Vnm! (Eltseq(r0-t1) cat Eltseq(ev(m,n,pol,r0)-e1)) ne c1 then//check
	         tr:= false;
	     elif Vnm!( Eltseq(t1) cat Eltseq(e1) ) ne c2 then
		  tr:=false;
	     end if;
	
	 when 1: 
	     r1:= [Rsp[1]]; // parse Rsp into r0,t1,e1
	     for i:=2 to n do
	         r1:=r1 cat [Rsp[i]];
	     end for;
	     r1:=Vn!(r1);  
	
     	     t1:= [Rsp[n+1]];
	     for i:=n+2 to 2*n do
	         t1:=t1 cat [Rsp[i]];
	     end for;
            t1:=Vn!(t1);
		
            e1:= [Rsp[2*n+1]];
	     for i:=2*n+2 to 2*n+m do
		  e1:=e1 cat [Rsp[i]];
	     end for;
	     e1:=Vn!(e1);
	     c0:= COM[1];
	     c2:= COM[3];			
	    
            if Vnm! (Eltseq(r1) cat Eltseq(v - ev(m,n,pol,r1) - ev(m,n,pol,t1+r1) + ev(m,n,pol,t1) + ev(m,n,pol,r1) - e1)) ne c0 then //check
	         tr:= false;
	     elif Vnm!( Eltseq(t1) cat Eltseq(e1) ) ne c2 then
		  tr:=false;
	     end if;
		
        when 2: 
	     r1:= [Rsp[1]]; // parse Rsp into r0,t1,e1
	     for i:=2 to n do
	         r1:=r1 cat [Rsp[i]];
	     end for;
	     r1:=Vn!(r1);  
		
            t0:= [Rsp[n+1]];
	     for i:=n+2 to 2*n do
	         t0:=t0 cat [Rsp[i]];
	     end for;
	     t0:=Vn!(t0);
		
            e0:= [Rsp[2*n+1]];
	     for i:=2*n+2 to 2*n+m do
	         e0:=e0 cat [Rsp[i]];
	     end for;
	     e0:=Vn!(e0);
	     c0:= COM[1];
	     c1:= COM[2];
			
	     if Vnm! (Eltseq(r1) cat Eltseq(ev(m,n,pol,t0+r1) - ev(m,n,pol,t0) -ev(m,n,pol,r1)+e0)) ne c0 then // check
	         tr:= false;
	     elif Vnm!( Eltseq(t0) cat Eltseq(e0) ) ne c1 then
	         tr:=false;
	     end if;

    end case;
    return tr;
end function;