% Basic test script for DOLPHIn (DictiOnary Learning for PHase retrIeval)
%
% A variety of model, design and algorithmic parameters can be set;
% for details as to their meaning/role, see DOLPHIn.m, setupParams.m and/or
% instanceGenerator.m
%
% Example setup runs DOLPHIn on 5 octanary coded diffraction pattern
% measurements, corrupted by 15dB-(white Gaussian)-noise, of the 256x256
% Cameraman test image (grayscale), with lots of output.
% (To also run Wirtinger Flow, set runWF to 1 in line 124.)
%
% Note: RGB images are handled directly by this script using only DOLPHIn.m 
% (i.e., DOLPHInRGB.m is *not* employed here), simply for regrouping the
% statistics and storing them in a slightly different fashion. 
%
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% This function is part of the DOLPHIn package (version 1.10)
% last modified: 06/08/2016, A. M. Tillmann 
%
% You may freely use and modify the code for academic purposes, though we
% would appreciate if you could let us know (particularly should you find 
% a bug); if you use DOLPHIn for your own work, please cite the paper
%
%    "DOLPHIn -- Dictionary Learning for Phase Retrieval",
%    Andreas M. Tillmann, Yonina C. Eldar and Julien Mairal, 2016.
%    http://arxiv.org/abs/1602.02263
%
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
addpath('tools/');
% ========== initialization: test setup / data ============================

% set random seed:
%rng(0);%rng('shuffle'); % use, e.g., rng(0) for reproducibility

% specify measurement type and corresponding function: 
typeList = {
     % ------- well-tested options: ---------------------------------------
     'cdp';                              % 1 
     'gauss';                            % 2
     'gkronsymm';                        % 3
     'gkron';                            % 4 
     % ------- experimental options: (not fully implemented/tested) -------
     'lingauss';                         % 5
     'lingkronsymm';                     % 6
     'lingkron';                         % 7
     'sqfft';                            % 8
     'absfft';                           % 9
     'user'                              % 10
     };

selectTypeNumber = 9; 
type = typeList{selectTypeNumber}; 

% specify original image:
imageList = {
     'images/cameraman256.png';          % 1  - 256x256 grayscale
     'images/house256.png';              % 2
     'images/peppers256.png';            % 3
     'images/lena512.png';               % 4  - 512x512 grayscale
     'images/barbara512.png';            % 5
     'images/boat512.png';               % 6 
     'images/fingerprint512.png';        % 7
     'images/mandril512.png';            % 8
     'images/mandrill512color.png';      % 9  - 512x512 RGB
     'images/waldspirale.png'            % 10 - 2816x2112 RBG
     'images/smallone.png'               % 11 - 64x64 grayscale
     };
% All images used in the paper can be found online, directly or as part of 
% some test image library or another, so we did not include them in the 
% DOLPHIn package to keep its size down. However, they are available as a 
% separate zip-file along with the code package, at the same webpage.
% The "waldspirale" image was found on wikimedia commons at
% upload.wikimedia.org/wikipedia/commons/b/bc/Darmstadt-Waldspirale-Hundertwasser4.jpg
 
selectImageNumber = 11;
imageOrig = imageList{selectImageNumber}; % set own image path, if desired

% create function handles for measurement operator:
if( strcmpi(type(1:3),'lin') )
    fun = @fun_lin_op;
elseif( strcmpi(type,'absfft') )
    fun = @fun_absfft_op;
elseif( strcmpi(type,'user') )
    % here, provide arbitrary user-specified measurement function (handle)
    % fun = @fun_user, for formal requirements, see the description of 
    % 'input data.fun' in instanceGenerator.m
    % ...
else
    fun = @fun_general_op;
end


% create (phase retrieval) instance: (see instanceGeneration.m for details)
data.cdptype       = 'complex'; % type of cdp masks
data.numM          =         5; % number of cdp masks
data.oversamplfact =         4; % oversampling factor (non-cdp models)
normalize          =         0; % toggle normalizing columns of matrices
                                % used for non-cdp measurement operator
SNR                =        15; % level of noise (added to measurements)
savefile           =        []; % filename to save instance data
if( strcmpi(type,'user') )
    data.fun = fun; % user-specified measurement function, defined earlier
end
% other fields that can be specified by user, see instanceGenerator.m :
% data.G, data.H, data.M

[Xopt,Y,F,G,H,M,S] = instanceGenerator(imageOrig,type,data,savefile,SNR,normalize);

% create function handle for projection onto constraint set for X, and
% toggle using support information: (need to have binary support mask S 
% (same size as Xopt) if activated; only specified by default in experi-
% mental absfft and sqfft cases)
if( strcmpi(type(end-2:end),'fft') )
    useSupp = 1;
    proj = @(XX)proj_box(XX.*S); 
else
    useSupp = 0;
    proj = @(XX)proj_box(XX);
end
% Note: if you wish to use a different constraint than the box 0<=X<=1, set
% proj to a corresponding function (handle) that computes the projection
% onto it. (Then, a known support might already be integrated into that
% function, in which case S is not needed and useSupp can be remain 0.)
% For pre-implemented functions/operators, see DOLPHIn/tools folder.


% ========== also run standard Wirtinger Flow? ============================
% toggle also running standard Wirtinger Flow (real-valued variant using
% projections onto constraints after gradient step)
runWF = 0; 
% Note: for direct comparison with DOLPHIn, use same X0 (and Y of course)!

% ========== initialization: DOLPHIn parameters ===========================
clear params; % clean up so defaults for phase 1 are reset correctly
clear stat;   % clean up statistics struct if present in variable space
clear X0 D0 A0 X D A RDA Xwf; % clean up reconstruction and init. variables
close all; % close all figures 

timer = tic; % time spend initializing stuff here will later be added to
             % DOLPHIn-output stat{1}.timeInit to correct accuracy
                
% required parameters:
[params.h,params.w,c] = size(Xopt); % c=1: grayscale, c=3: RGB image

% needed in this file to handle both grayscale and RGB cases in unified way:
if( c==1 )
    if( numel(size(Y))>2 ), Y = mat2cell(Y,size(Y,1),size(Y,2),size(Y,3)); 
    else Y = mat2cell(Y,size(Y,1),size(Y,2)); end
end
params.mY = numel(Y{1});


% optional parameters: (here, we set some values as in the experiments from
%                      the paper, depending on measurement type; all others
%                      are set to default values, see setupParams.m)
params.verbose = 2;

% standard DOLPHIn (cf. Algorithm 1 in the paper) ...
params.L = -1;
params.overlap = 0;
% ... or sparsity-constrained DOLPHIn variant?
%params.L = 4; 
%params.overlap = 1;

params.AfixK1 = 0; % do not update A during first K1 iterations (i.e., fix both A and D)
params.adjustA = 0;

%params.L = -1; % iterations for D-update, see setupParams.m
%params.Xiter = 1;
%params.BCDiter = 1;

% choice of regularization parameters mu and lambda, measurement-type-dep.:
if( strcmpi(type(1),'g') ) % type 'gauss', 'gkron' or 'gkronsymm'
    if( params.L <= 0 ) % standard (l1-regularized) DOLPHIn:
        params.mu = 0.5 * params.mY;
        if( strcmpi(type,'gauss') )
            params.lambda = 0.105 * params.mY;
        else
            params.lambda = 0.210 * params.mY;
        end
    %else % sparsity-constrained DOLPHIn:
    %    params.mu = 5e-3 * params.mY; % default; set in setupParams.m
    end
%else % type 'cdp' defaults; set in setupParams.m
%    params.mu     = 5e-2 * params.mY;
%    params.lambda = 3e-3 * params.mY;
end
%params.lambda = 7e-3 * params.mY; % used in large-scale cdp experiment
                                   % ("waldspirale" image)

% iteration limits for DOLPHIN phases 1 (D fix) and 2 (D updated):
%params.K1 = 25; % default; set in setupParams.m
%params.K2 = 50; % default if L<=0 (it's =K1 if L>0); set in setupParams.m                          

%params.K1 = 75; params.K2 = 0; %params.mu = 0; params.lambda = 0; %<-> WF
%params.K1 = 15; params.K2 = 60; % might be a bit better than current defaults
 

params = setupParams(params); 

% ========== loop over color channels =====================================
for col=1:c
    if( col > 1 )
        timer = tic; % reset watch for next run
        % params needs not be reset as it is left unchanged by DOLPHIn/WF
    end

    % ========== initialization: X0,D0,A0 =================================
    % initial image estimate:
    X0(:,:,col) = proj(rand([params.h,params.w])); % default, as in DOLPHIn.m
    %X0(:,:,col) = WFinitX0(F,Y{col},[params.h,params.w],proj,10); % spectral init.
    
    % initial dictionary: 
    if( col > 1 )
        D0{col} = D{col-1}; % init. w/ final dict. from last run
    else
        s = params.s1*params.s2; % pixels per patch
        D0{col} = [eye(s),dict_dct(params.s1,s)]; % default, as in DOLPHIn.m
        %D0{col} = randn(s,2*s);
    end
    
    % initial patch representations:
    A0{col} = []; % default: will be set to D0(:,:,col)\E(X0(:,:,col)) by DOLPHIn
    
    timeExternalInit = toc(timer);

    % ========== run DOLPHIn (recording statistics) and WF, if chosen: ====    
    if( ~strcmpi(type(1),'absfft') )
        % DOLPHIn:
        [D{col}, A{col}, X(:,:,col), RDA(:,:,col), stat{col}] = ...
            DOLPHIn(F,Y{col},fun,proj,params,Xopt(:,:,col),X0(:,:,col),D0{col},A0{col});            
    else
         % DOLPHInAbsFFT:
        [D{col}, A{col}, X(:,:,col), RDA(:,:,col), stat{col}] = ...
            DOLPHInAbsFFT(F,Y{col},fun,proj,params,Xopt(:,:,col),X0(:,:,col),D0{col},A0{col});            
    end


    % Wirtinger Flow:
    if( runWF )
        timer = tic;
        Xwf(:,:,col) = WF_RealCons(F,Y{col},fun,proj,params.h,...
            params.w,params.K1+params.K2,params.gammaX,X0(:,:,col),params.verbose);
        % add additional fields to statistics struct:
        stat{col}.timeWF = toc(timer);
        stat{col}.timeWFtotal = stat{col}.timeWF + timeExternalInit;
        stat{col}.mseXwf = immse(Xwf(:,:,col),Xopt(:,:,col));
        stat{col}.ssimXwf = ssim(Xwf(:,:,col),Xopt(:,:,col));
        [stat{col}.psnrXwf,stat{col}.snrXwf] = psnr(Xwf(:,:,col),Xopt(:,:,col));
    end
    
    % adjust/correct initialization time:
    stat{col}.timeInit = stat{col}.timeInit + timeExternalInit;
    stat{col}.timeTotal = stat{col}.timeTotal + timeExternalInit;
    
    % display final results:
    if( params.verbose )
        if( c > 1 )
            fprintf('=== channel %d: ================================================\n',col);
        else
            fprintf('===============================================================\n');            
        end
        fprintf('Init.  X0 : (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f\n',stat{col}.timeInit,stat{col}.mseX0,stat{col}.psnrX0,stat{col}.snrX0,stat{col}.ssimX0);
        fprintf('DOLPHIn X : (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f\n',stat{col}.timeK1+stat{col}.timeK2,stat{col}.mseX,stat{col}.psnrX,stat{col}.snrX,stat{col}.ssimX);
        fprintf('P_X(R(DA)): (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f (DA), avg. sparsity: %.2f\n',stat{col}.timeK1+stat{col}.timeK2,stat{col}.mseRDA,stat{col}.psnrRDA,stat{col}.snrRDA,stat{col}.ssimRDA,stat{col}.nnzA);
        
        if( params.verbose > 1 )
            figure(2+col-1);colormap(gray);
            subplot(1,3+runWF,1);
            imagesc(Xopt(:,:,col));
            if( c > 1 )
                title(sprintf('Original image (ch.%d)',col));
            else
                title(sprintf('Original image',col));
            end
            subplot(1,3+runWF,2);
            imagesc(RDA(:,:,col));
            title('Image reconstr. P_X(R(D*A))');
            subplot(1,3+runWF,3);
            imagesc(X(:,:,col));
            title('Image reconstr. X');
            if( runWF )
                subplot(1,4,4);
                imagesc(Xwf(:,:,col));
                title('Wirt. Flow reconstr.');
                fprintf('Wirt.F. X : (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f\n',stat{col}.timeWF,stat{col}.mseXwf,stat{col}.psnrXwf,stat{col}.snrXwf,stat{col}.ssimXwf);
            end
        end
    end
end % end of loop over color channels

% ========== gather final statistics ======================================
if( c < 2 )
    stat = stat{1}; stat.colorInfo = 'grayscale';
else
    statTotal.colorInfo = 'RGB';
    statTotal.statR = stat{1};statTotal.statG = stat{2};statTotal.statB = stat{3};
    statTotal.timeInit = stat{1}.timeInit+stat{2}.timeInit+stat{3}.timeInit;
    statTotal.timeWF = stat{1}.timeWF+stat{2}.timeWF+stat{3}.timeWF;
    statTotal.timeWFtotal = stat{1}.timeWFtotal+stat{2}.timeWFtotal+stat{3}.timeWFtotal;
    statTotal.noPatchesPerChannel = stat{1}.noPatches;
    statTotal.timeK1 = stat{1}.timeK1+stat{2}.timeK1+stat{3}.timeK1;
    statTotal.timeK2 = stat{1}.timeK2+stat{2}.timeK2+stat{3}.timeK2;
    statTotal.timeTotal = statTotal.timeInit+statTotal.timeK1+statTotal.timeK2;
    statTotal.nnzA = mean([stat{1}.nnzA,stat{2}.nnzA,stat{3}.nnzA]);
    statTotal.mseX0 = immse(X0,Xopt);
    statTotal.ssimX0 = ssim(X0,Xopt);
    [statTotal.psnrX0,statTotal.snrX0] = psnr(X0,Xopt);
    statTotal.mseX = immse(X,Xopt);
    statTotal.ssimX = ssim(X,Xopt);
    [statTotal.psnrX,statTotal.snrX] = psnr(X,Xopt);
    statTotal.mseXwf = immse(Xwf,Xopt);
    statTotal.ssimXwf = ssim(Xwf,Xopt);
    [statTotal.psnrXwf,statTotal.snrXwf] = psnr(Xwf,Xopt);
    statTotal.mseRDA = immse(RDA,Xopt);
    statTotal.ssimRDA = ssim(RDA,Xopt);
    [statTotal.psnrRDA,statTotal.snrRDA] = psnr(RDA,Xopt);
    stat = statTotal; clear statTotal;
    % display final results (RGB):
    if( params.verbose )
        fprintf('=== RGB: ======================================================\n');
        fprintf('Init.  X0 : (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f\n',stat.timeInit,stat.mseX0,stat.psnrX0,stat.snrX0,stat.ssimX0);
        fprintf('DOLPHIn X : (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f\n',stat.timeK1+stat.timeK2,stat.mseX,stat.psnrX,stat.snrX,stat.ssimX);
        fprintf('P_X(R(DA)): (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f (DA), avg. sparsity: %.2f\n',stat.timeK1+stat.timeK2,stat.mseRDA,stat.psnrRDA,stat.snrRDA,stat.ssimRDA,stat.nnzA);        
        if( params.verbose > 1 )
            figure(2+c);
            subplot(1,3+runWF,1);
            imagesc(Xopt);
            title(sprintf('Original image (RGB)',col));
            subplot(1,3+runWF,2);
            imagesc(RDA);
            title('Image reconstr. P_X(R(D*A))');
            subplot(1,3+runWF,3);
            imagesc(X);
            title('Image reconstr. X');
            if( runWF )
                subplot(1,4,4);
                imagesc(Xwf);
                title('Wirt. Flow reconstr.');
                fprintf('Wirt.F. X : (time: %4.2fs)  MSE = %.2e, PSNR = %.2f, SNR = %.2f, SSIM = %.2f\n',stat.timeWF,stat.mseXwf,stat.psnrXwf,stat.snrXwf,stat.ssimXwf);
            end
        end
    end
end