package rune

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file gradcheck.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
module T = Tensor

type gradient_check_result = {
  max_abs_error : float;
  max_rel_error : float;
  mean_abs_error : float;
  mean_rel_error : float;
  failed_indices : (int array * float * float * float) list;
  passed : bool;
  num_checked : int;
  num_failed : int;
}

let default_rtol = 2e-3 (* JAX default for float32 *)
let default_atol = 2e-3 (* JAX default for float32 *)
let to_float_value t = T.item [] t

let check_gradient ?(eps = Finite_diff.default_eps) ?(rtol = default_rtol)
    ?(atol = default_atol) ?(verbose = false) ?(check_indices = None)
    ?(method_ = `Central) f x =
  let autodiff_grad = Autodiff.grad f x in

  let finite_diff_grad = Finite_diff.finite_diff ~eps ~method_ f x in

  let shape = T.shape x in
  let numel = Array.fold_left ( * ) 1 shape in

  let autodiff_flat = T.reshape [| numel |] autodiff_grad in
  let finite_diff_flat = T.reshape [| numel |] finite_diff_grad in

  let indices_to_check =
    match check_indices with
    | None -> List.init numel Fun.id
    | Some indices -> indices
  in

  let failed_indices = ref [] in
  let abs_errors = ref [] in
  let rel_errors = ref [] in

  List.iter
    (fun i ->
      let auto_val = to_float_value (T.get [ i ] autodiff_flat) in
      let finite_val = to_float_value (T.get [ i ] finite_diff_flat) in

      let abs_error = abs_float (auto_val -. finite_val) in
      let rel_error =
        if abs_float auto_val > 1e-12 || abs_float finite_val > 1e-12 then
          abs_error /. max (abs_float auto_val) (abs_float finite_val)
        else 0.0
      in

      abs_errors := abs_error :: !abs_errors;
      rel_errors := rel_error :: !rel_errors;

      let passed_check = abs_error <= atol || rel_error <= rtol in

      if not passed_check then (
        let nd_index =
          let flat_idx = i in
          let nd_idx = Array.make (Array.length shape) 0 in
          let mutable_idx = ref flat_idx in
          for dim = Array.length shape - 1 downto 0 do
            nd_idx.(dim) <- !mutable_idx mod shape.(dim);
            mutable_idx := !mutable_idx / shape.(dim)
          done;
          nd_idx
        in
        failed_indices :=
          (nd_index, auto_val, finite_val, abs_error) :: !failed_indices;

        if verbose then
          Printf.printf
            "Failed at index %s: autodiff=%.6e, finite_diff=%.6e, \
             abs_error=%.6e, rel_error=%.6e\n"
            (nd_index |> Array.to_list |> List.map string_of_int
           |> String.concat ", " |> Printf.sprintf "[%s]")
            auto_val finite_val abs_error rel_error))
    indices_to_check;

  let max_abs_error = List.fold_left max 0.0 !abs_errors in
  let max_rel_error = List.fold_left max 0.0 !rel_errors in
  let mean_abs_error =
    if !abs_errors = [] then 0.0
    else
      List.fold_left ( +. ) 0.0 !abs_errors
      /. float_of_int (List.length !abs_errors)
  in
  let mean_rel_error =
    if !rel_errors = [] then 0.0
    else
      List.fold_left ( +. ) 0.0 !rel_errors
      /. float_of_int (List.length !rel_errors)
  in

  let num_checked = List.length indices_to_check in
  let num_failed = List.length !failed_indices in
  let passed = num_failed = 0 in

  if verbose then (
    Printf.printf "\nGradient check summary:\n";
    Printf.printf "  Checked: %d elements\n" num_checked;
    Printf.printf "  Failed: %d elements\n" num_failed;
    Printf.printf "  Max absolute error: %.6e\n" max_abs_error;
    Printf.printf "  Max relative error: %.6e\n" max_rel_error;
    Printf.printf "  Mean absolute error: %.6e\n" mean_abs_error;
    Printf.printf "  Mean relative error: %.6e\n" mean_rel_error;
    Printf.printf "  Status: %s\n" (if passed then "PASSED" else "FAILED"));

  let result =
    {
      max_abs_error;
      max_rel_error;
      mean_abs_error;
      mean_rel_error;
      failed_indices = List.rev !failed_indices;
      passed;
      num_checked;
      num_failed;
    }
  in

  if passed then `Pass result else `Fail result

let check_gradients ?(eps = Finite_diff.default_eps) ?(rtol = default_rtol)
    ?(atol = default_atol) ?(verbose = false) ?(method_ = `Central) f xs =
  let autodiff_grads = Autodiff.grads f xs in

  let results =
    List.mapi
      (fun idx (x, autodiff_grad) ->
        let f_single x_i =
          let xs_copy = List.mapi (fun i x -> if i = idx then x_i else x) xs in
          f xs_copy
        in

        let finite_diff_grad =
          Finite_diff.finite_diff ~eps ~method_ f_single x
        in

        let shape = T.shape x in
        let numel = Array.fold_left ( * ) 1 shape in

        let autodiff_flat = T.reshape [| numel |] autodiff_grad in
        let finite_diff_flat = T.reshape [| numel |] finite_diff_grad in

        let failed_indices = ref [] in
        let abs_errors = ref [] in
        let rel_errors = ref [] in

        for i = 0 to numel - 1 do
          let auto_val = to_float_value (T.get [ i ] autodiff_flat) in
          let finite_val = to_float_value (T.get [ i ] finite_diff_flat) in

          let abs_error = abs_float (auto_val -. finite_val) in
          let rel_error =
            if abs_float auto_val > 1e-12 || abs_float finite_val > 1e-12 then
              abs_error /. max (abs_float auto_val) (abs_float finite_val)
            else 0.0
          in

          abs_errors := abs_error :: !abs_errors;
          rel_errors := rel_error :: !rel_errors;

          let passed_check = abs_error <= atol || rel_error <= rtol in

          if not passed_check then (
            let nd_index =
              let flat_idx = i in
              let nd_idx = Array.make (Array.length shape) 0 in
              let mutable_idx = ref flat_idx in
              for dim = Array.length shape - 1 downto 0 do
                nd_idx.(dim) <- !mutable_idx mod shape.(dim);
                mutable_idx := !mutable_idx / shape.(dim)
              done;
              nd_idx
            in
            failed_indices :=
              (nd_index, auto_val, finite_val, abs_error) :: !failed_indices;

            if verbose then
              Printf.printf
                "Input %d failed at index %s: autodiff=%.6e, finite_diff=%.6e, \
                 abs_error=%.6e, rel_error=%.6e\n"
                idx
                (nd_index |> Array.to_list |> List.map string_of_int
               |> String.concat ", " |> Printf.sprintf "[%s]")
                auto_val finite_val abs_error rel_error)
        done;

        let max_abs_error =
          if !abs_errors = [] then 0.0 else List.fold_left max 0.0 !abs_errors
        in
        let max_rel_error =
          if !rel_errors = [] then 0.0 else List.fold_left max 0.0 !rel_errors
        in
        let mean_abs_error =
          if !abs_errors = [] then 0.0
          else
            List.fold_left ( +. ) 0.0 !abs_errors
            /. float_of_int (List.length !abs_errors)
        in
        let mean_rel_error =
          if !rel_errors = [] then 0.0
          else
            List.fold_left ( +. ) 0.0 !rel_errors
            /. float_of_int (List.length !rel_errors)
        in

        let num_checked = numel in
        let num_failed = List.length !failed_indices in
        let passed = num_failed = 0 in

        if verbose then (
          Printf.printf "\nGradient check summary for input %d:\n" idx;
          Printf.printf "  Checked: %d elements\n" num_checked;
          Printf.printf "  Failed: %d elements\n" num_failed;
          Printf.printf "  Max absolute error: %.6e\n" max_abs_error;
          Printf.printf "  Max relative error: %.6e\n" max_rel_error;
          Printf.printf "  Mean absolute error: %.6e\n" mean_abs_error;
          Printf.printf "  Mean relative error: %.6e\n" mean_rel_error;
          Printf.printf "  Status: %s\n" (if passed then "PASSED" else "FAILED"));

        {
          max_abs_error;
          max_rel_error;
          mean_abs_error;
          mean_rel_error;
          failed_indices = List.rev !failed_indices;
          passed;
          num_checked;
          num_failed;
        })
      (List.combine xs autodiff_grads)
  in

  let all_passed = List.for_all (fun r -> r.passed) results in
  if all_passed then `Pass results else `Fail results