Source file type_functions.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
(** Type-level functions
Skim through this section when reading the documentation
for the first (second, third and …) time
*)
(** {2 Small integer type representation } *)
type 'a z = [`zero of 'a]
type 'a one = [`one of 'a]
type 'a two = [`two of 'a]
type 'a three = [`three of 'a]
type 'a four = [`four of 'a]
type ('a,'b,'c) any =
[< `zero of 'b & 'a | `one of 'b & 'a | `two of 'b & 'a] as 'c
(** {2 Type-level functions} *)
(**
[(x,y,z,d1,d2,d3_ ) product] computes the types of (x,d1) * (y,d2= and
put the result inside z and d3.
In practice, the aims is to direct the unification of the type variables
using the type values of the inputs.
For the product we have the following types
[(''dim1, 'rank1) t -> ('dim2,'rank2) t -> ('dim3, 'rank3) t ]
and we want to unify ['rank3] and ['dim3] with the right values
*)
type ('rank1, 'rank2,'rank3,'dim1,'dim2,'dim3, 'parameters) product =
[<`zero of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 z * 'p2 one
calar * scalar ⇒ scalar *)
| `one of 'rank3 * 'dim3 & 'p1 one * 'dim2
lar * vector('dim) ⇒ vector('dim) *)
| `two of 'rank3 * 'dim3 & 'p1 two * 'dim2
lar * matrix('dim) ⇒ matrix('dim) *)
]
| `one of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 one * 'dim2
| `one of 'rank3 * 'dim2 * 'dim3 & 'p1 one * 'dim1 * 'dim1
| `two of 'rank3 * 'dim2 * 'dim3 & 'p1 one * 'dim1 * 'dim1
]
| `two of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 two * 'dim2
| `one of 'rank3 * 'dim2 * 'dim3 & 'p1 one * 'dim1 * 'dim1
| `two of 'rank3 * 'dim2 * 'dim3 & 'p1 two * 'dim1 * 'dim1
]
] as 'rank1
constraint 'parameters = 'p1 * 'p2 * 'p3
type ('rank1, 'rank2,'rank3,'dim1,'dim2,'dim3, 'parameters) div =
[<`zero of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 z * 'p2 one]
| `one of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 one * 'dim1
| `one of 'rank3 * 'dim2 * 'dim3 & 'p1 one * 'dim1 * 'dim1
| `two of 'rank3 * 'dim2 * 'dim3 & 'p1 one * 'dim1 * 'dim1
]
| `two of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 two * 'dim1
| `two of 'rank3 * 'dim2 * 'dim3 & 'p2 two * 'dim1 * 'dim1
]
] as 'rank1
constraint 'parameters = 'p1 * 'p2 * 'p3
(** (x,y,z,_ ) div computes the rank of x * y and
put the result inside z *)
type ('rank1, 'rank2,'rank3, 'parameters) rank_diff =
[<
| `one of 'rank2 & [< `one of 'rank3 & 'p1 z]
| `two of 'rank2 &
[< `one of 'rank3 & 'p1 one
| `two of 'rank3 & 'p1 z ]
] as 'rank1
constraint 'parameters = 'p1
(** (x,y,z,_ ) diff computes the rank of x - y and
put the result inside z *)
type ('rank1, 'rank2,'rank3,'dim1,'dim2,'dim3, 'parameters) sum =
[<`zero of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 z * 'p2 one
| `one of 'rank3 * 'dim3 & 'p1 one * 'dim2
| `two of 'rank3 * 'dim3 & 'p1 two * 'dim2]
| `one of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 one * 'dim1
| `one of 'rank3 * 'dim1 * 'dim3 & 'p1 one * 'dim2 * 'dim2 ]
| `two of 'rank2 &
[< `zero of 'rank3 * 'dim3 & 'p1 two * 'dim1
| `two of 'rank3 * 'dim1 * 'dim3 & 'p1 two * 'dim1 * 'dim3 ]
] as 'rank1
constraint 'parameters = 'p1 * 'p2 * 'p3
(** (x,y,z,_ ) sum computes the rank of x + y and
put the result inside z *)
type ( 'dim, 'res, 'parameters ) cross =
[< `two of 'res & ('p2 * 'p1 z) | `three of 'res & ('p2 three * 'p1 one) ]
as 'dim
constraint 'parameters = 'p1 * 'p2
type ('dim1,'dim2,'dim3,'p) simple_sum =
[< `one of 'dim2 &
[< `one of 'dim3 & 'p two
| `two of 'dim3 & 'p three
| `three of 'dim3 & 'p four
]
| `two of 'dim2 &
[< `one of 'dim3 & 'p three
| `two of 'dim3 & 'p four
]
| `three of 'dim2 & [< `one of 'dim3 & 'p four]
] as 'dim1
type ('dim1,'dim2,'dim3,'p) nat_sum =
[< `one of 'dim2 &
[< `one of 'dim3 & 'p two
| `two of 'dim3 & 'p three
| `three of 'dim3 & 'p four ]
| `two of 'dim2 &
[< `one of 'dim3 & 'p three
| `two of 'dim3 & 'p four ]
| `three of 'dim2 &
[< `one of 'dim3 & 'p four ]
]
as 'dim1
+ rank ) + (dim,rank) *)
type ('tensor_rank,'index_rank,'res_rank,'dim,'res_dim, 'len, 'parameters)
superindexing =
[< `two of
'index_rank &
[< `two of
'len &
[< `one of 'res_rank & 'p z
| `two of 'res_dim * 'res_rank & 's two * 'p one
| `three of 'res_dim * 'res_rank & 's three * 'p one
| `four of 'res_dim * 'res_rank & 's four * 'p one
]
| `one of
'dim * 'len &
'res_dim *
[< `one of 'res_rank & 'p one
| `two of 'dim * 'res_rank & 'p two * 's two
| `three of 'dim * 'res_rank & 'p three * 's two
| `four of 'dim * 'res_rank & 'p four * 's two
]
]
| `one of 'len &
[< `one of 'res_rank & 'p z
| `two of 'res_rank * 'res_dim & 'p one * 's two
| `three of 'res_rank * 'res_dim & 'p one * 's three
| `four of 'res_rank * 'res_dim & 'p one * 's four
]
] as 'tensor_rank
constraint 'parameters = 'p * 's