using OrdinaryDiffEq using ModelingToolkit using DataDrivenDiffEq using LinearAlgebra, Optim using Optimization, OptimizationFlux, OptimizationOptimJL #OptimizationFlux for ADAM and OptimizationOptimJL for BFGS using DiffEqSensitivity using Lux using Plots gr() using Statistics # Set a random seed for reproduceable behaviour using Random rng = Random.default_rng() Random.seed!(1234) #### NOTE # Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took # place, SR3 has been used. Right now, STLSQ performs better and has been changed. # Create a name for saving ( basically a prefix ) svname = "Scenario_1_" ## Data generation function lotka!(du, u, p, t) α, β, γ, δ = p du[1] = α*u[1] - β*u[2]*u[1] du[2] = γ*u[1]*u[2] - δ*u[2] end # Define the experimental parameter tspan = (0.0f0,3.0f0) u0 = Float32[0.44249296,4.6280594] p_ = Float32[1.3, 0.9, 0.8, 1.8] prob = ODEProblem(lotka!, u0,tspan, p_) solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1) # Ideal data X = Array(solution) t = solution.t # Add noise in terms of the mean x̄ = mean(X, dims = 2) noise_magnitude = Float32(5e-3) Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X)) plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing]) scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing]) ## Define the network # Gaussian RBF as activation rbf(x) = exp.(-(x.^2)) # Multilayer FeedForward U = Lux.Chain( Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,2) ) # Get the initial parameters and state variables of the model p, st = Lux.setup(rng, U) # Define the hybrid model function ude_dynamics!(du,u, p, t, p_true) û = U(u, p, st)[1] # Network prediction du[1] = p_true[1]*u[1] + û[1] du[2] = -p_true[4]*u[2] + û[2] end # Closure with the known parameter nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_) # Define the problem prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p) ## Function to train the network # Define a predictor function predict(θ, X = Xₙ[:,1], T = t) Array(solve(prob_nn, Vern7(), u0 = X, p=θ, saveat = T, abstol=1e-6, reltol=1e-6, sensealg = ForwardDiffSensitivity() )) end # Simple L2 loss function loss(θ) X̂ = predict(θ) sum(abs2, Xₙ .- X̂) end # Container to track the losses losses = Float32[] callback = function (p, l) push!(losses, l) if length(losses)%50==0 println("Current loss after $(length(losses)) iterations: $(losses[end])") end return false end ## Training # First train with ADAM for better convergence -> move the parameters into a # favourable starting positing for BFGS adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype) optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p)) res1 = Optimization.solve(optprob, ADAM(0.1), callback=callback, maxiters = 200) println("Training loss after $(length(losses)) iterations: $(losses[end])") # Train with BFGS optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer) res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 10000) println("Final training loss after $(length(losses)) iterations: $(losses[end])") # Plot the losses pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue) plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red) savefig(pl_losses, joinpath(pwd(), "plots", "$(svname)_losses.pdf")) # Rename the best candidate p_trained = res2.minimizer ## Analysis of the trained network # Plot the data and the approximation ts = first(solution.t):mean(diff(solution.t))/2:last(solution.t) X̂ = predict(p_trained, Xₙ[:,1], ts) # Trained on noisy data vs real solution pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing]) scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing]) savefig(pl_trajectory, joinpath(pwd(), "plots", "$(svname)_trajectory_reconstruction.pdf")) # Ideal unknown interactions of the predictor Ȳ = [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])'] # Neural network guess Ŷ = U(X̂,p_trained,st)[1] pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing]) plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing]) savefig(pl_reconstruction, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction.pdf")) # Plot the error pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ-Ŷ)), yaxis = :log, xlabel = "t", ylabel = "L2-Error", label = nothing, color = :red) pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2,1)) savefig(pl_missing, joinpath(pwd(), "plots", "$(svname)_missingterm_reconstruction_and_error.pdf")) pl_overall = plot(pl_trajectory, pl_missing) savefig(pl_overall, joinpath(pwd(), "plots", "$(svname)_reconstruction.pdf")) ## Symbolic regression via sparse regression ( SINDy based ) # Create a Basis @variables u[1:2] # Generate the basis functions, multivariate polynomials up to deg 5 # and sine b = [polynomial_basis(u, 3); sin.(u)] basis = Basis(b,u); # Create the thresholds which should be used in the search process λ = Float32.(exp10.(-7:0.5:0)) # Create an optimizer for the SINDy problem opt = STLSQ(λ) # Define different problems for the recovery full_problem = ContinuousDataDrivenProblem(X, t) ideal_problem = ContinuousDataDrivenProblem(X̂, ts, DX = Ȳ) nn_problem = ContinuousDataDrivenProblem(X̂, ts, DX = Ŷ) # Test on ideal derivative data for unknown function ( not available ) println("Sparse regression") full_res = solve(full_problem, basis, opt, maxiter = 10000, progress = true) ideal_res = solve(ideal_problem, basis, opt, maxiter = 10000, progress = true) nn_res = solve(nn_problem, basis, opt, maxiter = 10000, progress = true) # Store the results results = [full_res; ideal_res; nn_res] # Show the results map(println, results) # Show the results map(println ∘ result, results) # Show the identified parameters map(println ∘ parameter_map, results) # Define the recovered, hyrid model function recovered_dynamics!(du,u, p, t) û = nn_res(u, p) # Network prediction du[1] = p_[1]*u[1] + û[1] du[2] = -p_[4]*u[2] + û[2] end estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, parameters(nn_res)) estimate = solve(estimation_prob, Tsit5(), saveat = solution.t) # Plot plot(solution) plot!(estimate) ## Simulation # Look at long term prediction t_long = (0.0f0, 50.0f0) estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameters(nn_res)) estimate_long = solve(estimation_prob, Tsit5(), saveat = 0.1) # Using higher tolerances here results in exit of julia plot(estimate_long) true_prob = ODEProblem(lotka!, u0, t_long, p_) true_solution_long = solve(true_prob, Tsit5(), saveat = estimate_long.t) plot!(true_solution_long) ## Save the results save(joinpath(pwd(), "results" ,"$(svname)recovery_$(noise_magnitude).jld2"), "solution", solution, "X", Xₙ, "t" , ts, "neural_network" , U, "initial_parameters", p, "trained_parameters" , p_trained, # Training "losses", losses, "result", nn_res, "recovered_parameters", parameters(nn_res), # Recovery "long_solution", true_solution_long, "long_estimate", estimate_long) # Estimation ## Post Processing and Plots c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red c3 = :blue # RGBA(255/255,90/255,0,1) # Orange c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple p1 = plot(t,abs.(Array(solution) .- estimate)' .+ eps(Float32), lw = 3, yaxis = :log, title = "Timeseries of UODE Error", color = [3 :orange], xlabel = "t", label = ["x(t)" "y(t)"], titlefont = "Helvetica", legendfont = "Helvetica", legend = :topright) # Plot L₂ p2 = plot3d(X̂[1,:], X̂[2,:], Ŷ[2,:], lw = 3, title = "Neural Network Fit of U2(t)", color = c1, label = "Neural Network", xaxis = "x", yaxis="y", titlefont = "Helvetica", legendfont = "Helvetica", legend = :bottomright) plot!(X̂[1,:], X̂[2,:], Ȳ[2,:], lw = 3, label = "True Missing Term", color=c2) p3 = scatter(solution, color = [c1 c2], label = ["x data" "y data"], title = "Extrapolated Fit From Short Training Data", titlefont = "Helvetica", legendfont = "Helvetica", markersize = 5) plot!(p3,true_solution_long, color = [c1 c2], linestyle = :dot, lw=5, label = ["True x(t)" "True y(t)"]) plot!(p3,estimate_long, color = [c3 c4], lw=1, label = ["Estimated x(t)" "Estimated y(t)"]) plot!(p3,[2.99,3.01],[0.0,10.0],lw=1,color=:black, label = nothing) annotate!([(1.5,13,text("Training \nData", 10, :center, :top, :black, "Helvetica"))]) l = @layout [grid(1,2) grid(1,1)] plot(p1,p2,p3,layout = l) savefig(joinpath(pwd(),"plots","$(svname)full_plot.pdf"))