Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file linear_algebra.ml
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264openCore(** Vectors *)moduleVec=structtypet=floatarray[@@derivingsexp]letcopy=Array.copyletcreate0len=Array.create~len0.letsumsqt=letaccum=ref0.infori=0toArray.lengtht-1dolett_i=t.(i)inaccum:=!accum+.t_i*.t_idone;!accumletnormt=sqrt(sumsqt)letalmost_equal~tolt1t2=Array.lengtht1=Array.lengtht2&&Array.for_all2_exnt1t2~f:(funxy->Float.(<=)(Float.abs(x-.y))tol)end(** Matrices *)moduleMat=structtypet=floatarrayarray[@@derivingsexp]letcopyt=Array.mapt~f:Array.copyletcreate0~rows~cols=letmake_row_=Array.create~len:cols0.inArray.initrows~f:make_rowletcreate_per_row~rows~cols~f=letmake_row_=Array.initcols~finArray.initrows~f:make_row(** The norm of a column of a matrix. *)letcol_normtcolumn=letaccum=ref0.infori=0toArray.lengtht-1doletentry=t.(i).(column)inaccum:=!accum+.entry*.entrydone;sqrt!accum(** The inner product of columns j1 and j2 *)letcol_inner_prodtj1j2=letaccum=ref0.infori=0toArray.lengtht-1doaccum:=!accum+.t.(i).(j1)*.t.(i).(j2)done;!accumletget_columntj=letlen=Array.lengthtinArray.initlen~f:(funi->t.(i).(j))letalmost_equal~tolt1t2=Array.lengtht1=Array.lengtht2&&Array.for_all2_exnt1t2~f:(Vec.almost_equal~tol)endletqr_in_placea=(* Our implementation will just do Gram-Schmidt. *)letm=Array.lengthainifm=0then([||],[||])(* empty QR decomposition *)elseletn=Array.lengtha.(0)inletr=Mat.create0~rows:n~cols:ninforj=0ton-1do(* handle column j *)letalpha=Mat.col_normajinr.(j).(j)<-alpha;letone_over_alpha=1./.alphain(* Rescale this column so that it's a unit vector. *)fori=0tom-1doa.(i).(j)<-a.(i).(j)*.one_over_alphadone;forj2=j+1ton-1doletc=Mat.col_inner_prodajj2inr.(j).(j2)<-c;(* Now, subtract c * column j from column j2. *)fori=0tom-1doa.(i).(j2)<-a.(i).(j2)-.c*.a.(i).(j)donedonedone;(a,r)(** [qr A] returns the QR-decomposition of [A] as a pair (Q,R). [A] must have
at least as many rows as columns and have full rank.
If [in_place] (default: [false]) is [true], then [A] is overwritten with [Q].
*)letqr?(in_place=false)a=leta=ifin_placethenaelseMat.copyainqr_in_placea(** [triu_solve R b] solves R x = b where [R] is an m x m upper-triangular matrix
and [b] is an m x 1 column vector. *)lettriu_solverb=letm=Array.lengthbinifm<>Array.lengthrthenOr_error.error_string"triu_solve R b requires R to be square with same \
number of rows as b"elseifm=0thenOk[||]elseifm<>Array.lengthr.(0)thenOr_error.error_string"triu_solve R b requires R to be square"elseletsol=Vec.copybinfori=m-1downto0dosol.(i)<-sol.(i)/.r.(i).(i);forj=0toi-1dosol.(j)<-sol.(j)-.r.(j).(i)*.sol.(i)donedone;if(Array.existssol~f:Float.is_nan)thenOr_error.error_string"triu_solve detected NaN result"elseOksol(** [mul A B] computes the matrix product [A * B]. If [transa] (default: [false])
is [true], then we compute A' * B where A' denotes the transpose of A. *)(* val mul : ?transa:bool -> m -> m -> m *)(** [mul_mv A x] computes the product [A * x] (where [A] is a matrix and [x] is
a column vector).
[transa] is as with [mul].
*)letmul_mv?(transa=false)ax=(* we let c denote either a or a', depending on whether transa is set. *)letrows=Array.lengthainifrows=0then[||]elseletcols=Array.lengtha.(0)in(* (m, n, c_get) will be (rows of c, columns of c, accessor for c). *)let(m,n,c_get)=iftransathenletc_getij=a.(j).(i)in(cols,rows,c_get)elseletc_getij=a.(i).(j)in(rows,cols,c_get)inifn<>Array.lengthxthenfailwith"Dimension mismatch";letresult=Vec.create0minfori=0tom-1doresult.(i)<-Array.foldix~init:0.~f:(funjaccumx_j->accum+.c_getij*.x_j)done;result(** [ols A b] computes the ordinary least-squares solution to A x = b.
[A] must have at least as many rows as columns and have full rank.
This can be used to compute solutions to non-singular square systems,
but is somewhat sub-optimal for that purpose.
The algorithm is to factor A = Q * R and solve R x = Q' b where Q' denotes
the transpose of Q.
*)letols?(in_place=false)ab=let(q,r)=qr~in_placeaintriu_solver(mul_mv~transa:trueqb)let%test_module_=(modulestruct(* The examples and the correct reference values were generated in Octave. *)letmat_A=[|[|1.5539829;-0.4525782;-1.1728152;1.3674086;-1.1205482|];[|0.6792944;1.1568534;0.4154379;0.9084153;-2.5703106|];[|-0.5618483;-0.6781523;-0.5248221;-0.4142220;-0.7306068|];[|-0.5998192;1.3722146;-0.5557165;0.0363979;0.2204308|];[|0.9425094;-0.3673329;0.0099052;-0.1091253;0.9456771|];[|-0.6091836;0.6229814;1.1498873;1.6160578;0.7104362|];[|-0.1933751;1.3707531;0.6352440;1.3795393;-1.1168355|]|](* In case you want to enter the matrix back into Octave:
1.5539829 -0.4525782 -1.1728152 1.3674086 -1.1205482
0.6792944 1.1568534 0.4154379 0.9084153 -2.5703106
-0.5618483 -0.6781523 -0.5248221 -0.4142220 -0.7306068
-0.5998192 1.3722146 -0.5557165 0.0363979 0.2204308
0.9425094 -0.3673329 0.0099052 -0.1091253 0.9456771
-0.6091836 0.6229814 1.1498873 1.6160578 0.7104362
-0.1933751 1.3707531 0.6352440 1.3795393 -1.1168355
*)(* The known correct values for QR decomposition of A. *)letknown_Q=[|[|0.7057304;-0.0081434;-0.3701402;0.5895472;0.0512010|];[|0.3084968;0.5535933;0.1618997;-0.1938150;-0.5796110|];[|-0.2551595;-0.3432622;-0.2917051;0.2900536;-0.4904861|];[|-0.2724037;0.4956579;-0.6737123;-0.0924160;0.3724562|];[|0.4280340;-0.0431215;0.2265497;-0.2622756;0.4346026|];[|-0.2766564;0.1864428;0.4836269;0.6300418;0.2805885|];[|-0.0878199;0.5416106;0.1121823;0.2376066;-0.1204989|]|](* 0.7057304 -0.0081434 -0.3701402 0.5895472 0.0512010
0.3084968 0.5535933 0.1618997 -0.1938150 -0.5796110
-0.2551595 -0.3432622 -0.2917051 0.2900536 -0.4904861
-0.2724037 0.4956579 -0.6737123 -0.0924160 0.3724562
0.4280340 -0.0431215 0.2265497 -0.2622756 0.4346026
-0.2766564 0.1864428 0.4836269 0.6300418 0.2805885
-0.0878199 0.5416106 0.1121823 0.2376066 -0.1204989
*)letknown_R=[|[|2.2019498;-0.6132342;-0.7839086;0.7260895;-1.1510467|];[|0.0000000;2.4314496;0.7022565;1.7051660;-1.5669469|];[|0.0000000;0.0000000;1.6584752;0.6488550;0.4957819|];[|0.0000000;0.0000000;0.0000000;1.8811696;-0.4605288|];[|0.0000000;0.0000000;0.0000000;0.0000000;2.6177719|]|](* 2.2019498 -0.6132342 -0.7839086 0.7260895 -1.1510467
0.0000000 2.4314496 0.7022565 1.7051660 -1.5669469
0.0000000 0.0000000 1.6584752 0.6488550 0.4957819
0.0000000 0.0000000 0.0000000 1.8811696 -0.4605288
0.0000000 0.0000000 0.0000000 0.0000000 2.6177719
*)let(q,r)=qrmat_Alet%test"qr: correct Q"=Mat.almost_equal~tol:1e-7qknown_Qlet%test"qr: correct R"=Mat.almost_equal~tol:1e-7rknown_Rletv=[|-0.1970397;1.1226276;2.1068430;-1.0784432;-0.2012862|]letknown_A_times_v=[|-4.5343322;1.5778238;-1.1625479;0.4042440;-0.6498874;1.3562139;1.6523560|]letknown_R_inverse_v=[|0.7162378;0.3869500;1.5249890;-0.5921073;-0.0768922|]let%test"mul_mv"=Vec.almost_equal~tol:1e-7(mul_mvmat_Av)known_A_times_vlet%test"triu_solve"=matchtriu_solvervwith|Okr_inverse_v->Vec.almost_equal~tol:1e-7r_inverse_vknown_R_inverse_v|Error_->falseletw=[|-0.5465946;0.5402624;-1.9966324;-1.0315546;0.0570856;1.3310883;-0.1333566|]letknown_A_backslash_w=(* like A \ w in MATLAB or Octave *)[|7.08921663963053e-01;7.00514461674804e-02;1.25581479352235e+00;-6.33055939677457e-04;2.55312671644131e-01|]let%test"ols"=matcholsmat_Awwith|Oka_backslash_w->Vec.almost_equal~tol:1e-14a_backslash_wknown_A_backslash_w|Error_->falseend)