% This experimental code is published in connection with a scientific publication, Silent error detection in numerical time-stepping schemes,
% by Austin R. Benson, Sven Schmit, and Robert Schreiber, to appear in International Journal of High Performance Computing Applications in 2014.
% This code is made available solely to allow the readers of that publication to verify and reproduce the results described. 
% This experimental code is published "as is", with no representation, warranty, indemnification of any kind. 
% Hewlett-Packard excludes all liability that may result from the use of this experimental code in any form.

%
% Numerically solve the nonhomogeneous heat equation with finite
% differences.
%
%    u_t(t, x) = Ku_xx(t, x) + f(t, x)
%
%    u(x, 0) = g(x)
%
%    u(0, t) = BC0(t)      
%    u(1, t) = BC1(t)
%
%    x \in [0, 1], t \in [0, T]
%
function [n_diff_CN, n_diff_BE, errs, lte_est] = ...
    heat_solve_update(K, f, T, dx, dt, t_fault, x_fault, fault_amt, ...
    U_CN, U_Ri, U_BE, U_FE, n_diff_CN, n_diff_BE, f_type)
nx = floor(1 / dx);
nt = floor(T / dt);

lambda = K * dt / dx^2;
%alpha = 2 * lambda;
errs = 0;

x = linspace(0,1,nx);
t = linspace(0,T,nt);

% form B and A and take QR of B
% Crank-Nicholson: BU^{n+1} = AU^n
mat_vals = zeros(nx-2,1);
mat_vals(1) = 1 + lambda;
mat_vals(2) = -lambda / 2;
B = sparse(toeplitz(mat_vals));
R_CN = chol(B);

mat_vals(1) = 1 - lambda;
mat_vals(2) = lambda / 2;
A = sparse(toeplitz(mat_vals));

% form C and take QR
% Backward Euler: CU^{n+1} = U^n
mat_vals(1) = 1 + 2 * lambda;
mat_vals(2) = -lambda;
C = sparse(toeplitz(mat_vals));
R_BE = chol(C);

n_type = 'inf';


F_old = f(t(t_fault - 1), x(2:end-1))';
for k = t_fault:t_fault+1
    F1 = f(t(k), x(2:end-1))';
    F2 = F_old;
    if (k == t_fault)
        if strcmp(f_type, 'f_eval')
            F1(x_fault) = F1(x_fault) * fault_amt;
            %F2(x_fault) = F2(x_fault) * fault_amt;
        end
        if strcmp(f_type, 'prev_soln')
            old_val_U_CN = U_CN(k-1, x_fault);
            old_val_U_BE = U_BE(k-1, x_fault);            
            U_CN(k-1, x_fault) = U_CN(k-1, x_fault) * fault_amt;
            U_BE(k-1, x_fault) = U_BE(k-1, x_fault) * fault_amt;
        end
        if strcmp(f_type, 'chol_offdiag2zero')
            old_val_CN = R_CN(x_fault, x_fault+1);
            old_val_BE = R_BE(x_fault, x_fault+1);        
            R_CN(x_fault, x_fault+1) = 0;
            R_BE(x_fault, x_fault+1) = 0;
        end
        if strcmp(f_type, 'chol_offdiag2nonzero')
            R_CN(x_fault, x_fault+2) = fault_amt;
            R_BE(x_fault, x_fault+2) = fault_amt;
        end
        if strcmp(f_type, 'CN_A_diag2zero')
            old_val_A = A(x_fault, x_fault);
            A(x_fault, x_fault) = 0;
        end        
        if strcmp(f_type, 'CN_A_offdiag2zero')
            old_val_A = A(x_fault, x_fault+1);
            A(x_fault, x_fault+1) = 0;
        end
        if strcmp(f_type, 'CN_A_offdiag2nonzero')
            A(x_fault, x_fault+2) = fault_amt;
        end
    end

    % Crank-Nicholson update
    RHS = A * U_CN(k-1, 2:end-1)' + dt * (1/2) * (F1 + F2);
    if k == t_fault && strcmp(f_type, 'RHS_overall')
       RHS(x_fault) = RHS(x_fault) * fault_amt; 
    end
    U_CN(k,2:end-1) = R_CN\((R_CN')\RHS);
    
    % Backward Euler update
    RHS = U_BE(k-1, 2:end-1)' + dt * F1;
    if k == t_fault && strcmp(f_type, 'RHS_overall')
       RHS(x_fault) = RHS(x_fault) * fault_amt; 
    end    
    U_BE(k,2:end-1) = R_BE\((R_BE')\RHS);

    % Forward Euler and Dufort-Frankel updates
    U_FE(k, 2:nx-1) = lambda * diff(U_BE(k-1,:), 2) +...
        U_BE(k-1, 2:nx-1) + dt * F2';
    if (k > 2)
        U_Ri(k, 2:nx-1) = 2 * lambda * diff(U_CN(k-1, :), 2) + ...
            U_CN(k-2, 2:nx-1) + 2 * dt * F2';
    else
        U_Ri(k, 2:nx-1) = U_FE(k, 2:nx-1);
    end
    
    % Compute difference function (max-norm in this case)
    n_diff_CN(k) = norm(U_CN(k,:) - U_Ri(k,:), n_type);
    n_diff_BE(k) = norm(U_BE(k,:) - U_FE(k,:), n_type);    
    
    F_old = F1;    

    % restore old values
    if k == t_fault
        if strcmp(f_type, 'chol_offdiag2zero')
            R_CN(x_fault, x_fault+1) = old_val_CN; %#ok<*SPRIX>
            R_BE(x_fault, x_fault+1) = old_val_BE;
        end
        if strcmp(f_type, 'chol_offdiag2nonzero')
            R_CN(x_fault, x_fault+2) = 0;
            R_BE(x_fault, x_fault+2) = 0;
        end                    
        if strcmp(f_type, 'CN_A_diag2zero')
            A(x_fault, x_fault) = old_val_A;
        end
        if strcmp(f_type, 'CN_A_offdiag2zero')
            A(x_fault, x_fault+1) = old_val_A;
        end
        if strcmp(f_type, 'CN_A_offdiag2nonzero')
            A(x_fault, x_fault+2) = 0;
        end        
        if strcmp(f_type, 'prev_soln')
            U_CN(k-1, x_fault) = old_val_U_CN;
            U_BE(k-1, x_fault) = old_val_U_BE;            
        end                
    end
    
    if k == t_fault
        F1 = f(t(k), x(2:end-1))';
        F2 = f(t(k-1), x(2:end-1))';    
        
        % Crank-Nicholson update
        RHS = A * U_CN(k-1, 2:end-1)' + dt * (1/2) * (F1 + F2);
        tmp1_B = R_CN\((R_CN')\RHS);
        tmp1_B = tmp1_B';
        tmp1_A = 2 * lambda * diff(U_CN(k-1, :), 2) + ...
            U_CN(k-2, 2:nx-1) + 2 * dt * F2';        
        

        % Backward Euler update
        RHS = U_BE(k-1, 2:end-1)' + dt * F1;
        tmp2_B = R_BE\((R_BE')\RHS);
        tmp2_B = tmp2_B';
        tmp2_A = lambda * diff(U_BE(k-1,:), 2) +...
            U_BE(k-1, 2:nx-1) + dt * F1';        
        
        
        errs = [norm(tmp1_B - U_CN(k, 2:end-1), n_type), ...
                norm(tmp2_B - U_BE(k, 2:end-1), n_type)];
        lte_est = [norm(tmp1_B - tmp1_A, n_type), ...
                   norm(tmp2_B - tmp2_A, n_type)];
    end

end

end