package optiml-transport

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file transport.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
open Bigarray

type mat = (float, float64_elt, c_layout) Array2.t

type vec = (float, float64_elt, c_layout) Array1.t

(* must match EMD.h order *)
type result_internal =
  | Transport_Infeasible
  | Transport_Optimal
  | Transport_Unbounded
  | Transport_Max_iter_reached

type result =
  | Infeasible
  | Unbounded
  | Optimal of { cost : float; coupling : mat; u : vec; v : vec }
  | Max_iter_reached of { cost : float; coupling : mat; u : vec; v : vec }

type fref = { mutable field : float }

external kanto_solve :
  vec -> vec -> mat -> mat -> vec -> vec -> fref -> int -> result_internal
  = "transport_stub_bytecode" "transport_stub_native"

(* let kantorovich_raw x y d num_iter =
 *   let n1     = Array1.dim x in
 *   let n2     = Array1.dim y in
 *   let gamma  = Array2.create Float64 c_layout n1 n2 in
 *   let u      = Array1.create Float64 c_layout n1 in
 *   let v      = Array1.create Float64 c_layout n2 in
 *   let cost   = { field = -. 1.0 } in
 *   let result = kanto_solve x y d gamma u v cost num_iter in
 *   (result, gamma, u, v, cost.field) *)

let kantorovich ~x ~y ~d ~num_iter =
  let n1 = Array1.dim x in
  let n2 = Array1.dim y in
  let gamma = Array2.create Float64 c_layout n1 n2 in
  let u = Array1.create Float64 c_layout n1 in
  let v = Array1.create Float64 c_layout n2 in
  let cost = { field = -1.0 } in
  let result = kanto_solve x y d gamma u v cost num_iter in
  match result with
  | Transport_Infeasible -> Infeasible
  | Transport_Unbounded -> Unbounded
  | Transport_Optimal -> Optimal { cost = cost.field; coupling = gamma; u; v }
  | Transport_Max_iter_reached ->
      Max_iter_reached { cost = cost.field; coupling = gamma; u; v }
OCaml

Innovation. Community. Security.