Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source
Source file owl_neural_compiler.ml
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450# 1 "src/base/neural/owl_neural_compiler.ml"(*
* OWL - OCaml Scientific and Engineering Computing
* Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
*)moduleMake(E:Owl_types_computation_engine.Sig)=structmoduleEngine=Owl_computation_engine.Flatten(E)moduleNeural=Owl_neural_generic.Make(Engine)openNeuralopenAlgodiff(** Naive compilation functions, need to pass in loss function *)letcompile_simplenetworkinput_shapeloss_fun=Graph.initnetwork;Graph.mkparnetwork|>Owl_utils.aarr_map(funv->Engine.var_arr""~shape:(unpack_arrv|>Engine.shape)|>pack_arr)|>Graph.updatenetwork;(* derive the computation graph in reverse mode *)letx=Engine.var_arr"x"~shape:input_shape|>pack_arrinlety'=Graph.forwardnetworkx|>fstinletoutput_shape=unpack_arry'|>Engine.shapeinlety=Engine.var_arr"y"~shape:output_shape|>pack_arrinletloss=loss_funyy'inletz=Graph.(backwardnetworkloss)inletpri=Owl_utils_array.flatten(fstz)inletadj=Owl_utils_array.flatten(sndz)in(* assign loss variable name *)Owl_graph.set_name(unpack_eltloss|>Engine.elt_to_node)"loss";(* assign input variable names *)Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"x%i"iinOwl_graph.set_namebs)pri;(* assign output variable names *)Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"x%i'"iinOwl_graph.set_namebs)adj;letxt=unpack_arrxinletyt=unpack_arryinletpri=Array.mapunpack_arrpriinletadj=Array.mapunpack_arradjinxt,yt,pri,adj(** Shallow compilation functions, includes only gradient *)letcompile_shallow(params:Params.typ)networkfull_size=(* extract configurations of a network *)letloss_fun=Loss.runparams.lossin(* infer input shape from batch size and network shape *)letbatch=matchparams.batchwith|Full->full_size|Minin->n|Samplen->n|Stochastic->1inletnetwork_shape=Graph.input_shapenetworkinletinput_shape=Array.append[|batch|]network_shapein(* initialise the network weight *)Graph.initnetwork;Graph.mkparnetwork|>Owl_utils.aarr_map(funv->letv=Algodiff.unpack_arrvinEngine.eval_arr[|v|];letu=Engine.var_arr""~shape:(Engine.shapev)inEngine.(assign_arru(unpack_arrv));Algodiff.pack_arru)|>Graph.updatenetwork;(* derive the computation graph in forward mode *)letx=Engine.var_arr"x"~shape:input_shape|>pack_arrinlety'=Graph.forwardnetworkx|>fstinletoutput_shape=unpack_arry'|>Engine.shapeinlety=Engine.var_arr"y"~shape:output_shape|>pack_arrinletloss=loss_funyy'inletloss=Maths.(loss/(_f(Mat.row_numy|>float_of_int)))in(* derive the computation graph in reverse mode *)letz=Graph.(backwardnetworkloss)inletpri=Owl_utils_array.flatten(fstz)inletadj=Owl_utils_array.flatten(sndz)in(* assign loss variable name *)Owl_graph.set_name(unpack_eltloss|>Engine.elt_to_node)"loss";(* assign input variable names *)Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"x%i"iinOwl_graph.set_namebs)pri;(* assign output variable names *)Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"x%i'"iinOwl_graph.set_namebs)adj;(* freeze the graph *)leta0=[|unpack_eltloss|>Engine.elt_to_node|]inleta1=Array.map(funv->unpack_arrv|>Engine.arr_to_node)priinleta2=Array.map(funv->unpack_arrv|>Engine.arr_to_node)adjinleta3=Owl_utils_array.(a0@a1@a2)in(* FIXME: experimental *)Engine.freeze_ancestorsa3;(* return key parameters *)x,y,pri,adj,loss(** Deep compilation functions, includes gs, us, ps, ch, and new weights *)letcompile_deep(params:Params.typ)networkfull_size=(* extract configurations of a network *)letloss_fun=Loss.runparams.lossinletgrad_fun=Gradient.runparams.gradientinletrate_fun=Learning_Rate.runparams.learning_rateinletregl_fun=Regularisation.runparams.regularisationinletmomt_fun=Momentum.runparams.momentuminletupch_fun=Learning_Rate.update_chparams.learning_rateinletclip_fun=Clipping.runparams.clippingin(* infer input shape from batch size and network shape *)letbatch=matchparams.batchwith|Full->full_size|Minin->n|Samplen->n|Stochastic->1inletnetwork_shape=Graph.input_shapenetworkinletinput_shape=Array.append[|batch|]network_shapein(* initialise the network weight *)Graph.initnetwork;Graph.mkparnetwork|>Owl_utils.aarr_map(funv->letv=Algodiff.unpack_arrvinEngine.eval_arr[|v|];letu=Engine.var_arr""~shape:(Engine.shapev)inEngine.(assign_arru(unpack_arrv));Algodiff.pack_arru)|>Graph.updatenetwork;(* derive the computation graph in forward mode *)letx=Engine.var_arr"x"~shape:input_shape|>pack_arrinlety'=Graph.forwardnetworkx|>fstinletoutput_shape=unpack_arry'|>Engine.shapeinlety=Engine.var_arr"y"~shape:output_shape|>pack_arrinletloss=loss_funyy'inletloss=Maths.(loss/(_f(Mat.row_numy|>float_of_int)))in(* add regularisation term if necessary *)letws=Owl_utils_array.flatten(Graph.mkprinetwork)inletreg=matchparams.regularisation<>Regularisation.Nonewith|true->Array.fold_left(funaw->Maths.(a+regl_funw))(_f0.)ws|false->_f0.inletloss=Maths.(loss+reg)in(* assign loss variable name *)Owl_graph.set_name(unpack_eltloss|>Engine.elt_to_node)"loss";(* derive the computation graph in reverse mode *)letz=Graph.(backwardnetworkloss)inletws=Owl_utils_array.flatten(fstz)inletgs'=Owl_utils_array.flatten(sndz)in(* assign input/output variable names *)Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"ws%i"iinOwl_graph.set_namebs)ws;Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"gs'%i'"iinOwl_graph.set_namebs)gs';(* allocate variables for optimisation engine *)letgs=Array.mapi(funiw->letname=Printf.sprintf"gs%i"iinletshape=Engine.shape(unpack_arrw)inEngine.var_arrname~shape|>pack_arr)wsinletps=Array.mapi(funiw->letname=Printf.sprintf"ps%i"iinletshape=Engine.shape(unpack_arrw)inEngine.var_arrname~shape|>pack_arr)wsinletus=Array.mapi(funiw->letname=Printf.sprintf"us%i"iinletshape=Engine.shape(unpack_arrw)inEngine.var_arrname~shape|>pack_arr)wsinletch=Array.mapi(funiw->letname1=Printf.sprintf"cha%i"iinletname2=Printf.sprintf"chb%i"iinletshape=Engine.shape(unpack_arrw)inletch1=Engine.var_arrname1~shape|>pack_arrinletch2=Engine.var_arrname2~shape|>pack_arrin[|ch1;ch2|])wsin(* calculate the new weights of the network *)(* clip the gradient if necessary *)letgs'=Array.mapclip_fungs'in(* calculate gradient descent *)letps'=Owl_utils_array.map4(grad_fun(funa->a))wsgspsgs'in(* update gcache if necessary *)letch'=Owl_utils_array.map2upch_fungs'chin(* adjust direction based on learning_rate *)letus'=Owl_utils_array.map3(funp'g'c->(* FIXME: 999 is just place holder *)Maths.(p'*rate_fun999g'c))ps'gs'ch'in(* adjust direction based on momentum *)letus'=Owl_utils_array.map2momt_funusus'in(* update the weight *)letws'=Owl_utils_array.map2(funwu->Maths.(w+u))wsus'in(* assign output variable names *)Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"ws'%i"iinOwl_graph.set_namebs)ws';Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"ps'%i"iinOwl_graph.set_namebs)ps';Array.iteri(funia->letb=unpack_arra|>Engine.arr_to_nodeinlets=Printf.sprintf"us'%i"iinOwl_graph.set_namebs)us';Array.iteri(funia->letc0=unpack_arra.(0)|>Engine.arr_to_nodeinletc1=unpack_arra.(1)|>Engine.arr_to_nodeinlets0=Printf.sprintf"cha'%i"iinlets1=Printf.sprintf"chb'%i"iinOwl_graph.set_namec0s0;Owl_graph.set_namec1s1;)ch';(* contruct a computation graph with inputs and outputs *)letnetwork_name=Graph.get_network_namenetworkinletch,ch'=Owl_utils_array.(flattench,flattench')inlet_to_nodes=Array.map(funv->unpack_arrv|>Engine.arr_to_node)inletraw_i=Owl_utils_array.(ws@gs@ps@us@ch)|>_to_nodesinletraw_o=Owl_utils_array.(ws'@gs'@ps'@us'@ch')|>_to_nodesinletparam_i,param_o=Engine.remove_unused_iopairraw_iraw_oinletoutput=Array.appendparam_o[|unpack_eltloss|>Engine.elt_to_node|]inletcgraph=Engine.make_graph~input:param_i~outputnetwork_nameinEngine.make_iopaircgraphparam_iparam_o;(* initialise values of remaining variables *)Owl_utils.aarr_iter(funx->lety=Algodiff.unpack_arrxinletshape=Engine.shapeyinEngine.assign_arry(Engine.A.zerosshape))[|gs;ps;us;ch|];(* return key parameters *)loss,x,y,cgraphletmake_eval_funlossxtytcgraph=letxt=Algodiff.unpack_arrxtinletyt=Algodiff.unpack_arrytinlet_evalxt'yt'=letxt'=Algodiff.unpack_arrxt'inletyt'=Algodiff.unpack_arryt'inEngine.eval_arr[|xt';yt'|];letxt'=Engine.unpack_arrxt'inletyt'=Engine.unpack_arryt'inEngine.unsafe_assign_arrxtxt';Engine.unsafe_assign_arrytyt';Engine.eval_graphcgraph;lossin_evalletmake_update_funcgraph=let_update()=Engine.update_iopaircgraphin_updatelettrain?state?paramsnetworkxy=letparams=matchparamswith|Somep->p|None->Params.default()inletnetwork_name=Graph.get_network_namenetworkinOwl_log.info"compile network %s into static graph ..."network_name;(* compile network into static graph *)letx_size=(unpack_arrx|>Engine.shape).(0)inletloss,xt,yt,cgraph=compile_deepparamsnetworkx_sizeinleteval=make_eval_funlossxtytcgraphinletupdate=make_update_funcgraphinletsave_fname=()in(* Experimental: optimise graph structure *)Engine.save_graphcgraph(network_name^"_raw.cgd");Engine.optimisecgraph;Engine.save_graphcgraph(network_name^"_opt.cgd");Owl_log.info"start training %s ..."network_name;Optimise.minimise_compiled_network?stateparamsevalupdatesavexy(* Multi-input/output version of ``model``. *)letmodel_inputs?(optimise=true)?(batch_size=1)network=letnetwork_name=Graph.get_network_namenetworkinOwl_log.info"compile network %s into static graph ..."network_name;letinput_shapes=Graph.input_shapesnetworkinletinputs=Array.mapi(funish->Engine.var_arr("input_"^string_of_inti)~shape:(Array.append[|batch_size|]sh)|>pack_arr)input_shapesinletoutputs=Graph.run_inputsinputsnetworkinlet_to_nodes=Array.map(funv->unpack_arrv|>Engine.arr_to_node)inleti,o=_to_nodesinputs,_to_nodesoutputsinletcgraph=Engine.make_graph~input:i~output:onetwork_namein(* optimise graph structure *)ifoptimisethen(Engine.optimisecgraph);letevalxt'=letxt=Array.map(funx->Algodiff.unpack_arrx)inputsinletxt'=Array.map(funx'->Algodiff.unpack_arrx')xt'inEngine.eval_arrxt';letxt'=Array.map(funx'->Engine.unpack_arrx')xt'inArray.iter2(funxx'->Engine.unsafe_assign_arrxx')xtxt';Engine.eval_graphcgraph;outputsinletresultsxt=letn=Optimise.Utils.sample_numxt.(0)inletchunk_sizei=leta=i*batch_sizeinletb=(minn(a+batch_size))-1inletc=b-a+1ina,b,cinletget_chunkabx=matchxwith|Arrx->letres=A.get_slice[[a;b]]xinArrres|_->failwith("Owl_neural_compiler.model_inputs: get_chunk: "^(type_infox))inletiteratei=(* perform the computation on one batch *)leta,b,c=chunk_sizeiinletxt=Array.map(get_chunkab)xtinletresult=Array.map(funx->letx=lety=Algodiff.unpack_arrxinifc<>batch_sizethenA.get_slice[[0;c-1]]yelseyinA.copyx)(evalxt)inEngine.eval_arrresult;resultinletnb_iterations=(n-1)/batch_size+1in(* compute results for each batch *)letresult=Array.initnb_iterations(funi->iteratei)in(* put the results back together *)letslicei=Array.initnb_iterations(funj->result.(j).(i))inletresult=Array.init(Array.lengthresult.(0))(funi->A.concatenate~axis:0(slicei))inEngine.eval_arrresult;Array.mapAlgodiff.pack_arrresultinresults(* ``model network`` transforms the network into a computation graph and
optimises it. Returns a function that takes the input of the network as an
argument and returns the output. *)letmodel?optimise?batch_sizenetwork=leteval=model_inputs?optimise?batch_sizenetworkinfunxt'->(eval[|xt'|]).(0)end(* Make functor ends *)