Skip to content

Commit 6cdb71a

Browse files
Merge pull request #78 from ChrisRackauckas-Claude/inplace-support
Add in-place solver support following OrdinaryDiffEq cache pattern
2 parents 7dbeb9a + 4cbd034 commit 6cdb71a

File tree

7 files changed

+1250
-37
lines changed

7 files changed

+1250
-37
lines changed

src/DASSL.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DASSL
22

3-
export dasslIterator, dasslSolve
3+
export dasslIterator, dasslSolve, dasslSolve!
4+
export DASSLCache, alg_cache
45

56
using ArrayInterface: fast_scalar_indexing
67
using Reexport: @reexport
@@ -14,11 +15,12 @@ import DiffEqBase: solve
1415

1516
export dassl
1617

17-
include("common.jl")
18-
1918
const MAXORDER = 6
2019
const MAXIT = 10
2120

21+
# Include cache and in-place implementations first
22+
include("cache.jl")
23+
2224
mutable struct JacData{T <: Real, M}
2325
a::T
2426
jac::M # Jacobian matrix for the newton solver
@@ -836,6 +838,12 @@ function numerical_jacobian(F, reltol, abstol, weights)
836838
end
837839
end
838840

841+
# Include in-place implementations (needs functions defined above)
842+
include("inplace.jl")
843+
844+
# Include DiffEqBase interface (needs cache and inplace)
845+
include("common.jl")
846+
839847
@setup_workload begin
840848
@compile_workload begin
841849
# Precompile the DiffEqBase interface with Float64

src/cache.jl

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Cache structures for in-place DASSL operations
2+
# Following OrdinaryDiffEq.jl pattern: pre-allocate all working arrays
3+
4+
"""
5+
DASSLCache{T, uType, jacType}
6+
7+
Pre-allocated cache for in-place DASSL operations.
8+
All working arrays are allocated once and reused across integration steps.
9+
"""
10+
mutable struct DASSLCache{T, uType, jacType}
11+
# Current state vectors (reused each step)
12+
yn::uType # Next y value (corrector output)
13+
dyn::uType # Next dy value
14+
y0::uType # Predictor y value
15+
dy0::uType # Predictor dy value
16+
17+
# Newton iteration working vectors
18+
delta::uType # Newton step delta
19+
residual::uType # Residual vector for in-place F evaluation
20+
ytmp::uType # Temporary vector for computations
21+
ytmp2::uType # Second temporary vector
22+
23+
# Jacobian storage
24+
jac::jacType # Jacobian matrix buffer
25+
jac_factorized::Any # Factorized Jacobian (stored as Any for flexibility)
26+
a::T # Current Jacobian coefficient
27+
28+
# Numerical Jacobian working arrays
29+
f_plus::uType # F(y + delta) result
30+
f_base::uType # F(y) result
31+
32+
# History buffers (circular buffer, fixed size for BDF order 1-6)
33+
t_hist::Vector{T} # Time history
34+
y_hist::Vector{uType} # Solution history
35+
dy_hist::Vector{uType} # Derivative history
36+
hist_start::Int # Start index in circular buffer
37+
hist_len::Int # Current number of valid entries
38+
39+
# Error weights buffer
40+
wt::uType
41+
end
42+
43+
"""
44+
alg_cache(alg, u, p, t, ::Val{true})
45+
46+
Create a mutable cache for in-place DASSL operations.
47+
All working arrays are pre-allocated based on the size of `u`.
48+
"""
49+
function alg_cache(alg, u::uType, p, t::T, ::Val{true}) where {T, uType}
50+
n = length(u)
51+
52+
# State vectors
53+
yn = similar(u)
54+
dyn = similar(u)
55+
y0 = similar(u)
56+
dy0 = similar(u)
57+
58+
# Newton iteration vectors
59+
delta = similar(u)
60+
residual = similar(u)
61+
ytmp = similar(u)
62+
ytmp2 = similar(u)
63+
64+
# Jacobian storage
65+
jac = zeros(eltype(u), n, n)
66+
jac_factorized = nothing
67+
a = zero(T)
68+
69+
# Numerical Jacobian vectors
70+
f_plus = similar(u)
71+
f_base = similar(u)
72+
73+
# History buffers (MAXORDER + 3 = 9 entries max needed)
74+
max_hist = MAXORDER + 3
75+
t_hist = zeros(T, max_hist)
76+
y_hist = [similar(u) for _ in 1:max_hist]
77+
dy_hist = [similar(u) for _ in 1:max_hist]
78+
79+
# Error weights
80+
wt = similar(u)
81+
82+
return DASSLCache(
83+
yn, dyn, y0, dy0,
84+
delta, residual, ytmp, ytmp2,
85+
jac, jac_factorized, a,
86+
f_plus, f_base,
87+
t_hist, y_hist, dy_hist, 1, 0,
88+
wt
89+
)
90+
end
91+
92+
"""
93+
alg_cache(alg, u, p, t, ::Val{false})
94+
95+
For out-of-place problems, return nothing (no cache needed).
96+
"""
97+
function alg_cache(alg, u, p, t, ::Val{false})
98+
return nothing
99+
end
100+
101+
# ============================================================================
102+
# Circular buffer history management
103+
# ============================================================================
104+
105+
"""
106+
push_history!(cache, t, y, dy)
107+
108+
Add a new entry to the history circular buffer.
109+
If buffer is full, overwrites the oldest entry.
110+
"""
111+
function push_history!(cache::DASSLCache, t, y, dy)
112+
max_size = length(cache.t_hist)
113+
114+
if cache.hist_len < max_size
115+
# Buffer not full, append
116+
cache.hist_len += 1
117+
else
118+
# Buffer full, advance start pointer (overwrite oldest)
119+
cache.hist_start = mod1(cache.hist_start + 1, max_size)
120+
end
121+
122+
# Compute index for new entry
123+
idx = mod1(cache.hist_start + cache.hist_len - 1, max_size)
124+
125+
cache.t_hist[idx] = t
126+
copyto!(cache.y_hist[idx], y)
127+
return copyto!(cache.dy_hist[idx], dy)
128+
end
129+
130+
"""
131+
pop_oldest_history!(cache)
132+
133+
Remove the oldest entry from the history buffer.
134+
"""
135+
function pop_oldest_history!(cache::DASSLCache)
136+
return if cache.hist_len > 0
137+
cache.hist_start = mod1(cache.hist_start + 1, length(cache.t_hist))
138+
cache.hist_len -= 1
139+
end
140+
end
141+
142+
"""
143+
get_t_at(cache, i)
144+
145+
Get time at history index i (1 = oldest in current window).
146+
"""
147+
function get_t_at(cache::DASSLCache, i::Integer)
148+
max_size = length(cache.t_hist)
149+
idx = mod1(cache.hist_start + i - 1, max_size)
150+
return cache.t_hist[idx]
151+
end
152+
153+
"""
154+
get_y_at(cache, i)
155+
156+
Get y vector at history index i (1 = oldest in current window).
157+
Returns a reference to the stored vector.
158+
"""
159+
function get_y_at(cache::DASSLCache, i::Integer)
160+
max_size = length(cache.y_hist)
161+
idx = mod1(cache.hist_start + i - 1, max_size)
162+
return cache.y_hist[idx]
163+
end
164+
165+
"""
166+
get_dy_at(cache, i)
167+
168+
Get dy vector at history index i (1 = oldest in current window).
169+
Returns a reference to the stored vector.
170+
"""
171+
function get_dy_at(cache::DASSLCache, i::Integer)
172+
max_size = length(cache.dy_hist)
173+
idx = mod1(cache.hist_start + i - 1, max_size)
174+
return cache.dy_hist[idx]
175+
end
176+
177+
"""
178+
get_latest_t(cache)
179+
180+
Get the most recent time value.
181+
"""
182+
function get_latest_t(cache::DASSLCache)
183+
return get_t_at(cache, cache.hist_len)
184+
end
185+
186+
"""
187+
get_latest_y(cache)
188+
189+
Get the most recent y vector.
190+
"""
191+
function get_latest_y(cache::DASSLCache)
192+
return get_y_at(cache, cache.hist_len)
193+
end
194+
195+
"""
196+
get_latest_dy(cache)
197+
198+
Get the most recent dy vector.
199+
"""
200+
function get_latest_dy(cache::DASSLCache)
201+
return get_dy_at(cache, cache.hist_len)
202+
end
203+
204+
"""
205+
get_history_t(cache, ord)
206+
207+
Get a vector of the last `ord` time values for interpolation.
208+
Returns values from oldest to newest within the window.
209+
Note: This allocates a small vector - could be optimized further.
210+
"""
211+
function get_history_t(cache::DASSLCache, ord::Integer)
212+
n = min(ord, cache.hist_len)
213+
t_vec = Vector{eltype(cache.t_hist)}(undef, n)
214+
start_idx = cache.hist_len - n + 1
215+
@inbounds for i in 1:n
216+
t_vec[i] = get_t_at(cache, start_idx + i - 1)
217+
end
218+
return t_vec
219+
end
220+
221+
"""
222+
get_history_y(cache, ord)
223+
224+
Get a vector of the last `ord` y vectors for interpolation.
225+
Returns references from oldest to newest within the window.
226+
Note: This allocates a small vector of references.
227+
"""
228+
function get_history_y(cache::DASSLCache, ord::Integer)
229+
n = min(ord, cache.hist_len)
230+
y_vec = Vector{eltype(cache.y_hist)}(undef, n)
231+
start_idx = cache.hist_len - n + 1
232+
@inbounds for i in 1:n
233+
y_vec[i] = get_y_at(cache, start_idx + i - 1)
234+
end
235+
return y_vec
236+
end
237+
238+
"""
239+
get_history_dy(cache, ord)
240+
241+
Get a vector of the last `ord` dy vectors for interpolation.
242+
Returns references from oldest to newest within the window.
243+
"""
244+
function get_history_dy(cache::DASSLCache, ord::Integer)
245+
n = min(ord, cache.hist_len)
246+
dy_vec = Vector{eltype(cache.dy_hist)}(undef, n)
247+
start_idx = cache.hist_len - n + 1
248+
@inbounds for i in 1:n
249+
dy_vec[i] = get_dy_at(cache, start_idx + i - 1)
250+
end
251+
return dy_vec
252+
end
253+
254+
"""
255+
clear_history!(cache)
256+
257+
Reset the history buffer to empty state.
258+
"""
259+
function clear_history!(cache::DASSLCache)
260+
cache.hist_start = 1
261+
return cache.hist_len = 0
262+
end
263+
264+
"""
265+
weights!(wt, y, reltol, abstol)
266+
267+
Compute error weights in-place.
268+
"""
269+
function weights!(wt, y, reltol, abstol)
270+
return @. wt = reltol * abs(y) + abstol
271+
end

src/common.jl

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -47,44 +47,43 @@ function solve(
4747
end
4848

4949
tspan = [prob.tspan[1], prob.tspan[2]]
50-
51-
#sizeu = size(prob.u0)
52-
#sizedu = size(prob.du0)
5350
p = prob.p
5451

55-
### Fix inplace functions to the non-inplace version
5652
if isinplace
57-
f = (t, u, du) -> (out = similar(u); prob.f(out, du, u, p, t); out)
53+
# In-place path: use pre-allocated cache for zero-allocation inner loop
54+
cache = alg_cache(alg, prob.u0, p, tspan[1], Val(true))
55+
56+
# In-place function wrapper (no allocation per call!)
57+
F! = (out, t, u, du) -> prob.f(out, du, u, p, t)
58+
59+
ts, timeseries, dus = dasslSolve!(
60+
cache, F!, prob.u0, tspan,
61+
abstol = abstol,
62+
reltol = reltol,
63+
maxstep = dtmax,
64+
minstep = dtmin,
65+
initstep = dt,
66+
dy0 = prob.du0,
67+
maxorder = alg.maxorder,
68+
factorize_jacobian = alg.factorize_jacobian
69+
)
5870
else
71+
# Out-of-place path (unchanged, backward compatible)
5972
f = (t, u) -> prob.f(u, p, t)
60-
end
6173

62-
### Finishing Routine
63-
64-
ts, timeseries,
65-
dus = dasslSolve(
66-
f, prob.u0, tspan,
67-
abstol = abstol,
68-
reltol = reltol,
69-
maxstep = dtmax,
70-
minstep = dtmin,
71-
initstep = dt,
72-
dy0 = prob.du0,
73-
maxorder = alg.maxorder,
74-
factorize_jacobian = alg.factorize_jacobian
75-
)
76-
#=
77-
timeseries = Vector{uType}(0)
78-
if typeof(prob.u0)<:Number
79-
for i=1:length(ures)
80-
push!(timeseries,ures[i][1])
81-
end
82-
else
83-
for i=1:length(ures)
84-
push!(timeseries,reshape(ures[i],sizeu))
85-
end
74+
ts, timeseries, dus = dasslSolve(
75+
f, prob.u0, tspan,
76+
abstol = abstol,
77+
reltol = reltol,
78+
maxstep = dtmax,
79+
minstep = dtmin,
80+
initstep = dt,
81+
dy0 = prob.du0,
82+
maxorder = alg.maxorder,
83+
factorize_jacobian = alg.factorize_jacobian
84+
)
8685
end
87-
=#
86+
8887
return build_solution(
8988
prob, alg, ts, timeseries, du = dus,
9089
timeseries_errors = timeseries_errors

0 commit comments

Comments
 (0)