package owl-base

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file owl_algodiff_core_sig.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# 1 "src/base/algodiff/owl_algodiff_core_sig.ml"
(*
 * OWL - OCaml Scientific Computing
 * Copyright (c) 2016-2022 Liang Wang <liang@ocaml.xyz>
 *)

module type Sig = sig
  module A : Owl_types_ndarray_algodiff.Sig

  (** {5 Type definition} *)

  include Owl_algodiff_types_sig.Sig with type elt := A.elt and type arr := A.arr

  (** {5 Core functions} *)

  val tag : unit -> int
  (** start global tagging counter *)

  val primal : t -> t
  (** get primal component of DF or DR type *)

  val primal' : t -> t
  (** iteratively get primal component of DF or DR type until the component itself is not DF/DR *)

  val zero : t -> t
  (** return a zero value, which type decided by the input value *)

  val reset_zero : t -> t
  (** [reset_zero x] iteratively resets all elements included in [x] *)

  val tangent : t -> t
  (** get the tangent component of input, if the data type is suitable *)

  val adjref : t -> t ref
  (** get the adjref component of input, if the data type is suitable *)

  val adjval : t -> t
  (** get the adjval component of input, if the data type is suitableTODO *)

  val shape : t -> int array
  (** get the shape of primal' value of input *)

  val is_float : t -> bool
  (** check if input is of float value; if input is of type DF/DR, check its primal' value *)

  val is_arr : t -> bool
  (** check if input is of ndarray value; if input is of type DF/DR, check its primal' value *)

  val row_num : t -> int
  (** get the shape of primal' value of input; and then get the first dimension *)

  val col_num : t -> int
  (** get the shape of primal' value of input; and then get the second dimension *)

  val numel : t -> int
  (** for ndarray type input, return its total number of elements. *)

  val clip_by_value : amin:A.elt -> amax:A.elt -> t -> t
  (** other functions, without tracking gradient *)

  val clip_by_l2norm : A.elt -> t -> t
  (** other functions, without tracking gradient *)

  val copy_primal' : t -> t
  (** if primal' value of input is ndarray, copy its value in a new AD type ndarray *)

  val tile : t -> int array -> t
  (** if primal' value of input is ndarray, apply the tile function *)

  val repeat : t -> int array -> t
  (** if primal' value of input is ndarray, apply the repeat function *)

  val pack_elt : A.elt -> t
  (** convert from [elt] type to [t] type. *)

  val unpack_elt : t -> A.elt
  (** convert from [t] type to [elt] type. *)

  val pack_flt : float -> t
  (** convert from [float] type to [t] type. *)

  val _f : float -> t
  (** A shortcut function for [F A.(float_to_elt x)]. *)

  val unpack_flt : t -> float
  (** convert from [t] type to [float] type. *)

  val pack_arr : A.arr -> t
  (** convert from [arr] type to [t] type. *)

  val unpack_arr : t -> A.arr
  (** convert from [t] type to [arr] type. *)

  (* functions to report errors, help in debugging *)

  val deep_info : t -> string

  val type_info : t -> string

  val error_binop : string -> t -> t -> 'a

  val error_uniop : string -> t -> 'a

end