function [ f, alpha ] = pao_alpha_nonlinear_opt( kappa, f0, alpha0, avg, mask )

if ~exist('mask', 'var')
    mask = [];
end

if ischar( kappa )
    kappa = im2double(imread(kappa));
end

n_pix = size(kappa,1) * size(kappa,2);
n_vars = 3 + n_pix;

lb = zeros(n_vars,1);
ub = zeros(n_vars,1);
ub(1:3) = 1;
ub(4:end) = pi/2;

x0 = [reshape(f0,1,[]) reshape(alpha0,1,[])];
    
% Pixel weight
kappa_w = avg ./ repmat(max(0.01, sum(avg,3)), [1 1 3]);
for i = 1:size(kappa_w,3)
    aux = kappa_w(:,:,i);
    aux(mask == 0) = 0;
    kappa_w(:,:,i) = aux;
end

[kdx kdy] = pao_gradient(kappa, 3);
kmag = sqrt(kdx.^2 + kdy.^2);
kmag = max(kmag,[],3);
kappa_w = kappa_w .* repmat(1.1 - kmag / max(kmag(:)),[1 1 size(kappa,3)]);

kappa_w(:) = 1;
if ~isempty(mask)
    kappa(mask == 0) = 0;
end

% Set options for non-linear optimization
lsqnonlin_options = optimset('Algorithm', 'trust-region-reflective');
lsqnonlin_options.Display = 'iter-detailed';
lsqnonlin_options.JacobPattern = jacobian_sparsity_pattern(size(kappa));
lsqnonlin_options.MaxIter = 10;

% Optimize objective function
fprintf('Running optimization\n');

x_opt = lsqnonlin(@(x) objective_function(x, kappa, kappa_w), x0,...
    lb, ub, lsqnonlin_options);

f = x_opt(1:3);
alpha = reshape(x_opt(4:end), [size(kappa,1) size(kappa,2)]);

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Objective Function

function [ diff_pred ] = objective_function(x, kappa, kappa_w)

f = x(1:3);
alpha = reshape(x(4:end), [size(kappa,1), size(kappa,2)]);

kappa_pred = pao_compute_kappa(f, alpha);

diff_pred = (kappa_pred - kappa) .* kappa_w;

if false
    debug = max(abs(diff_pred), [] , 3);
    imagesc(debug);
    colorbar
    error 'debug'
end

diff_pred = diff_pred(:);

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Jacobian Sparsity Pattern

function [ J ] = jacobian_sparsity_pattern( imsize )

n_f = imsize(3);
n_alpha = imsize(1) * imsize(2);

f_ijs = zeros(n_alpha * n_f, 2);
for i = 1:n_f
    f_ijs((i - 1) * n_alpha  + (1:n_alpha),1) = (i - 1) * n_alpha  + (1:n_alpha);
    f_ijs((i - 1) * n_alpha  + (1:n_alpha),2) = i;
end

alpha_ijs = zeros(n_alpha * n_f, 2);
for i = 1:n_f
    alpha_ijs((i - 1) * n_alpha  + (1:n_alpha), 1) = (i - 1) * n_alpha  + (1:n_alpha);
    alpha_ijs((i - 1) * n_alpha  + (1:n_alpha), 2) = n_f + (1:n_alpha);
end


ijs = cat(1, f_ijs, alpha_ijs);
vals = ones(size(ijs,1), 1);

J = sparse(ijs(:,1), ijs(:,2), vals, n_f * n_alpha, n_f + n_alpha);

end
