function [D, A, X, RDA, stat] = DOLPHIn(F,Y,fun,proj,params,Xopt,X0,D0,A0)
%function [D, A, X, RDA, stat] = DOLPHIn(F,Y,fun,proj,params,Xopt,X0,D0,A0)
%
% This is DOLPHIn, an alternating minimization scheme to simultaneously 
% solve the 2D (generalized) phase retrieval problem and learn a dictionary 
% along with sparse representations for image patches.
%
% Note: This function assumes measurements/image correspond to a single 
%       color channel! Use DOLPHInRGB.m for RGB images; see also
%       testDOLPHIn.m
%
% INPUT: 
%   F - function implementing the measurement operator; must be of the form
%       F(.,t) such that t=0 applies the op. and t=1 applies its adjoint.
%   Y - the measurements of true image Xopt, i.e., Y = fun(F,Xopt,0)
%   fun - function to evaluate Yhat := fun(F,X,0) = F(X,0) at given X and
%       compute gradient gradX w.r.t. X of the phase retrieval objective 
%       term 0.25*|| Y - Yhat ||_F^2; must be of the form 
%       [Yhat,gradX] = fun(F,X,Y)  (Y is only used in gradient computation)
%   proj - function implementing the projection of X onto some constraints
%       (e.g., proj=@(Z)max(min(Z,1),0).*S projects onto pixel values in 
%       [0,1] w.r.t. a binary support mask S (1 on support, 0 otherwise));
%       must be of the form [projX] = proj(X)
%   params - struct with the algorithmic parameters; except for params.h 
%       and params.w (specifying the image dimensions h x w), all fields 
%       can be left empty to use default values (see setupParams.m)
%
% OPTIONAL INPUT:
%   Xopt - original image (ground truth to be recovered); used to compute
%       some (more) statistics only and can be left empty or omitted
%   X0, D0, A0 - initial image estimate, dictionary and patch representa-
%       tions (columns of A0), resp.; can be left empty or omitted to use
%       defaults [random X0, D0=[ID DCT], A0=argmin||E(X0)-D0*A||_F^2]
%
% OUTPUT:
%   D, A, X - the learned dictionary, (sparse) patch representations and 
%       final image reconstruction
%   RDA - image reconstruction via sparse patch representations; often this
%       gives a better image estimate than the final X
%   stat - struct containing various statistics about the DOLPHIn run/sol.,
%       most of which are only evaluated if original image (ground truth) 
%       Xopt was provided by the user.
%       Fields independent of Xopt are:
%       .params1/.params2 - full set of DOLPHIn parameters in phase 1/2
%       .noPatches - number of patches (= column dimension of A)
%       .tabobj - array containing objective function values throughout the 
%                 iteration process
%       .timeInit - time spent on initialization
%       .timeK1 - time spent in DOLPHIn phase 1 (dictionary kept fixed)
%       .timeK2 - time spent in DOLPHIn phase 2 (dictionary updated)
%       .timeTotal - total runtime of DOLPHIn call
%       .nnzA1  - average number of nonzeros in columns of A after first K1
%            iterations
%       .nnzA   - average number of nonzeros in columns of final A
%       Fields that are only created if Xopt was given are:
%       .mseX0/.snrX0/.psnrX0/.ssimX0 - MSE, SNR, PSNR and SSIM values of
%            initial image estimate X0
%       .mseX1/.snrX1/.psnrX1/.ssimX1 - same for X-estimate after first K1
%            iterations
%       .mseRDA1/.snrRDA1/.psnrRDA1/.ssimRDA1 - same for estimate from
%            patch representations, P_X(R(D*A)), after first K1 iterations
%       .mseX/.snrX/.psnrX/.ssimX - same for final X-estimate
%       .mseRDA/.snrRDA/.psnrRDA/.ssimRDA - same for final P_X(R(D*A))
%
% NOTE: This code requires an installation of the SPAMS package;
% the SPAMS path needs to be set accordingly in setupParams.m
%
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% This function is part of the DOLPHIn package (version 1.10)
% last modified: 06/16/2016, A. M. Tillmann
%
% You may freely use and modify the code for academic purposes, though we
% would appreciate if you could let us know (particularly should you find 
% a bug); if you use DOLPHIn for your own work, please cite the paper
%
%    "DOLPHIn -- Dictionary Learning for Phase Retrieval",
%    Andreas M. Tillmann, Yonina C. Eldar and Julien Mairal, 2016.
%    http://arxiv.org/abs/1602.02263
%
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% ========== initialization ===============================================
timer = tic;
format compact;

% initialize algorithmic/model parameters:

params.mY = numel(Y); % number of measurements (per color channel)
params = setupParams(params);
h = params.h; w = params.w; s1 = params.s1; s2 = params.s2; s = s1*s2;

% initialize image variable:
if( ~exist('X0','var') || isempty(X0) )
    X0 = proj(rand(h,w));
    %X0 = WFinitX0(F,Y,[h,w],proj); % spectral initialization, requires F
end
X = X0;

% initialize dictionary:
if( ~exist('D0','var') || isempty(D0) )
    D0 = [eye(s),dict_dct(s1,s)];
end
D = D0;
n = size(D,2);

% initialize patch representations:
Z = myim2col(X,[s1 s2],params.overlap); 
if( ~exist('A0','var') || isempty(A0) )    
    A0 = D\Z; % LS solution to "D0*A=E(X0)", i.e., argmin||E(X0)-D0*A||_F^2    
end
A = A0;
if( params.AfixK1 ) % if both A- and D-updates are deactivated at first...
    mu = params.mu; lambda = params.lambda;
    params.mu = 0; params.lambda = 0; % ...switch to WF for first K1 iter.
end


% initial objective function and gradient evaluation:
[Yhat, gradX] = fun(F,X,Y);
objval = @(YYhat,AA,ZZ,DD,muval,lambdaval)0.25*sum(abs(YYhat(:)-Y(:)).^2)... 
    +0.5*muval*sum(sum((ZZ-DD*AA).^2))...
    +lambdaval*sum(abs(AA(:))); 
obj = objval(Yhat,A,Z,D,params.mu,params.lambda);

recstats = (nargout > 4); 
if( params.verbose || recstats )
    stat.tabobj = [obj;zeros((1+(~params.AfixK1))*params.K1+3*params.K2+params.adjustA,1)];
    stat.noPatches = size(A,2);
end
innerii = 0; % counter for all inner iterations; will end up at (1+(~params.AfixK1))*K1+3*K2

% initialize step size (for X-updates):
params.gammaX = 1e4*params.gammaX/obj;
gammaX = params.gammaX;

% specify method for A-updates
if( params.L < 0 )      % use |L| (F)ISTA iterations
    updateA = @(AinitGuess,DD,ZZ,pars)mexFistaFlat(ZZ,DD,AinitGuess,pars);
    updateApars = params.FISTApar;
elseif( params.L > 0 )  % use OMP with sparsity bound L
    updateA = @(AinitGuess,DD,ZZ,pars)mexOMP(ZZ,DD,pars);
    updateApars = params.OMPpar;    
else                    % use Homotopy method
    updateA = @(AinitGuess,DD,ZZ,pars)mexLasso(ZZ,DD,pars);
    updateApars = params.LASSOpar;
end

% header and initial output, if iteration log is turned on:
if( params.verbose )
    fprintf('\n DOLPHIn -- DictiOnary Learning for PHase retrIeval -- v1.00\n');
    fprintf('===============================================================\n');
    fprintf(' Iteration  |  Objective  |  Avg. Patch Sparsity  |  Stepsize \n');
    fprintf('===============================================================\n');
    fprintf('     Init.  |  %.3e  |       %.3e       |  %.3e\n',obj,nnz(A)/stat.noPatches,gammaX);
    if( params.verbose > 1) % additionally plot progress
        figure;
        subplot(2,2,1);
        displayPatches(D); title('Dictionary D');
        subplot(2,2,2);
        semilogy(stat.tabobj); title('Objective/iter.');
        subplot(2,2,3);
        imagesc(X);
        colormap(gray);
        title('Image reconstruction');
        subplot(2,2,4);
        RDA = proj(mycol2im(D*A,[s1 s2],[h w],params.overlap));
        imagesc(RDA);
        title('Image repres. using D');
        drawnow;
    end
end

% initialize struct to record statistics:
haveXopt = (exist('Xopt','var') && ~isempty(Xopt));
if( recstats )     
    %stat.tabobj and stat.noPatches were already set
    stat.params1 = params;
    if( haveXopt )
        stat.mseX0  = immse(X0,Xopt);
        stat.ssimX0 = ssim(X0,Xopt);
        [stat.psnrX0,stat.snrX0] = psnr(X0,Xopt);
    end
    if( params.K1 < 1 )
        stat.timeK1 = 0;
        stat.mseX1 = stat.mseX0; stat.ssimX1 = stat.ssimX0;
        stat.psnrX1 = stat.psnrX0; stat.snrX1 = stat.snrX0;
        if( ~exist('RDA','var') ), RDA = proj(mycol2im(D*A,[s1 s2],[h w],params.overlap)); end
        stat.mseRDA1 = immse(RDA,Xopt);
        stat.ssimRDA1 = ssim(RDA,Xopt);
        [stat.psnrRDA1,stat.snrRDA1] = psnr(RDA,Xopt);        
    end    
    stat.timeInit = toc(timer);
end 


% ========== main loop (alternating minimization) =========================
for ii = 1:(params.K1+params.K2)
    if( params.L > 0 && ii == params.K1+1 ) % entering phase 2
         % adjust L and mu in sparsity-constrained DOLPHIn:         
         params.mu = 1.68*params.mu; 
         params.L = min(2*params.L,sqrt(s)); params.OMPpar.L = params.L;
         updateApars = params.OMPpar;    
    end
    if( params.AfixK1 && ii == params.K1+1 )
        params.mu = 1.68*mu; params.lambda = lambda;
        Z = myim2col(X,[s1 s2],params.overlap); A = D\Z;
    end
    
    % ---------- update A -------------------------------------------------
    if( ii > params.AfixK1*params.K1 ) % if AfixK1=true, fixes A until iter. K1+1
        innerii = innerii+1;
        A = updateA(A,D,Z,updateApars); 
        
        obj = objval(Yhat,A,Z,D,params.mu,params.lambda);   
        % output and statistics:
        if( params.verbose )
            fprintf(' %5d (A)  |  %.3e  |       %.3e       |\n',ii,obj,nnz(A)/stat.noPatches);
            stat.tabobj(innerii) = obj;
        elseif( recstats )
            stat.tabobj(innerii) = obj;
        end    
    end

    % ---------- update X -------------------------------------------------
    innerii = innerii+1;
     
    R = Z-D*A;    
    for kk=1:params.Xiter
        grad = gradX +  params.mu*mycol2im(R,[s1 s2],[h w],2*params.overlap); % 2*overlap: no averaging (if patches do overlap, otherwise: irrelevant)
        oldobj = obj; Xold = X; count = 1;   
        while(true)
            X = proj(Xold - gammaX*grad);
            
            Yhat = fun(F,X,Y);
            Z = myim2col(X,[s1 s2],params.overlap);
            obj = objval(Yhat,A,Z,D,params.mu,params.lambda);
            if( oldobj >= obj ), break; end % successfully decreased objective
            gammaX = gammaX/2;
            count = count+1;
            if( count >= 100 ), break; end  % linesearch failed (100 trials)
        end
        gammaX = 1.68*gammaX;
        if( count == 100 )
            if( params.verbose), fprintf('No descent after 100 trials (linesearch in X-update). Aborting.\n'); end
            break;
        end
        [Yhat, gradX] = fun(F,X,Y); % gradX is gradient of phase retrieval term, therefore independent of D and A, so recomputation here is valid
    end
    if( params.verbose )
        fprintf(' %5d (X)  |  %.3e  |                       |  %.3e\n',ii,obj,gammaX/2);
        stat.tabobj(innerii) = obj;
    elseif( recstats )
        stat.tabobj(innerii) = obj;
    end
 
    % --------- update D --------------------------------------------------
    if( ii > params.K1 ) % D is only updated after K1 iterations
        innerii = innerii+1;
        B = Z*A';
        C = A*A';        
        for kk=1:params.BCDiter
            for jj=1:n
                if(C(jj,jj) > 0)
                    newdj = (B(:,jj)-D*C(:,jj))/C(jj,jj) + D(:,jj);
                else
                    newdj = randn(s,1);
                end
                normnewdj = sqrt(sum(newdj.^2));
                D(:,jj) = newdj / normnewdj;
                %D(:,jj) = newdj / max(1,normnewdj); % strictly speaking, this is the correct "projection"
                if( params.adjustA )
                    A(jj,:) = A(jj,:) * normnewdj; % warning: can give improved results but may destroy convergence!
                    %A(jj,:) = A(jj,:) * max(1,normnewdj);                   
                end                
            end
        end
        obj = objval(Yhat,A,Z,D,params.mu,params.lambda);
        if( params.verbose )
            fprintf(' %5d (D)  |  %.3e  |                       |\n',ii,obj);
            stat.tabobj(innerii) = obj;
        elseif( recstats)
            stat.tabobj(innerii) = obj;
        end
    end
   
   % --- remaining iteration log / output, if switched on: 
   if( params.verbose )       
       if( mod(ii,20) == 0 )    
           fprintf('===============================================================\n');
           fprintf(' Iteration  |  Objective  |  Avg. Patch Sparsity  |  Stepsize \n');
           fprintf('===============================================================\n');
       end
       if( params.verbose > 1 ) % additionally, plot progress
           subplot(2,2,1);
           displayPatches(D); colormap(gray); title('Dictionary D');
           subplot(2,2,2);
           semilogy(stat.tabobj); title('Objective/iter.');
           subplot(2,2,3);
           imagesc(X); title('Image reconstruction');
           subplot(2,2,4);
           RDA = proj(mycol2im(D*A,[s1 s2],[h w],params.overlap));
           imagesc(RDA);
           title('Image repres. using D');
           drawnow;
       end
   end
   
   % --- record more statistics after first K1 iterations, if switched on:
   if( ii == params.K1 && recstats )
       stat.nnzA1 = nnz(A)/stat.noPatches;
       if( haveXopt )
           stat.mseX1 = immse(X,Xopt);
           stat.ssimX1 = ssim(X,Xopt);
           [stat.psnrX1,stat.snrX1] = psnr(X,Xopt);
           if( ~exist('RDA','var') ), RDA = proj(mycol2im(D*A,[s1 s2],[h w],params.overlap)); end
           stat.mseRDA1 = immse(RDA,Xopt);
           stat.ssimRDA1 = ssim(RDA,Xopt);
           [stat.psnrRDA1,stat.snrRDA1] = psnr(RDA,Xopt);           
           if( params.K2 > 0 )
               clear RDA; 
           else
               stat.timeK2 = 0;
               stat.mseX = stat.mseX1; stat.ssimX = stat.ssimX1;
               stat.psnrX = stat.psnrX1; stat.snrX = stat.snrX1;
               stat.nnzA = stat.nnzA1;
               stat.mseRDA = stat.mseRDA1; stat.ssimRDA = stat.ssimRDA1;
               stat.psnrRDA = stat.psnrRDA1; stat.snrRDA = stat.snrRDA1;
           end
       end
       stat.timeK1 = toc(timer) - stat.timeInit;       
   end
end
% ========== end of main loop =============================================

% additional A-update, to adjust for final dictionary update:
if( params.K2 > 0 && params.adjFinA )
    A = updateA(A,D,Z,updateApars); 
    % output and statistics:
    if( params.verbose )
        obj = objval(Yhat,A,Z,D,params.mu,params.lambda);
        fprintf(' adjust A   |  %.3e  |       %.3e       |\n',obj,nnz(A)/stat.noPatches);
        stat.tabobj(innerii+1) = obj;
    elseif( recstats )
        obj = objval(Yhat,A,Z,D,params.mu,params.lambda);
        stat.tabobj(innerii+1) = obj;
    end      
    clear RDA; % so that it is recomputed with updated A below
end

% compute image representation from sparsely coded patches, if necessary:
if( ~exist('RDA','var') && (nargout > 3 || haveXopt) )
    RDA = proj(mycol2im(D*A,[s1 s2],[h w],params.overlap));
end
% record final statistics, if switched on:
if( recstats )
   stat.timeK2 = toc(timer) - stat.timeK1;
   if( haveXopt )
       stat.mseX = immse(X,Xopt);
       stat.ssimX = ssim(X,Xopt);
       [stat.psnrX,stat.snrX] = psnr(X,Xopt);
       stat.mseRDA = immse(RDA,Xopt);
       stat.ssimRDA = ssim(RDA,Xopt);
       [stat.psnrRDA,stat.snrRDA] = psnr(RDA,Xopt);
   end
   stat.nnzA = nnz(A)/stat.noPatches;
   stat.timeTotal = toc(timer);
   stat.params2 = params;
end