package phylogenetics

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file linear_algebra.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
(** A vector of floats. *)
module type Vector = sig
  type t

  val length : t -> int

  (** Initialises a vector from a int->float function. *)
  val init : int -> f:(int -> float) -> t

  val map : t -> f:(float -> float) -> t

  (** Scalar-vector product (in-place). *)
  val inplace_scal_mul: float -> t -> unit

  (** Scalar-vector product *)
  val scal_mul : float -> t -> t

  (** Scalar-vector addition. *)
  val scal_add: float -> t -> t

  (** Vector addition. *)
  val add : t -> t -> t

  (** Element-wise product of two vectors. *)
  val mul : t -> t -> t

  (** Sum of the elements of a vector. *)
  val sum : t -> float

  (** Element-wise logarithm of vector *)
  val log : t -> t

  (** Element-wise exponential of matrix*)
  val exp : t -> t

  (** Minimum element in a vector. *)
  val min : t -> float

  (** Maximum element in a vector. *)
  val max : t -> float

  (** Access a specific element of a vector. *)
  val get : t -> int -> float

  (** Set a specific element of a vector. *)
  val set : t -> int -> float -> unit

  val robust_equal : tol:float -> t -> t -> bool

  val of_array : float array -> t
  val to_array : t -> float array

  (** Prints a vector to the standard output. *)
  val pp : Format.formatter -> t -> unit
end

(** A square matrix of floats. *)
module type Matrix = sig
  type vec
  type t

  val dim : t -> int * int

  (** {5 Matrix and vector creation} *)

  (** Initialises a square matrix from a int->int->float function. *)
  val init : int -> f:(int -> int -> float) -> t

  (** [init_sym n ~f] creates a symetric square matrix by calling [f]
     only for elements s.t. [i <= j] *)
  val init_sym : int -> f:(int -> int -> float) -> t

  (** Initializes a square diagonal matrix from the vector of its diagonal elements. *)
  val diagm : vec -> t

  (** Matrix element-wise multiplication *)
  val mul : t -> t -> t

  (** Matrix addition. *)
  val add : t -> t -> t

  (** Multiplication of a matrix by a scalar. *)
  val scal_mul : float -> t -> t

  (** Inplace multiplication of a matrix by a scalar. *)
  val inplace_scal_mul: float -> t -> unit

  (** Matrix multiplication *)
  val dot :
    ?transa:[`N | `T] ->
    ?transb:[`N | `T] ->
    t -> t -> t

  (** Matrix-vector product *)
  val apply : ?trans:[`N | `T] -> t -> vec -> vec

  (** Matrix exponentiation *)
  val pow : t -> int -> t

  (** Matrix exponential *)
  val expm : t -> t

  (** Element-wise logarithm of matrix *)
  val log : t -> t

  (** Compares two matrices and tolerates a certain relative difference.
      Let f be the float parameter, it returns true iff the elements of the second matrix
      are between 1-f and 1+f times the corresponding elements of the first *)
  val robust_equal : tol:float -> t -> t -> bool

  (** Access a specific element of a matrix. *)
  val get : t -> int -> int -> float

  (** Set a specific element of a matrix. *)
  val set : t -> int -> int -> float -> unit

  (** Copy row from a matrix *)
  val row : t -> int -> vec

  (** Diagonalizes a matrix M so that M = PxDxP^T; returns (v,P) where
      v is the diagonal vector of D.*)
  val diagonalize : t -> vec * t

  val transpose : t -> t

  (** Computes the inverse of a matrix. *)
  val inverse: t -> t

  (** [zero_eigen_vector m] is a vector [v] such that [Vec.sum v = 1]
      and [mat_vec_mul m v = zero] *)
  val zero_eigen_vector : t -> vec

  val of_arrays : float array array -> t option
  val of_arrays_exn : float array array -> t

  (** Prints a matrix to the standard output (display may be messy). *)
  val pp : Format.formatter -> t -> unit
end

module type S = sig
  type vec
  type mat

  module Vector : Vector with type t = vec
  module Matrix : Matrix with type t = mat and type vec := vec
end

module Lacaml = struct
  open Lacaml.D

  type mat = Lacaml.D.mat
  type vec = Lacaml.D.vec

  let inplace_scal_mat_mul f a = Mat.scal f a

  let scal_mat_mul f a =
    let r = lacpy a in
    inplace_scal_mat_mul f r ;
    r

  module Vector = struct
    type t = vec
    let length x = Vec.dim x
    let init size ~f = Vec.init size (fun i -> f (i - 1))
    let add v1 v2 = Vec.add v1 v2
    let mul v1 v2 = Vec.mul v1 v2
    let sum v = Vec.sum v
    let log v = Vec.log v
    let exp v = Vec.exp v
    let min v = Vec.min v
    let max v = Vec.max v
    let get v i = v.{i + 1}
    let set v i x = v.{i + 1} <- x
    let pp = pp_vec
    let to_array = Vec.to_array
    let of_array = Vec.of_array
    let scal_add s v = Lacaml.D.Vec.add_const s v
    let scal_mul s v =
      let r = copy v in
      scal s r ; r
    let inplace_scal_mul s v = scal s v
    let map v ~f = Vec.map f v
    let robust_equal ~tol:p v1 v2 =
      if length v1 <> length v2 then invalid_arg "incompatible dimensions" ;
      let diff = Vec.(abs (sub v1 v2)) in
      let relative_diff = (* element-wise diff/m1 *)
        mul diff (Vec.map (fun x -> 1./.x) v1)
      in
      Vec.max relative_diff <= p
  end

  module Matrix = struct
    type t = mat
    let dim m = Mat.dim1 m, Mat.dim2 m
    let init size ~f = Mat.init_rows size size (fun i j -> f (i - 1) (j - 1))

    let init_sym size ~f =
      let r = init size ~f:(fun _ _ -> 0.) in
      for i = 1 to size do
        r.{i, i} <- f (i - 1) (i - 1) ;
        for j = i + 1 to size do
          let r_ij = f (i - 1) (j - 1) in
          r.{i, j} <- r_ij ;
          r.{j, i} <- r_ij
        done
      done ;
      r

    let diagm v = Mat.of_diag v
    let add a b = Mat.add a b
    let norm1 x = lange ~norm:`O x
    let mul a b = Mat.mul a b
    let inplace_scal_mul f a = Mat.scal f a

    let scal_mul f a =
      let r = lacpy a in
      inplace_scal_mat_mul f r ;
      r

    let dot ?transa ?transb a b = gemm ?transa ?transb a b

    let apply ?trans m x = gemv ?trans m x

    let log m = Mat.log m

    let robust_equal ~tol:p m1 m2 =
      if Mat.dim1 m1 <> Mat.dim1 m2 || Mat.dim2 m1 <> Mat.dim2 m2
      then invalid_arg "incompatible dimensions" ;
      let diff = Mat.sub m1 m2 in
      lange ~norm:`M diff <= p

    let get m i j = m.{i + 1, j + 1}
    let set m i j x = m.{i + 1, j + 1} <- x
    let row mat r = Mat.copy_row mat (r + 1) (* FIXME: costly operation! *)

    let transpose m = Mat.transpose_copy m

    let diagonalize m =
      let tmp = lacpy m in (* copy matrix to avoid erasing original *)
      let _, v, c, _ = syevr ~vectors:true tmp in (* syevr = find eigenvalues and eigenvectors *)
      v, c

    let%test "Lapack Matrix.diagonalize" =
      let m = init 13 ~f:(fun i j ->
          float i +. float j
        )
      in
      let vp, p = diagonalize m in
      robust_equal
        ~tol:1e-6
        (dot p (dot (diagm vp) (transpose p)))
        m

    let pp = pp_mat

    let of_arrays_exn xs = Mat.of_array xs
    let of_arrays xs =
      try Some (of_arrays_exn xs)
      with _ -> None
    let inverse m =
      let tmp = lacpy m in (* copy matrix to avoid erasing original *)
      let tmp_vec = getrf tmp in (* getri requires a previous call to getrf (LU factorization) *)
      getri ~ipiv:tmp_vec tmp ; (* inversion *)
      tmp

    let zero_eigen_vector mat =
      let n = Mat.dim2 mat in
      if n <> Mat.dim1 mat then invalid_arg "Expected square matrix" ;
      let a = Mat.init_rows (n + 1) n (fun i j ->
          if i = n + 1 then 1. else mat.{j, i}
        )
      in
      let b = Mat.init_rows (n + 1) 1 (fun i _ -> if i <= n then 0. else 1.) in
      gels a b ;
      Vec.init n (fun i -> b.{i, 1})

    let pow x k =
      let m = Mat.dim1 x in
      let n = Mat.dim2 x in
      if m <> n then invalid_arg "non-square matrix" ;
      if k < 0 then invalid_arg "negative power" ;
      let rec loop k =
        if k = 0 then Mat.identity m
        else if k mod 2 = 0 then
          let r = loop (k / 2) in
          gemm r r
        else
          let r = loop ((k - 1) / 2) in
          gemm x (gemm r r)
      in
      loop k

    let rec naive_pow x k =
      let m = Mat.dim1 x in
      let n = Mat.dim2 x in
      if m <> n then invalid_arg "non-square matrix" ;
      if k < 0 then invalid_arg "negative power" ;
      if k = 0 then Mat.identity m
      else gemm x (naive_pow x (k - 1))

    let%test "lacaml matrix pow" =
      let m = Linear_algebra_tools.Lacaml.Mat.init 5 ~f:(fun i j -> float (i + j)) in
      robust_equal ~tol:1e-6 (pow m 13) (naive_pow m 13)

    let log_2 = Float.log 2.
    let log2 x = Float.log x /. log_2

    let expm x =
      let m = Mat.dim1 x in
      let n = Mat.dim2 x in
      if m <> n then invalid_arg "matrix not square" ;
      (* trivial case *)
      if m = 1 && n = 1 then
        Mat.make 1 1 (Float.exp x.{1, 1})
      else (
        (* TODO: use gebal to balance to improve accuracy, refer to Julia's impl *)
        let xe = Mat.identity m in
        let norm_x = norm1 x in
        (* for small norm, use lower order Padé-approximation *)
        if norm_x <= 2.097847961257068 then (
          let c = (
            if norm_x > 0.9504178996162932 then
              [|17643225600.; 8821612800.; 2075673600.; 302702400.; 30270240.; 2162160.; 110880.; 3960.; 90.; 1.|]
            else if norm_x > 0.2539398330063230 then
              [|17297280.; 8648640.; 1995840.; 277200.; 25200.; 1512.; 56.; 1.|]
            else if norm_x > 0.01495585217958292 then
              [|30240.; 15120.; 3360.; 420.; 30.; 1.|]
            else
              [|120.; 60.; 12.; 1.|]
          ) in

          let x2 = gemm x x in
          let p = ref (lacpy xe) in
          let u = scal_mat_mul c.(1) !p in
          let v = scal_mat_mul c.(0) !p in

          for i = 1 to Array.(length c / 2 - 1) do
            let j = 2 * i in
            let k = j + 1 in
            p := gemm !p x2 ;
            Mat.axpy ~alpha:c.(k) !p u ;
            Mat.axpy ~alpha:c.(j) !p v ;
          done;

          let u = gemm x u in
          let a = Mat.sub v u in
          let b = Mat.add v u in
          gesv a b ;
          b
        )
        (* for larger norm, Padé-13 approximation *)
        else (
          let s = log2 (norm_x /. 5.4) in
          let t = ceil s in
          let x = if s > 0. then scal_mul (2. ** (-. t)) x else x in

          let c =
            [|64764752532480000.; 32382376266240000.; 7771770303897600.;
              1187353796428800.;  129060195264000.;   10559470521600.;
              670442572800.;      33522128640.;       1323241920.;
              40840800.;          960960.;            16380.;
              182.;               1.|]
          in

          let x2 = gemm x x in
          let x4 = gemm x2 x2 in
          let x6 = gemm x2 x4 in
          let u =
            let m = lacpy x2 in
            inplace_scal_mat_mul c.(9) m ;
            Mat.axpy ~alpha:c.(11) x4 m ;
            Mat.axpy ~alpha:c.(13) x6 m ;
            let m = gemm x6 m in
            Mat.axpy ~alpha:c.(1) xe m ;
            Mat.axpy ~alpha:c.(3) x2 m ;
            Mat.axpy ~alpha:c.(5) x4 m ;
            Mat.axpy ~alpha:c.(7) x6 m ;
            gemm x m
          in
          let v =
            let m = lacpy x2 in
            inplace_scal_mat_mul c.(8) m ;
            Mat.axpy ~alpha:c.(10) x4 m ;
            Mat.axpy ~alpha:c.(12) x6 m ;
            let m = gemm x6 m in
            Mat.axpy ~alpha:c.(0) xe m ;
            Mat.axpy ~alpha:c.(2) x2 m ;
            Mat.axpy ~alpha:c.(4) x4 m ;
            Mat.axpy ~alpha:c.(6) x6 m ;
            m
          in
          let a = Mat.sub v u in
          let b = Mat.add v u in
          gesv a b ;

          let x = ref b in
          if s > 0. then (
            for _i = 1 to int_of_float t do
              x := gemm !x !x
            done;
          );
          !x
        )
      )

    let%test "lacaml expm 1d" =
      let c = 0.4 in
      let m = Mat.make 1 1 c in
      robust_equal ~tol:1e-6 (expm m) (Mat.make 1 1 (Float.exp c))
  end
end

include Lacaml