package owl-ode

  1. Overview
  2. Docs

Source file symplectic_generic.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# 1 "src/ode/symplectic/symplectic_generic.ml"
(*
 * OWL - OCaml Scientific and Engineering Computing
 * OWL-ODE - Ordinary Differential Equation Solvers
 *
 * Copyright (c) 2019 Ta-Chu Kao <tck29@cam.ac.uk>
 * Copyright (c) 2019 Marcello Seri <m.seri@rug.nl>
 *)

open Types

module Make (M : Owl_types_ndarray_algodiff.Sig with type elt = float) = struct
  module C = Common.Make (M)

  type f_t = M.arr * M.arr -> float -> M.arr

  module M = struct
    include M

    (* TODO: implement this in owl *)
    let ( *$ ) = M.mul_scalar
    let ( + ) = M.add
  end

  let prepare step f (x0, p0) tspec () =
    let tspan, dt =
      match tspec with
      | T1 { t0; duration; dt } -> (t0, t0 +. duration), dt
      | T2 { tspan; dt } -> tspan, dt
      | T3 _ -> raise Owl_exception.(NOT_IMPLEMENTED "T3 not implemented")
    in
    let step = step f ~dt in
    C.symplectic_integrate ~step ~tspan ~dt (x0, p0)


  let symplectic_euler_s (f : f_t) ~dt (xs, ps) t0 =
    let t = t0 +. dt in
    let fxs = f (xs, ps) t in
    let ps' = M.(ps + (fxs *$ dt)) in
    let xs' = M.(xs + (ps' *$ dt)) in
    (xs', ps'), t


  let symplectic_euler =
    (module struct
      type state = M.arr * M.arr
      type f = M.arr * M.arr -> float -> M.arr
      type step_output = (M.arr * M.arr) * float
      type solve_output = M.arr * M.arr * M.arr

      let step = symplectic_euler_s
      let solve = prepare step
    end
    : Solver
      with type state = M.arr * M.arr
       and type f = M.arr * M.arr -> float -> M.arr
       and type step_output = (M.arr * M.arr) * float
       and type solve_output = M.arr * M.arr * M.arr)


  let leapfrog_s (f : f_t) ~dt (xs, ps) t0 =
    let t = t0 +. dt in
    let fxs = f (xs, ps) t in
    let xs' = M.(xs + (ps *$ dt) + (fxs *$ (dt *. dt *. 0.5))) in
    let fxs' = f (xs', ps) (t +. dt) in
    let ps' = M.(ps + ((fxs + fxs') *$ (dt *. 0.5))) in
    (xs', ps'), t


  let leapfrog =
    (module struct
      type state = M.arr * M.arr
      type f = M.arr * M.arr -> float -> M.arr
      type step_output = (M.arr * M.arr) * float
      type solve_output = M.arr * M.arr * M.arr

      let step = leapfrog_s
      let solve = prepare step
    end
    : Solver
      with type state = M.arr * M.arr
       and type f = M.arr * M.arr -> float -> M.arr
       and type step_output = (M.arr * M.arr) * float
       and type solve_output = M.arr * M.arr * M.arr)


  (* For the values used in the implementations below
     see Candy-Rozmus (https://www.sciencedirect.com/science/article/pii/002199919190299Z)
     and https://en.wikipedia.org/wiki/Symplectic_integrator *)
  let symint ~coeffs (f : f_t) ~dt =
    let symint_step ~coeffs f (xs, ps) t dt =
      List.fold_left
        (fun ((xs, ps), t) (ai, bi) ->
          let ps' = M.(ps + (f (xs, ps) t *$ (dt *. bi))) in
          let xs' = M.(xs + (ps' *$ (dt *. ai))) in
          let t = t +. (dt *. ai) in
          (xs', ps'), t)
        ((xs, ps), t)
        coeffs
    in
    fun (xs, ps) t -> symint_step ~coeffs f (xs, ps) t dt


  let leapfrog_c = [ 0.5, 0.0; 0.5, 1.0 ]
  let pseudoleapfrog_c = [ 1.0, 0.5; 0.0, 0.5 ]
  let ruth3_c = [ 2.0 /. 3.0, 7.0 /. 24.0; -2.0 /. 3.0, 0.75; 1.0, -1.0 /. 24.0 ]

  let ruth4_c =
    let c = Owl.Maths.pow 2.0 (1.0 /. 3.0) in
    [ 0.5, 0.0; 0.5 *. (1.0 -. c), 1.0; 0.5 *. (1.0 -. c), -.c; 0.5, 1.0 ]
    |> List.map (fun (v1, v2) -> v1 /. (2.0 -. c), v2 /. (2.0 -. c))


  let _leapfrog_s' f ~dt = symint ~coeffs:leapfrog_c f ~dt
  let pseudoleapfrog_s f ~dt = symint ~coeffs:pseudoleapfrog_c f ~dt

  let pseudoleapfrog =
    (module struct
      type state = M.arr * M.arr
      type f = M.arr * M.arr -> float -> M.arr
      type step_output = (M.arr * M.arr) * float
      type solve_output = M.arr * M.arr * M.arr

      let step = pseudoleapfrog_s
      let solve = prepare step
    end
    : Solver
      with type state = M.arr * M.arr
       and type f = M.arr * M.arr -> float -> M.arr
       and type step_output = (M.arr * M.arr) * float
       and type solve_output = M.arr * M.arr * M.arr)


  let ruth3_s f ~dt = symint ~coeffs:ruth3_c f ~dt

  let ruth3 =
    (module struct
      type state = M.arr * M.arr
      type f = M.arr * M.arr -> float -> M.arr
      type step_output = (M.arr * M.arr) * float
      type solve_output = M.arr * M.arr * M.arr

      let step = ruth3_s
      let solve = prepare step
    end
    : Solver
      with type state = M.arr * M.arr
       and type f = M.arr * M.arr -> float -> M.arr
       and type step_output = (M.arr * M.arr) * float
       and type solve_output = M.arr * M.arr * M.arr)


  let ruth4_s f ~dt = symint ~coeffs:ruth4_c f ~dt

  let ruth4 =
    (module struct
      type state = M.arr * M.arr
      type f = M.arr * M.arr -> float -> M.arr
      type step_output = (M.arr * M.arr) * float
      type solve_output = M.arr * M.arr * M.arr

      let step = ruth4_s
      let solve = prepare step
    end
    : Solver
      with type state = M.arr * M.arr
       and type f = M.arr * M.arr -> float -> M.arr
       and type step_output = (M.arr * M.arr) * float
       and type solve_output = M.arr * M.arr * M.arr)


  (*
    (* XXX:
    We would like to do

        pint = so.fsolve(
            lambda pint: p - pint + 0.5*h*acc(x, pint, t0+i*h),
            p
        )[0]
        xnew = x + h*pint
        pnew = pint + 0.5*h*acc(xnew, pint, t0+(i+1)*h)
        sol[i+1] = np.array((pnew, xnew))

    but http://ocaml.xyz/apidoc/owl_M.arrhs_root.html does not seem
    powerful enough for that in general.
    *)

let leapfrog_implicit ~f y0 (t0, t1) dt =
  let _, elts = M.shape y0 in
  assert (M.s.is_even elts);

  let steps = steps t0 t1 dt in
  let sol = M.empty steps elts in

  sol.${[[0]]}<- y0;
         for idx = 1 to steps-1 do
          (* TODO *)
          ()
         done;
         sol
  *)

  (* ----- helper functions ----- *)

  let to_state_array ?(axis = 0) (dim1, dim2) xs ps =
    let unpack =
      if axis = 0
      then M.to_rows
      else if axis = 1
      then M.to_cols
      else raise Owl_exception.INDEX_OUT_OF_BOUND
    in
    let xs = unpack xs in
    let ps = unpack ps in
    if M.numel xs.(0) <> dim1 * dim2
    then raise Owl_exception.(DIFFERENT_SHAPE ([| M.numel xs.(0) |], [| dim1 * dim2 |]));
    if M.numel ps.(0) <> dim1 * dim2
    then raise Owl_exception.(DIFFERENT_SHAPE ([| M.numel ps.(0) |], [| dim1 * dim2 |]));
    ( Array.map (fun x -> M.reshape x [| dim1; dim2 |]) xs
    , Array.map (fun p -> M.reshape p [| dim1; dim2 |]) ps )
end