function [Xopt,Y,F,G,H,M,S] = instanceGenerator(imgfile,type,data,savefile,SNR,normalize)
%function [Xopt,Y,F,G,H,M,S] = instanceGenerator(imgfile,type,data,savefile,SNR,normalize)
%
% instance generator for generalized 2D phase retrieval dictionary learning
%
% INPUT: 
% imgfile  - filename (incl. path) of input original image Xopt
% type     - string specifying measurement type to create; choose from:
%            'cdp' - Sq. Fourier magnitudes of coded diffraction patterns
%            'gauss' - complex Gaussian measurements |G*Xopt|.^2
%            'gkronsymm' - complex Gaussian meas. |G*Xopt*G'|.^2
%            'gkron' - complex Gaussian meas. |G*Xopt*H'|.^2
%            --- experimental other options : -----------------------------
%            'lingauss' - linear (complex) Gaussian measurements G*Xopt
%            'lingkronsymm' - linear (complex) Gaussian meas. G*Xopt*G'
%            'lingkron' - linear (complex) Gaussian meas. G*Xopt*H'
%            'sqfft' - Squared Fourier magnitude measurements 
%              (output Xopt is embedding of input Xopt into larger 0-image)
%            'absfft' - Fourier magnitude measurements 
%              (output Xopt is embedding of input Xopt into larger 0-image)
%            'user' - use measurement function provided by user as data.fun
%              [data.fun MUST adhere to the same form as the input function
%              'fun' in DOLPHIn.m, i.e., [Yhat,gradX]=data.fun(F,X,Y) must
%              give Yhat = user-specified measurement function eval. at X
%              and gradient w.r.t. X of 0.25*|| Y - Yhat ||_F^2; input 'F'
%              is not used here but needed for compatability (just use []),
%              input Y is only needed to compute the gradient and can be 
%              omitted otherwise. Note: user-defined measurement function 
%              must be differentiable.]
%
% Note: Of course, if the user specifies matrices G and/or H by setting the
% fields in the data struct (describes next), these can be arbitrary and
% non-Gaussian; the terms used above pertain to what this function creates
% if no user-given matrices (or masks M, for that matter) are available!
%
% OPTIONAL INPUT: 
% data     - struct with predefined parts or parameters to be used in the 
%            instance generation process; leave empty to use defaults. 
%            Fields irrelevant to specified measurement type will be 
%            ignored. Valid fields are:
%            .G, .H - the matrices G,H used for Gaussian meas. types
%            .oversamplfact - sampling ratio in Gaussian case [default:4]
%            .M - diffraction masks (3D-array; M(:,:,j) is j-th mask)
%            .numM - number of diffraction masks [default:5]
%            .cdptype - type of diffraction masks; choose from 
%              'complex' [default], 'ternary', or 'binary' [experimental]
%            .fun - user-provided measurement function
%
% savefile - filename (incl. path) for saving the generated instance data
%            [default: don't save anything]
% SNR      - value inf gives noiseless measurements [default], otherwise 
%            specifies SNR value (in dB) s.t. white Gaussian noise is added 
%            to the measurements accordingly
% normalize - toggle whether to (L2-)normalize rows of measurement matrix 
%            (not available in cdp case) [default: no normalization]
%
% OUTPUT:
% Xopt     - original image (entries as doubles in [0,1]); 
% Y        - measurements of Xopt according to specified type;
%            if Xopt is an RGB color image, Y will contain measurements 
%            of all three color channels, so that, e.g., the measurements
%            for the "green" channel will be stored in Y{2}
% F        - linear operator used in measurements: a function handle whose
%            2nd argument specifies forward (0) or adjoint (1) application
% G,H,M    - matrices or diffraction masks used to create F
% S        - support of Xopt (as binary mask of the same size); 
%            only created in experimental 'sqfft' case, where Xopt is 
%            embedded into larger all-zero image and the support is thus 
%            known (exactly).
%            For other types, it is just S=1 (so that S.*X = X itself)
%
% Note: All output variables that are not needed (not related to selected
% measurement model) will be set to [].
%
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% This function is part of the DOLPHIn package (version 1.10)
% last modified: 04/26/2016, A. M. Tillmann
%
% You may freely use and modify the code for academic purposes, though we
% would appreciate if you could let us know (particularly should you find 
% a bug); if you use DOLPHIn for your own work, please cite the paper
%
%    "DOLPHIn -- Dictionary Learning for Phase Retrieval",
%    Andreas M. Tillmann, Yonina C. Eldar and Julien Mairal, 2016.
%    http://arxiv.org/abs/1602.02263
%
% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% set some parameters and defaults:
if( ~exist('SNR','var') ),               SNR = inf; end
if( exist('data','var') )
    if( isfield(data,'G') ),             G = data.G; end
    if( isfield(data,'H') ),             H = data.H; end
    if( isfield(data,'oversamplfact') ), oversamplfact = data.oversamplfact; end
    if( isfield(data,'M') ),             M = data.M; end
    if( isfield(data,'numM') ),          numM = data.numM; end
    if( isfield(data,'cdptype') ),       cdptype = data.cdptype; end
end
if( ~exist('oversamplfact','var') ),     oversamplfact = 4; end
if( ~exist('normalize','var') ),         normalize = 0; end

% reset random seed, if desired:
%rng(0); 

% read original image and create Xopt:
Xopt = double(imread(imgfile))/255;          
[m,n,c] = size(Xopt);
% if c>1: same operators, but different noise for each channel 

% create matrices/masks, measurement operator and measurements Y, 
% according to selected type:
if( strcmpi(type,'cdp') )
    if( ~exist('M','var') )
        if( ~exist('cdptype','var') ), cdptype = 'complex'; end % default
        if( ~exist('numM','var') ), numM = 5; end %default
        M = create_cdp_masks(m,n,numM,cdptype);  
        % add ID-"mask":
        %M(:,:,numM+1) = ones(m,n);
    end 
    F = @(XX,tt)op_cdp_2d(XX,M,tt); % measurement operator 
    for col=1:c
        Y{col} = fun_general_op(F,Xopt(:,:,col)); % = abs(F(Xopt,0)).^2
    end
    G = []; H = []; S = 1; 
elseif( strcmpi(type(end-4:end),'gauss') )
    if( ~exist('G','var') ) 
        % explicit Gaussian measurement matrix
        G = (1/sqrt(2))*randn(oversamplfact*max(m,n),m)+(1i/sqrt(2))*randn(oversamplfact*max(m,n),m);
        if( normalize )
            for ii=1:size(G,1), G(ii,:)=G(ii,:)./norm(G(ii,:)); end % normalize rows
        end
    end
    F = @(XX,tt)op_general_mat(G,XX,tt);
    for col=1:c
        if( strcmpi(type,'gauss') )
            Y{col} = fun_general_op(F,Xopt(:,:,col)); % = abs(F(Xopt,0)).^2
        else % type = 'lingauss'
            Y{col} = fun_lin_op(F,Xopt(:,:,col)); % = F(Xopt,0)
        end
    end
    H = []; M = []; S = 1;
elseif( strcmpi(type(end-3:end),'symm') )
    if( m ~= n )
        fprintf('Xopt is not square, cannot use (lin)gkronsymm measurement type; aborting.\n');
        Xopt = []; Y = []; F = []; G = []; H = []; M = []; S = [];
        return;
    end
    if( ~exist('G','var') )
        G = (1/sqrt(2))*randn(oversamplfact*m,m)+(1i/sqrt(2))*randn(oversamplfact*m,m);
        if( normalize )
           for ii=1:size(G,1), G(ii,:)=G(ii,:)./norm(G(ii,:)); end % normalize rows
        end
    end        
    % measurement operator: G*X*G'; adjoint G'*Z*G
    F = @(XX,tt)op_kron_mat(G,G,XX,tt);
    for col=1:c
        if( strcmpi(type,'gkronsymm') )
            Y{col} = fun_general_op(F,Xopt(:,:,col)); % = abs(F(Xopt,0)).^2
        else % type = 'lingkronsymm'
            Y{col} = fun_lin_op(F,Xopt(:,:,col)); % = F(Xopt,0)
        end
    end
    H = G; M = []; S = 1;
elseif( strcmpi(type(end-3:end),'kron') )
    if( ~exist('G','var') )
        G = (1/sqrt(2))*randn(oversamplfact*m,m)+(1i/sqrt(2))*randn(oversamplfact*m,m);
        if( normalize )
            for ii=1:size(G,1), G(ii,:)=G(ii,:)./norm(G(ii,:)); end % normalize rows
        end
    end        
    if( ~exist('H','var') )
        H = (1/sqrt(2))*randn(oversamplfact*n,n)+(1i/sqrt(2))*randn(oversamplfact*n,n);   
        if( normalize )
            for ii=1:size(H,1), H(ii,:)=H(ii,:)./norm(H(ii,:));end % normalize rows
        end
    end    
    % measurement operator: G*X*H'; adjoint G'*Z*H
    F = @(XX,tt)op_kron_mat(G,H,XX,tt);
    for col=1:c
        if( strcmpi(type,'gkron') )
            Y{col} = fun_general_op(F,Xopt(:,:,col)); % = abs(F(Xopt,0)).^2
        else % type = 'lingkron'
            Y{col} = fun_lin_op(F,Xopt(:,:,col)); % = F(Xopt,0)
        end
    end
    M = []; S = 1;
elseif( strcmpi(type,'user') )
    for col=1:c
        Y{col} = data.fun([],Xopt(:,:,col)); % user-specified measurements
    end
    F = []; G = []; H = []; M = []; S = 1;
elseif( strcmpi(type,'sqfft') )
    F = @(XX,tt)op_fft_2d(XX,tt); % measurement operator
    origXopt = Xopt; clear Xopt;
    for col=1:c
        Xopt(:,:,col) = blkdiag(zeros(m*(oversamplfact/4),n*(oversamplfact/4)), origXopt(:,:,col), zeros(m*(oversamplfact/4),n*(oversamplfact/4)));   % embedding
        S(:,:,col) = blkdiag(zeros(m*(oversamplfact/4),n*(oversamplfact/4)), ones(m,n), zeros(m*(oversamplfact/4),n*(oversamplfact/4))); % (exact!) support mask
        Y{col} = fun_general_op(F,Xopt(:,:,col)); % = abs(F(Xopt,0)).^2        
    end
    G = []; H = []; M = [];
elseif( strcmpi(type,'absfft') )
    F = @(XX,tt)op_fft_2d(XX,tt); % measurement operator
    origXopt = Xopt; clear Xopt;
    for col=1:c
        Xopt(:,:,col) = blkdiag(zeros(m*(oversamplfact/4),n*(oversamplfact/4)), origXopt(:,:,col), zeros(m*(oversamplfact/4),n*(oversamplfact/4)));   % embedding
        S(:,:,col) = blkdiag(zeros(m*(oversamplfact/4),n*(oversamplfact/4)), ones(m,n), zeros(m*(oversamplfact/4),n*(oversamplfact/4))); % (exact!) support mask
        Y{col} = abs(F(Xopt(:,:,col),0)); 
    end
    G = []; H = []; M = [];
else
    fprintf('Unknown type; aborting.\n');
    Xopt = []; F = []; G = []; H = []; M = []; S = [];
    return;
end

if( ~isinf(SNR) )
    for col=1:c
        Y{col} = reshape(awgn(Y{col}(:),SNR,'measured','dB'),size(Y{col}));
    end
end

if( c == 1 ) % just one color channel to consider...
    Y=Y{1};
end

if( exist('savefile','var') && ~isempty(savefile) )
    save(savefile,'Xopt','Y','F','G','H','M','S','type');
end
