% This experimental code is published in connection with a scientific publication, Silent error detection in numerical time-stepping schemes,
% by Austin R. Benson, Sven Schmit, and Robert Schreiber, to appear in International Journal of High Performance Computing Applications in 2014.
% This code is made available solely to allow the readers of that publication to verify and reproduce the results described. 
% This experimental code is published "as is", with no representation, warranty, indemnification of any kind. 
% Hewlett-Packard excludes all liability that may result from the use of this experimental code in any form.

%
% Numerically solve the nonhomogeneous heat equation with finite
% differences.
%
%    u_t(t, x) = Ku_xx(t, x) + f(t, x)
%
%    u(x, 0) = g(x)
%
%    u(0, t) = BC0(t)      
%    u(1, t) = BC1(t)
%
%    x \in [0, 1], t \in [0, T]
%
function [n_diff_CN_FE, n_diff_CN_Ri, n_diff_CN_DF, ...
    n_diff_BE, U_CN, U_Ri, U_BE, U_FE, U_DF, t, x, errs] = ...
    heat_solve(K, f, g, BC0, BC1, T, dx, dt, t_fault, x_fault, fault_amt)

nx = floor(1 / dx);
nt = floor(T / dt);

lambda = K * dt / dx^2;
alpha = 2 * lambda;
errs = 0;

x = linspace(0,1,nx);
t = linspace(0,T,nt);

% norm difference stats
n_diff_CN_FE = zeros(nt,1);
n_diff_CN_Ri = n_diff_CN_FE;
n_diff_CN_DF = n_diff_CN_FE;
n_diff_BE = n_diff_CN_FE;

% solve
U_CN = zeros(nt,nx); % Crank-Nicholson
U_CN(1,:) = g(x);
U_CN(:, 1) = BC0(t);   % BC
U_CN(:, end) = BC1(t); % BC

% copy boundary and initial conditions
U_Ri = U_CN; % Richardson
U_DF = U_CN; % DuFort-Frankel
U_BE = U_CN; % Backward Euler
U_FE = U_CN; % Forward Euler

% form B and A and take QR of B
% Crank-Nicholson: BU^{n+1} = AU^n
mat_vals = zeros(nx-2,1);
mat_vals(1) = 1 + lambda;
mat_vals(2) = -lambda / 2;
B = sparse(toeplitz(mat_vals));
R_CN = chol(B);

mat_vals(1) = 1 - lambda;
mat_vals(2) = lambda / 2;
A = sparse(toeplitz(mat_vals));

% form C and take QR
% Backward Euler: CU^{n+1} = U^n
mat_vals(1) = 1 + 2 * lambda;
mat_vals(2) = -lambda;
C = sparse(toeplitz(mat_vals));
R_BE = chol(C);


% solve system
%fprintf('Producing error of %f at iteration %d\n', fault_amt, fault_ind);

f_type = 'f_eval';
%f_type = 'prev_soln';
%f_type = 'RHS_overall';
%f_type = 'CN_A_diag2zero';
%f_type = 'CN_A_offdiag2zero';
%f_type = 'CN_A_offdiag2nonzero';
%f_type = 'chol_offdiag2zero';
%f_type = 'chol_offdiag2nonzero';
%f_type = '';

n_type = 'inf'; % norm to use

if 0
Op_vals = [-2 * lambda + 1; lambda; zeros(nx - 4,1)];
[Q1, EVals] = eig(toeplitz(Op_vals));
%[~, I] = sort(abs(diag(EVals)), 'descend');
Q_FE = Q1; %(:, I);
D = diag(EVals);
figure;
plot(D);
xlabel('i');
ylabel('Eigenvalue corresponding to w_i');
title('Eigenvalues of Forward Euler update matrix');
set(findall(gcf,'type','text'),'fontSize', 10);
Op_vals = [2 * lambda + 1; -lambda; zeros(nx - 4,1)];
[Q2, EVals] = eig(toeplitz(Op_vals));
%[~, I] = sort(abs(diag(EVals)), 'descend');
Q_BE = Q2; %(:, I);
end

F_old = f(t(1), x(2:end-1))';
for k = 2:nt
    F1 = f(t(k), x(2:end-1))';
    F2 = F_old;
    
    if (k > 2)
        F3 = f(t(k-2), x(2:end-1))';
    end
    if (k == t_fault)
        if strcmp(f_type, 'f_eval')
            F1(x_fault) = F1(x_fault) * fault_amt;
            F_save = F1;
            %F2(x_fault) = F2(x_fault) * fault_amt;
            %F3(x_fault) = F3(x_fault) * fault_amt;            
        end
        if strcmp(f_type, 'prev_soln')
            U_CN(k-1, x_fault) = U_CN(k-1, x_fault) * fault_amt;
            U_BE(k-1, x_fault) = U_BE(k-1, x_fault) * fault_amt;
        end
        if strcmp(f_type, 'chol_offdiag2zero')
            old_val_CN = R_CN(x_fault, x_fault+1);
            old_val_BE = R_BE(x_fault, x_fault+1);        
            R_CN(x_fault, x_fault+1) = 0;
            R_BE(x_fault, x_fault+1) = 0;
        end
        if strcmp(f_type, 'chol_offdiag2nonzero')
            R_CN(x_fault, x_fault+2) = fault_amt;
            R_BE(x_fault, x_fault+2) = fault_amt;
        end
        if strcmp(f_type, 'CN_A_diag2zero')
            old_val_A = A(x_fault, x_fault);
            A(x_fault, x_fault) = 0;
        end        
        if strcmp(f_type, 'CN_A_offdiag2zero')
            old_val_A = A(x_fault, x_fault+1);
            A(x_fault, x_fault+1) = 0;
        end
        if strcmp(f_type, 'CN_A_offdiag2nonzero')
            A(x_fault, x_fault+2) = fault_amt;
        end
    end
    
    % Crank-Nicholson update
    RHS = A * U_CN(k-1, 2:end-1)' + dt * (1/2) * (F1 + F2);
    if k == t_fault && strcmp(f_type, 'RHS_overall')
       RHS(x_fault) = RHS(x_fault) * fault_amt; 
    end
    U_CN(k,2:end-1) = R_CN\((R_CN')\RHS);
    
    % Backward Euler update
    RHS = U_BE(k-1, 2:end-1)' + dt * F1;
    if k == t_fault && strcmp(f_type, 'RHS_overall')
       RHS(x_fault) = RHS(x_fault) * fault_amt; 
    end    
    U_BE(k,2:end-1) = R_BE\((R_BE')\RHS);

    % Forward Euler and Dufort-Frankel updates
    U_FE(k, 2:nx-1) = lambda * diff(U_BE(k-1,:), 2) +...
        U_BE(k-1, 2:nx-1) + dt * F2';
    if (k > 2)
        U_DF(k, 2:nx-1) = ...
             ((1 - alpha) * U_CN(k - 2, 2:nx-1) +...
             alpha * (U_CN(k-1, 3:nx) + U_CN(k-1, 1:nx-2)) +...
             2 * dt * F2') / (1 + alpha);
         
        U_Ri(k, 2:nx-1) = 2 * lambda * diff(U_CN(k-1, :), 2) + ...
            U_CN(k-2, 2:nx-1) + 2 * dt * F2';
    else
        U_DF(k, 2:nx-1) = U_FE(k, 2:nx-1);
        U_Ri(k, 2:nx-1) = U_FE(k, 2:nx-1);
    end
    
    F_old = F1;

    % Compute difference function (max-norm in this case)
    n_diff_CN_FE(k) = norm(U_CN(k,:) - U_FE(k,:), n_type);
    n_diff_CN_Ri(k) = norm(U_CN(k,:) - U_Ri(k,:), n_type);
    n_diff_CN_DF(k) = norm(U_CN(k,:) - U_DF(k,:), n_type);
    n_diff_BE(k) = norm(U_BE(k,:) - U_FE(k,:), n_type);
    
    if 0
    if (k == t_fault - 1 || k == t_fault || k == t_fault + 1)
       if (k == t_fault - 1)
         figure;
       end
       vec1 = U_FE(k, 2:nx-1)';
       vec1 = vec1 / norm(vec1, 2);
       vec2 = U_BE(k, 2:nx-1)';
       vec2 = vec2 / norm(vec2, 2);
       vec3 = U_FE(k, 2:nx-1)' - U_BE(k, 2:nx-1)';
       vec3 = vec3 / norm(vec3, 2);
       if k == t_fault - 1
           subplot(1, 3, 1);
           semilogy(abs(Q_FE' * vec1)); hold on;
       elseif k == t_fault
           subplot(1, 3, 2);
           semilogy(abs(Q_FE' * vec1)); hold on;           
       else
           subplot(1, 3, 3);
       end
       semilogy(abs(Q_FE' * vec1)); hold on;
       semilogy(abs(Q_FE' * vec2), 'r');
       semilogy(abs(Q_FE' * vec3), 'g'); hold off;
       if k == t_fault - 1
           title('Step before error');
       elseif k == t_fault
           title('Step of error');
       else
           title('Step after error');
       end 
       legend('FE', 'BE', 'FE - BE', 'Location', 'SouthEast');
       xlabel('i');
       ylabel('|<w_i, U>| / |<U, U>|');
       set(findall(gcf,'type','text'),'fontSize', 10);
    end
    end
    
    %semilogy(abs(Q_FE' * U_FE(k, 2:nx-1)'));
    %pause(0.1);
    
    % restore old values
    if k == t_fault
        if strcmp(f_type, 'chol_offdiag2zero')
            R_CN(x_fault, x_fault+1) = old_val_CN;
            R_BE(x_fault, x_fault+1) = old_val_BE;
        end
        if strcmp(f_type, 'chol_offdiag2nonzero')
            R_CN(x_fault, x_fault+2) = 0;
            R_BE(x_fault, x_fault+2) = 0;
        end                    
        if strcmp(f_type, 'CN_A_diag2zero')
            A(x_fault, x_fault) = old_val_A;
        end
        if strcmp(f_type, 'CN_A_offdiag2zero')
            A(x_fault, x_fault+1) = old_val_A;
        end
        if strcmp(f_type, 'CN_A_offdiag2nonzero')
            A(x_fault, x_fault+2) = 0;
        end        
    end
    
    if 1
    if k == t_fault
        F1 = f(t(k), x(2:end-1))';
        F2 = f(t(k-1), x(2:end-1))';    
        
        % Crank-Nicholson update
        RHS = A * U_CN(k-1, 2:end-1)' + dt * (1/2) * (F1 + F2);
        tmp1 = R_CN\((R_CN')\RHS);

        % Backward Euler update
        RHS = U_BE(k-1, 2:end-1)' + dt * F1;
        tmp2 = R_BE\((R_BE')\RHS);
        
        errs = [norm(tmp1' - U_CN(k, 2:end-1), n_type), ...
                norm(tmp2' - U_BE(k, 2:end-1), n_type)];
    end
    end
end

if 0
    Op_vals = [-2 * lambda + 1; lambda; zeros(nx - 4,1)];
    [Q_FE, D] = eig(toeplitz(Op_vals));
    F = dt * (D \ (Q_FE' * F_save));
    
    set(gca, 'FontSize', 10, 'LineWidth', 2);
    
    vec1 = Q_FE' * U_BE(t_fault-1, 2:nx-1)';
    vec2 = Q_FE' * U_BE(t_fault, 2:nx-1)';
    vec3 = Q_FE' * U_BE(t_fault+1, 2:nx-1)';
    vec4 = Q_FE' * U_FE(t_fault+1, 2:nx-1)';
    vec5 = D * (vec2 + F);
    vec5 = Q_FE * D * Q_FE' * U_BE(t_fault, 2:nx-1)' + dt * F_save;
    vec5 = Q_FE' * vec5;
    semilogy(abs(vec1), 'r-', 'LineWidth', 1.5); hold on;
    semilogy(abs(vec2), 'b:', 'LineWidth', 3); hold on;
    semilogy(abs(vec3), 'k-.', 'LineWidth', 1.25); hold on;
    semilogy(abs(vec4), 'm--', 'LineWidth', 1.5); hold on;           
    semilogy(abs(F), 'cyan'); hold on;
    semilogy(abs(vec5), 'g'); hold off;

    legend('BE step before fault', ...
           'BE step of fault', ...
           'BE step after fault', ...
           'FE step after fault', ...
           'Forcing term', ...
           'Prediction', ...
           'Location', 'NorthWest');
    xlabel('i');
    ylabel('|w_i^TU|');    
    set(findall(gcf,'type','text'),'fontSize', 10);   
    %saveplotaspdf('eig_components2.pdf', 14, 14);
end

if 0
    Op_vals = [-2 * lambda + 1; lambda; zeros(nx - 4,1)];
    [Q_FE, D] = eig(toeplitz(Op_vals));

    figure;
    for k = [t_fault-1, t_fault, t_fault+1]
        vec1 = Q_FE' * U_BE(k, 2:nx-1)';
        vec2 = Q_FE' * U_FE(k, 2:nx-1)';
        F = f(t(k), x(2:nx-1))';
        if k == t_fault
            F(x_fault) = F(x_fault) * fault_amt;
        end
        vec3 = dt * (D \ Q_FE' * F);
        vec4 = (U_BE(k, 2:nx-1)' - U_FE(k, 2:nx-1)');

        if (k == t_fault - 1)
            subplot(1, 3, 1);
        elseif k == t_fault
            subplot(1, 3, 2);
        else
            subplot(1, 3, 3);
        end        
        dif = U_BE(k, 2:nx-1)' - U_FE(k, 2:nx-1)';
        vec6 = Q_FE' * (dif);
        semilogy(abs(dif), '-cyan'); hold on;
        %semilogy(89, abs(dif(89)), '-cyan*');
        semilogy(abs(vec6), '-m');
        vec7 = Q_FE' * dif;
        for i = 1:length(vec7)
            vec7(i) = sum(vec6(1:i).^2);
        end
        semilogy(vec7, 'b');
        if (k == t_fault - 1)
            title('Step before fault');
        elseif k == t_fault
            title('Step of fault');
        else
            title('Step after fault');
            legend('|diff|', '|Q^T * diff|', ...
                'Sum squares Q^T * diff', 'Location', 'SouthEast');
        end        
        

        if 0
        if (k == t_fault - 1)
            subplot(1, 3, 1);
        elseif k == t_fault
            subplot(1, 3, 2);
        else
            subplot(1, 3, 3);
        end
        semilogy(abs(vec1), 'r-', 'LineWidth', 1.5); hold on;
        semilogy(abs(vec2), 'b:', 'LineWidth', 3); hold on;
        semilogy(abs(vec3), 'k-.', 'LineWidth', 1.25); hold on;
        semilogy(abs(vec4), 'cyan');
        if k == t_fault + 1
            vec5 = Q_FE' * (toeplitz(Op_vals) * U_FE(k-1, 2:nx-1)' + dt * F);
            semilogy(abs(vec5), 'g');
        end
        if (k == t_fault - 1)
            title('Step before fault');
        elseif k == t_fault
            title('Step of fault');
        else
            title('Step after fault');
            legend('|Q^TU_{BE}|', '|Q^TU_{FE}|', '|G|', '|U_{BE} - U_{FE}|', 'Extra step on FE', ...
                'Location', 'NorthWest');
        end
        xlabel('vector component');
        set(findall(gcf,'type','text'),'fontSize', 10);   
        end
    end
    if 0
    figure
    semilogy(abs(vec1(1:50)), 'r-', 'LineWidth', 1.5); hold on;
    semilogy(abs(vec2(1:50)), 'b:', 'LineWidth', 3); hold on;
    semilogy(abs(vec3(1:50)), 'k-.', 'LineWidth', 1.25); hold on;
    semilogy(abs(vec4(1:50)), 'cyan');    
    semilogy(abs(vec5(1:50)), 'green');
    legend('|Q^TU_{BE}|', '|Q^TU_{FE}|', '|G|', '|U_{BE} - U_{FE}|', 'Extra step on FE', ...
                'Location', 'NorthWest');
    end
    
    %saveplotaspdf('eig_components2.pdf', 14, 14);
end

end