Heat equation numerical solution

An attempt was made to create a 1D heat equation numerical solver webpage and add it to this website. The equation in question, for those unfamiliar is

ut=2ux2.\begin{aligned} \dfrac{\partial u}{\partial t} &= \dfrac{\partial^2 u}{\partial x^2}. \end{aligned}

Yes, I do know it has an exact solution in the form of an infinite series, but I wanted to numerically approximate its solution for a webpage. This attempt was abandoned, however, as it proved too computationally expensive for a webpage. The numerical integration technique I attempted involved approximating spatial derivatives from grid point values using fast Fourier transforms (FFTs) and their inverses (IFFTs), then integrating in time using the Runge–Kutta–Fehlberg 4th-order method with 5th-order error checking (RKF45). The calculations took too long and froze the JavaScript console when run there.

In fairness, RKF45 already takes a few seconds to solve the double elastic pendulum (DEP) and triple pendulum (TP) ordinary differential equations (ODEs). It is therefore not surprising that increasing the complexity by orders of magnitude—by adding an FFT and IFFT to every system evaluation and having more than 100 ODEs in the system (one for each grid point)—was simply too much for a webpage to handle.

These webpages—the DEP and TP solver pages—take longer to apply RKF45 than the other solver pages because they obtain the equations to be integrated by solving a linear system of equations at each step. The matrices have 4 rows and columns for DEP, and 3 for TP. Math.js’ lusolve function is used to solve these linear systems; it employs the LU decomposition method, which is O(N3)O(N^3) (keeping in mind there are also N2N^2 terms that we typically ignore in Big-O notation and the N3N^3 has a coefficient of 23\dfrac{2}{3}). For DEP, this corresponds to roughly 75 operations per function call, and for TP, about 36. FFT is O(NlogN)O(N \log N) and in reality it is about 5Nlog2N5N\log_2 N operations, as is its inverse. So for N>100N > 100, we would potentially be looking at over 6,643 operations (10×100×log210010\times 100 \times \log_2 100) per function call. If you also factor in that the arrays are also >100-elements long, further increasing the operations needed to be performed, it is not surprising that this would be too computationally complex for a webpage to handle.

Due to these limitations, the best I can offer to people on this website is the below Julia implementation of this solution. It is sadly not interactive, so you cannot customize the parameters of the problem for yourself, but you can at least see how the solution works. Interestingly, this implementation fails, due to the step size becoming too small, if the tolerance type is set to relative instead of absolute.

using FFTW, StaticArrays, Plots, Interpolations

"""
    RKF45(f::Function, params, t0::Number, tf::Number, conds::Vector, epsilon, dtInitial, tolType="absolute", dtMin=(tf-t0)/1e8)

f should be a function that returns the RHS of the ODE being solved expressed as a system of first-order ODEs
in an array of the form `[dx1/dt, dx2/dt, dx3/dt, dx4/dt, ..., dxn/dt]`. Its arguments should be: params 
(an object containing problem parameters), t (a Float64) and `vars::Vector` (a column vector of the form 
`[element1; element2; element3; ...; elementn]`).

`params` should be a named tuple containing parameter values. For the simple
pendulum problem with pendulum length 1 metre, for example, it can be written
as:

`params = (g = 9.81, l = 1.0)`.

`t0` is the value of t (the independent variable) at the beginning of the integration.

`tf` is the value of t at the end of the integration.

`conds` is an SVector containing initial conditions. For the simple pendulum
problem, for example, the following code may be used (where theta0 and
thetaDot0) have been defined elsewhere:

`conds = @SVector [theta0, thetaDot0]`.

`epsilon` is the error tolerance for the problem.

`dtInitial` is the initial guess for the step size.

`tolType` is the type of tolerance used. Either "absolute" or "relative".

`dtMin` is the minimum step size allowed. Default is (tf-t0)/1e8.
"""
function RKF45(f::Function, params::NamedTuple, t0::Float64, tf::Float64, conds::SVector, epsilon::Float64, dtInitial::Float64, tolType::String = "absolute", dtMin::Float64 = (tf-t0)/1e8)
    # Initialize relevant variables
    dt = dtInitial;
    t = Float64[t0];
    vars = [conds];
    i = 1;
    ti = t0;

    # Loop over t under the solution for tf has been found
    while ( ti < tf )
        varsi =  vars[i];
        dt = minimum((dt, tf-ti));

        # RKF45 approximators
        K1 = dt*f(params, ti, varsi);
        K2 = dt*f(params, ti + dt/4, varsi + K1/4);
        K3 = dt*f(params, ti + 3*dt/8, varsi + 3*K1/32 + 9*K2/32);
        K4 = dt*f(params, ti + 12*dt/13, varsi + 1932*K1/2197 - 7200*K2/2197 + 7296*K3/2197);
        K5 = dt*f(params, ti + dt, varsi + 439*K1/216 - 8*K2 + 3680*K3/513 - 845*K4/4104);
        K6 = dt*f(params, ti + dt/2, varsi - 8*K1/27 + 2*K2 - 3544*K3/2565 + 1859*K4/4104 - 11*K5/40);

        # 4/5th order approximations to next step value
        vars1 = varsi + 25*K1/216 + 1408*K3/2565 + 2197*K4/4104 - K5/5;
        vars2 = varsi + 16*K1/135 + 6656*K3/12825 + 28561*K4/56430 - 9*K5/50 + 2*K6/55;

        # Determine if error is small enough to move on to next step
        if (tolType in ["relative", "rel", "R", "r", "Rel", "Relative"])
            R = maximum(abs.(vars2 - vars1)./abs.(vars1))/dt;
        elseif (tolType in ["absolute", "abs", "A", "a", "Abs", "Absolute"])
            R = maximum(abs.(vars2 - vars1))/dt;
        else
            error("tolType is set to an invalid value ($tolType), so exiting...")
        end
        s = (epsilon/(2*R))^(0.25);
        if (R <= epsilon)
            Base.push!(t, ti+dt);
            StaticArrays.push!(vars, vars1);
            i += 1;
            ti = t[i];
        end
        dt *= s;
        if (dt < dtMin)
            @warn("dt has reached $dt at t=$ti which is less than dtMin=$dtMin")
            if (tolType in ["absolute", "abs", "A", "a", "Abs", "Absolute"])
                tolType = "relative";
                @warn("As you are using an absolute tolerance type, we will switch to relative tolerance to see if this fixes the problem...")
            elseif (tolType in ["relative", "rel", "R", "r", "Rel", "Relative"])
                @warn("Breaking out of loop as tolerance type is already set to relative.")
                break
            else
                error("tolType is set to an invalid value ($tolType), so exiting...")
            end
        end
    end

    # Transpose and enter into NamedTuple
    vars = transpose(reduce(hcat, vars));
    return t, vars;
end
function heat(params, t, u::SVector)
    α = params.α

    # Wavenumbers
    k = params.k

    # FFT of u
    u_hat = fft(collect(u))  # convert to Vector for FFTW

    # Second derivative in Fourier space: -(k^2) * u_hat
    uxx_hat = - (k .^ 2) .* u_hat

    # Back to real space
    u_xx = real(ifft(uxx_hat))

    # Return du/dt as SVector for RKF45
    return SVector{length(u_xx)}(α .* u_xx)
end

N = 128                # number of grid points
T = 2π                 # period
L = T                  # domain length
dx = L / N
α = 0.01               # thermal diffusivity
# Wavenumbers for FFT (assumes N even)
k = vcat(0:N÷2-1, 0, -N÷2+1:-1) .* (2π/L)

# Initial condition
x = dx .* (0:N-1)
u0 = 10 * (2 .- cos.(2*pi/T * x))

params = (α=α, L=L, k=k)
conds = SVector{length(u0)}(u0)
t0 = 0.0
tf = 300.0
epsilon = 1e-9
dtInitial = 1e-3

t_vals, u_vals = RKF45(heat, params, t0, tf, conds, epsilon, dtInitial)
Umat = Matrix(u_vals)';
plotlyjs()  # use PlotlyJS for 3D plotting

# Create grids for x and t matching Umat dimensions
X = repeat(x, 1, length(t_vals))           # N × nt
T = repeat(t_vals', length(x), 1)          # N × nt (t_vals' transposed to row vector)

# Surface plot
surface(X, T, Umat, xlabel="x", ylabel="t", zlabel="u(x,t)",
        title="Heat equation", legend=false)
savefig(joinpath(@OUTPUT, "heatEquation.svg"))

# Animate
nt_uniform = Int64(round((tf-t0)*30));
t_uniform = range(t_vals[1], t_vals[end], length=nt_uniform)  # e.g., 300 frames

U_interp = zeros(size(Umat, 1), length(t_uniform))

for i in 1:size(Umat, 1)
    itp = LinearInterpolation(t_vals, Umat[i, :])
    U_interp[i, :] = itp.(t_uniform)
end

function fixed_decimals(x; digits=3)
    s = "t = " * string(round(x, digits=digits))
    # Check if decimal part exists
    if !occursin('.', s)
        # Add decimal point and zeros if missing
        s *= "." * "0"^digits
    else
        parts = split(s, '.')
        decimals = parts[2]
        zeros_needed = digits - length(decimals)
        s *= "0"^max(0, zeros_needed)
    end
    return s
end

ymin, ymax = extrema(U_interp)
anim = @animate for i in 1:nt_uniform
    plot(x, U_interp[:, i],
         ylim = (ymin, ymax),
         xlabel = "x",
         ylabel = "u(x, t)",
         title = fixed_decimals(t_uniform[i]),
         legend = false)
end

Deltat = t_uniform[2]-t_uniform[1];
mp4(anim, joinpath(@OUTPUT, "heat.mp4"), fps = 1/Deltat)