Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file owl_utils_ndarray.ml
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143# 1 "src/base/misc/owl_utils_ndarray.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)openBigarray(* convert an element of elt type to string *)letelt_to_str:typeab.(a,b)kind->(a->string)=function|Char->funv->Printf.sprintf"%c"v|Nativeint->funv->Printf.sprintf"%nd"v|Int8_signed->funv->Printf.sprintf"%i"v|Int8_unsigned->funv->Printf.sprintf"%i"v|Int16_signed->funv->Printf.sprintf"%i"v|Int16_unsigned->funv->Printf.sprintf"%i"v|Int->funv->Printf.sprintf"%i"v|Int32->funv->Printf.sprintf"%ld"v|Int64->funv->Printf.sprintf"%Ld"v|Float32->funv->Printf.sprintf"%G"v|Float64->funv->Printf.sprintf"%G"v|Complex32->funv->Printf.sprintf"(%G, %Gi)"Complex.(v.re)Complex.(v.im)|Complex64->funv->Printf.sprintf"(%G, %Gi)"Complex.(v.re)Complex.(v.im)(* convert an element of string to elt type *)letelt_of_str:typeab.(a,b)kind->(string->a)=function|Char->funv->Scanf.sscanfv"%c%!"(func->c)|Nativeint->funv->Nativeint.of_stringv|Int8_signed->funv->int_of_stringv|Int8_unsigned->funv->int_of_stringv|Int16_signed->funv->int_of_stringv|Int16_unsigned->funv->int_of_stringv|Int->funv->int_of_stringv|Int32->funv->Int32.of_stringv|Int64->funv->Int64.of_stringv|Float32->funv->float_of_stringv|Float64->funv->float_of_stringv|Complex32->funv->Scanf.sscanfv"(%f, %fi)%!"(funreim->{Complex.re;im})|Complex64->funv->Scanf.sscanfv"(%f, %fi)%!"(funreim->{Complex.re;im})(* calculate the number of elements in an ndarray *)letnumelx=Array.fold_right(funca->c*a)(Genarray.dimsx)1(* calculate the stride of a ndarray, s is the shape.
for [x] of shape [|2;3;4|], the return is [|12;4;1|]
*)letcalc_strides=letd=Array.lengthsinletr=Array.maked1infori=1tod-1dor.(d-i-1)<-s.(d-i)*r.(d-i)done;r(* calculate the slice size in each dimension, s is the shape.
for [x] of shape [|2;3;4|], the return is [|24;12;4|]
*)letcalc_slices=letd=Array.lengthsinletr=Array.makeds.(d-1)infori=d-2downto0dor.(i)<-s.(i)*r.(i+1)done;r(* c layout index translation: 1d -> nd
i is one-dimensional index;
j is n-dimensional index;
s is the stride.
the space of j needs to be pre-allocated *)letindex_1d_ndijs=j.(0)<-i/s.(0);fork=1toArray.lengths-1doj.(k)<-(imods.(k-1))/s.(k);done(* c layout index translation: nd -> 1d
j is n-dimensional index;
s is the stride. *)letindex_nd_1djs=leti=ref0inArray.iteri(funka->i:=!i+(a*s.(k)))j;!i(* given ndarray [x] and 1d index, return nd index. *)letindxi_1d=letshape=Genarray.dimsxinletstride=calc_strideshapeinleti_nd=Array.copystrideinindex_1d_ndi_1di_ndstride;i_nd(* given ndarray [x] and nd index, return 1d index. *)leti1dxi_nd=letshape=Genarray.dimsxinletstride=calc_strideshapeinindex_nd_1di_ndstride(* Adjust the index according to the [0, m). m is the boundary, i can be negative. *)letadjust_indexim=ifi>=0&&i<mthenielseifi<0&&i>=-mtheni+melseraiseOwl_exception.INDEX_OUT_OF_BOUND(* prepare the parameters for reduce/fold operation, [a] is axis *)letreduce_paramsax=letd=Genarray.num_dimsxinleta=adjust_indexadinlet_shape=Genarray.dimsxinlet_stride=calc_stride_shapeinlet_slicez=calc_slice_shapeinletm=(numelx)/_slicez.(a)inletn=_slicez.(a)inleto=_stride.(a)in_shape.(a)<-1;m,n,o,_shape(* check whether two shapes are broadcastable *)letbroadcastables0s1=letsa,sb=Owl_utils_array.align`Left1s0s1intry(Array.iter2(funab->Owl_exception.(check(not(a<>1&&b<>1&&a<>b))NOT_BROADCASTABLE);)sasb;true)with_exn->false(* ends here *)