package catala

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file utils.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
(* This file is part of the Catala compiler, a specification language for tax
   and social benefits computation rules. Copyright (C) 2024 Inria, contributor:
   Louis Gesbert <louis.gesbert@inria.fr>

   Licensed under the Apache License, Version 2.0 (the "License"); you may not
   use this file except in compliance with the License. You may obtain a copy of
   the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
   WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
   License for the specific language governing permissions and limitations under
   the License. *)

open Catala_utils
open Shared_ast
open Ast
module D = Dcalc.Ast
module L = Lcalc.Ast

let rec get_vars e =
  match Mark.remove e with
  | EVar v -> VarName.Set.singleton v
  | EFunc _ | ELit _ | EPosLit | EExternal _ -> VarName.Set.empty
  | EStruct str ->
    StructField.Map.fold
      (fun _ e -> VarName.Set.union (get_vars e))
      str.fields VarName.Set.empty
  | EStructFieldAccess { e1; _ } | ETupleAccess { e1; _ } | EInj { e1; _ } ->
    get_vars e1
  | ETuple el | EArray el | EAppOp { args = el; _ } ->
    List.fold_left
      (fun acc e -> VarName.Set.union acc (get_vars e))
      VarName.Set.empty el
  | EApp { f; args; _ } ->
    List.fold_left
      (fun acc e -> VarName.Set.union acc (get_vars e))
      (get_vars f) args

let rec subst_expr v e within_expr =
  let m = Mark.get within_expr in
  match Mark.remove within_expr with
  | EVar v1 -> if VarName.equal v v1 then e else within_expr
  | EFunc _ | ELit _ | EPosLit | EExternal _ -> within_expr
  | EStruct str ->
    ( EStruct
        { str with fields = StructField.Map.map (subst_expr v e) str.fields },
      m )
  | EStructFieldAccess sfa ->
    EStructFieldAccess { sfa with e1 = subst_expr v e sfa.e1 }, m
  | ETuple el -> ETuple (List.map (subst_expr v e) el), m
  | ETupleAccess ta -> ETupleAccess { ta with e1 = subst_expr v e ta.e1 }, m
  | EInj i -> EInj { i with e1 = subst_expr v e i.e1 }, m
  | EArray el -> EArray (List.map (subst_expr v e) el), m
  | EApp app ->
    ( EApp
        {
          app with
          f = subst_expr v e app.f;
          args = List.map (subst_expr v e) app.args;
        },
      m )
  | EAppOp ao -> EAppOp { ao with args = List.map (subst_expr v e) ao.args }, m

let rec subst_stmt v e stmt =
  match stmt with
  | SInnerFuncDef ifd ->
    SInnerFuncDef
      {
        ifd with
        func = { ifd.func with func_body = subst_block v e ifd.func.func_body };
      }
  | SLocalDecl _ -> stmt
  | SLocalInit li -> SLocalInit { li with expr = subst_expr v e li.expr }
  | SLocalDef ld -> SLocalDef { ld with expr = subst_expr v e ld.expr }
  | SFatalError fe ->
    SFatalError { fe with pos_expr = subst_expr v e fe.pos_expr }
  | SIfThenElse { if_expr; then_block; else_block } ->
    SIfThenElse
      {
        if_expr = subst_expr v e if_expr;
        then_block = subst_block v e then_block;
        else_block = subst_block v e else_block;
      }
  | SSwitch sw ->
    let switch_var =
      if VarName.equal sw.switch_var v then
        match e with EVar v1, _ -> v1 | _ -> raise Exit
      else sw.switch_var
    in
    SSwitch
      {
        sw with
        switch_var;
        switch_cases =
          List.map
            (fun c -> { c with case_block = subst_block v e c.case_block })
            sw.switch_cases;
      }
  | SReturn e1 -> SReturn (subst_expr v e e1)
  | SAssert { pos_expr; expr } ->
    SAssert { pos_expr = subst_expr v e pos_expr; expr = subst_expr v e expr }
  | _ -> .

and subst_block v e block =
  List.map (fun (stmt, pos) -> subst_stmt v e stmt, pos) block

let subst_block v expr typ pos block =
  try subst_block v expr block
  with Exit -> (SLocalInit { name = v, pos; typ; expr }, pos) :: block

let rec find_block pred = function
  | [] -> None
  | stmt :: _ when pred stmt -> Some stmt
  | (SIfThenElse { then_block; else_block; _ }, _) :: r -> (
    match find_block pred then_block with
    | None -> (
      match find_block pred else_block with
      | None -> find_block pred r
      | some -> some)
    | some -> some)
  | (SSwitch { switch_cases; _ }, _) :: r -> (
    match
      List.find_map (fun case -> find_block pred case.case_block) switch_cases
    with
    | None -> find_block pred r
    | some -> some)
  | _ :: r -> find_block pred r

let rec filter_map_block pred = function
  | [] -> []
  | ((SIfThenElse { then_block; else_block; _ }, _) as stmt) :: r ->
    Option.to_list (pred stmt)
    @ filter_map_block pred then_block
    @ filter_map_block pred else_block
    @ filter_map_block pred r
  | ((SSwitch { switch_cases; _ }, _) as stmt) :: r ->
    Option.to_list (pred stmt)
    @ List.flatten
        (List.map
           (fun case -> filter_map_block pred case.case_block)
           switch_cases)
    @ filter_map_block pred r
  | stmt :: r -> Option.to_list (pred stmt) @ filter_map_block pred r