|
| 1 | +# Cache structures for in-place DASSL operations |
| 2 | +# Following OrdinaryDiffEq.jl pattern: pre-allocate all working arrays |
| 3 | + |
| 4 | +""" |
| 5 | + DASSLCache{T, uType, jacType} |
| 6 | +
|
| 7 | +Pre-allocated cache for in-place DASSL operations. |
| 8 | +All working arrays are allocated once and reused across integration steps. |
| 9 | +""" |
| 10 | +mutable struct DASSLCache{T, uType, jacType} |
| 11 | + # Current state vectors (reused each step) |
| 12 | + yn::uType # Next y value (corrector output) |
| 13 | + dyn::uType # Next dy value |
| 14 | + y0::uType # Predictor y value |
| 15 | + dy0::uType # Predictor dy value |
| 16 | + |
| 17 | + # Newton iteration working vectors |
| 18 | + delta::uType # Newton step delta |
| 19 | + residual::uType # Residual vector for in-place F evaluation |
| 20 | + ytmp::uType # Temporary vector for computations |
| 21 | + ytmp2::uType # Second temporary vector |
| 22 | + |
| 23 | + # Jacobian storage |
| 24 | + jac::jacType # Jacobian matrix buffer |
| 25 | + jac_factorized::Any # Factorized Jacobian (stored as Any for flexibility) |
| 26 | + a::T # Current Jacobian coefficient |
| 27 | + |
| 28 | + # Numerical Jacobian working arrays |
| 29 | + f_plus::uType # F(y + delta) result |
| 30 | + f_base::uType # F(y) result |
| 31 | + |
| 32 | + # History buffers (circular buffer, fixed size for BDF order 1-6) |
| 33 | + t_hist::Vector{T} # Time history |
| 34 | + y_hist::Vector{uType} # Solution history |
| 35 | + dy_hist::Vector{uType} # Derivative history |
| 36 | + hist_start::Int # Start index in circular buffer |
| 37 | + hist_len::Int # Current number of valid entries |
| 38 | + |
| 39 | + # Error weights buffer |
| 40 | + wt::uType |
| 41 | +end |
| 42 | + |
| 43 | +""" |
| 44 | + alg_cache(alg, u, p, t, ::Val{true}) |
| 45 | +
|
| 46 | +Create a mutable cache for in-place DASSL operations. |
| 47 | +All working arrays are pre-allocated based on the size of `u`. |
| 48 | +""" |
| 49 | +function alg_cache(alg, u::uType, p, t::T, ::Val{true}) where {T, uType} |
| 50 | + n = length(u) |
| 51 | + |
| 52 | + # State vectors |
| 53 | + yn = similar(u) |
| 54 | + dyn = similar(u) |
| 55 | + y0 = similar(u) |
| 56 | + dy0 = similar(u) |
| 57 | + |
| 58 | + # Newton iteration vectors |
| 59 | + delta = similar(u) |
| 60 | + residual = similar(u) |
| 61 | + ytmp = similar(u) |
| 62 | + ytmp2 = similar(u) |
| 63 | + |
| 64 | + # Jacobian storage |
| 65 | + jac = zeros(eltype(u), n, n) |
| 66 | + jac_factorized = nothing |
| 67 | + a = zero(T) |
| 68 | + |
| 69 | + # Numerical Jacobian vectors |
| 70 | + f_plus = similar(u) |
| 71 | + f_base = similar(u) |
| 72 | + |
| 73 | + # History buffers (MAXORDER + 3 = 9 entries max needed) |
| 74 | + max_hist = MAXORDER + 3 |
| 75 | + t_hist = zeros(T, max_hist) |
| 76 | + y_hist = [similar(u) for _ in 1:max_hist] |
| 77 | + dy_hist = [similar(u) for _ in 1:max_hist] |
| 78 | + |
| 79 | + # Error weights |
| 80 | + wt = similar(u) |
| 81 | + |
| 82 | + return DASSLCache( |
| 83 | + yn, dyn, y0, dy0, |
| 84 | + delta, residual, ytmp, ytmp2, |
| 85 | + jac, jac_factorized, a, |
| 86 | + f_plus, f_base, |
| 87 | + t_hist, y_hist, dy_hist, 1, 0, |
| 88 | + wt |
| 89 | + ) |
| 90 | +end |
| 91 | + |
| 92 | +""" |
| 93 | + alg_cache(alg, u, p, t, ::Val{false}) |
| 94 | +
|
| 95 | +For out-of-place problems, return nothing (no cache needed). |
| 96 | +""" |
| 97 | +function alg_cache(alg, u, p, t, ::Val{false}) |
| 98 | + return nothing |
| 99 | +end |
| 100 | + |
| 101 | +# ============================================================================ |
| 102 | +# Circular buffer history management |
| 103 | +# ============================================================================ |
| 104 | + |
| 105 | +""" |
| 106 | + push_history!(cache, t, y, dy) |
| 107 | +
|
| 108 | +Add a new entry to the history circular buffer. |
| 109 | +If buffer is full, overwrites the oldest entry. |
| 110 | +""" |
| 111 | +function push_history!(cache::DASSLCache, t, y, dy) |
| 112 | + max_size = length(cache.t_hist) |
| 113 | + |
| 114 | + if cache.hist_len < max_size |
| 115 | + # Buffer not full, append |
| 116 | + cache.hist_len += 1 |
| 117 | + else |
| 118 | + # Buffer full, advance start pointer (overwrite oldest) |
| 119 | + cache.hist_start = mod1(cache.hist_start + 1, max_size) |
| 120 | + end |
| 121 | + |
| 122 | + # Compute index for new entry |
| 123 | + idx = mod1(cache.hist_start + cache.hist_len - 1, max_size) |
| 124 | + |
| 125 | + cache.t_hist[idx] = t |
| 126 | + copyto!(cache.y_hist[idx], y) |
| 127 | + return copyto!(cache.dy_hist[idx], dy) |
| 128 | +end |
| 129 | + |
| 130 | +""" |
| 131 | + pop_oldest_history!(cache) |
| 132 | +
|
| 133 | +Remove the oldest entry from the history buffer. |
| 134 | +""" |
| 135 | +function pop_oldest_history!(cache::DASSLCache) |
| 136 | + return if cache.hist_len > 0 |
| 137 | + cache.hist_start = mod1(cache.hist_start + 1, length(cache.t_hist)) |
| 138 | + cache.hist_len -= 1 |
| 139 | + end |
| 140 | +end |
| 141 | + |
| 142 | +""" |
| 143 | + get_t_at(cache, i) |
| 144 | +
|
| 145 | +Get time at history index i (1 = oldest in current window). |
| 146 | +""" |
| 147 | +function get_t_at(cache::DASSLCache, i::Integer) |
| 148 | + max_size = length(cache.t_hist) |
| 149 | + idx = mod1(cache.hist_start + i - 1, max_size) |
| 150 | + return cache.t_hist[idx] |
| 151 | +end |
| 152 | + |
| 153 | +""" |
| 154 | + get_y_at(cache, i) |
| 155 | +
|
| 156 | +Get y vector at history index i (1 = oldest in current window). |
| 157 | +Returns a reference to the stored vector. |
| 158 | +""" |
| 159 | +function get_y_at(cache::DASSLCache, i::Integer) |
| 160 | + max_size = length(cache.y_hist) |
| 161 | + idx = mod1(cache.hist_start + i - 1, max_size) |
| 162 | + return cache.y_hist[idx] |
| 163 | +end |
| 164 | + |
| 165 | +""" |
| 166 | + get_dy_at(cache, i) |
| 167 | +
|
| 168 | +Get dy vector at history index i (1 = oldest in current window). |
| 169 | +Returns a reference to the stored vector. |
| 170 | +""" |
| 171 | +function get_dy_at(cache::DASSLCache, i::Integer) |
| 172 | + max_size = length(cache.dy_hist) |
| 173 | + idx = mod1(cache.hist_start + i - 1, max_size) |
| 174 | + return cache.dy_hist[idx] |
| 175 | +end |
| 176 | + |
| 177 | +""" |
| 178 | + get_latest_t(cache) |
| 179 | +
|
| 180 | +Get the most recent time value. |
| 181 | +""" |
| 182 | +function get_latest_t(cache::DASSLCache) |
| 183 | + return get_t_at(cache, cache.hist_len) |
| 184 | +end |
| 185 | + |
| 186 | +""" |
| 187 | + get_latest_y(cache) |
| 188 | +
|
| 189 | +Get the most recent y vector. |
| 190 | +""" |
| 191 | +function get_latest_y(cache::DASSLCache) |
| 192 | + return get_y_at(cache, cache.hist_len) |
| 193 | +end |
| 194 | + |
| 195 | +""" |
| 196 | + get_latest_dy(cache) |
| 197 | +
|
| 198 | +Get the most recent dy vector. |
| 199 | +""" |
| 200 | +function get_latest_dy(cache::DASSLCache) |
| 201 | + return get_dy_at(cache, cache.hist_len) |
| 202 | +end |
| 203 | + |
| 204 | +""" |
| 205 | + get_history_t(cache, ord) |
| 206 | +
|
| 207 | +Get a vector of the last `ord` time values for interpolation. |
| 208 | +Returns values from oldest to newest within the window. |
| 209 | +Note: This allocates a small vector - could be optimized further. |
| 210 | +""" |
| 211 | +function get_history_t(cache::DASSLCache, ord::Integer) |
| 212 | + n = min(ord, cache.hist_len) |
| 213 | + t_vec = Vector{eltype(cache.t_hist)}(undef, n) |
| 214 | + start_idx = cache.hist_len - n + 1 |
| 215 | + @inbounds for i in 1:n |
| 216 | + t_vec[i] = get_t_at(cache, start_idx + i - 1) |
| 217 | + end |
| 218 | + return t_vec |
| 219 | +end |
| 220 | + |
| 221 | +""" |
| 222 | + get_history_y(cache, ord) |
| 223 | +
|
| 224 | +Get a vector of the last `ord` y vectors for interpolation. |
| 225 | +Returns references from oldest to newest within the window. |
| 226 | +Note: This allocates a small vector of references. |
| 227 | +""" |
| 228 | +function get_history_y(cache::DASSLCache, ord::Integer) |
| 229 | + n = min(ord, cache.hist_len) |
| 230 | + y_vec = Vector{eltype(cache.y_hist)}(undef, n) |
| 231 | + start_idx = cache.hist_len - n + 1 |
| 232 | + @inbounds for i in 1:n |
| 233 | + y_vec[i] = get_y_at(cache, start_idx + i - 1) |
| 234 | + end |
| 235 | + return y_vec |
| 236 | +end |
| 237 | + |
| 238 | +""" |
| 239 | + get_history_dy(cache, ord) |
| 240 | +
|
| 241 | +Get a vector of the last `ord` dy vectors for interpolation. |
| 242 | +Returns references from oldest to newest within the window. |
| 243 | +""" |
| 244 | +function get_history_dy(cache::DASSLCache, ord::Integer) |
| 245 | + n = min(ord, cache.hist_len) |
| 246 | + dy_vec = Vector{eltype(cache.dy_hist)}(undef, n) |
| 247 | + start_idx = cache.hist_len - n + 1 |
| 248 | + @inbounds for i in 1:n |
| 249 | + dy_vec[i] = get_dy_at(cache, start_idx + i - 1) |
| 250 | + end |
| 251 | + return dy_vec |
| 252 | +end |
| 253 | + |
| 254 | +""" |
| 255 | + clear_history!(cache) |
| 256 | +
|
| 257 | +Reset the history buffer to empty state. |
| 258 | +""" |
| 259 | +function clear_history!(cache::DASSLCache) |
| 260 | + cache.hist_start = 1 |
| 261 | + return cache.hist_len = 0 |
| 262 | +end |
| 263 | + |
| 264 | +""" |
| 265 | + weights!(wt, y, reltol, abstol) |
| 266 | +
|
| 267 | +Compute error weights in-place. |
| 268 | +""" |
| 269 | +function weights!(wt, y, reltol, abstol) |
| 270 | + return @. wt = reltol * abs(y) + abstol |
| 271 | +end |
0 commit comments