package janestreet_csv

  1. Overview
  2. Docs

Source file join.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
(* Join csv files on some common field *)
open Core
module Csv = Csvlib.Csv

module T = struct
  type t =
    | Full
    | Inner
    | Left
  [@@deriving compare, enumerate, sexp]
end

include T

let param =
  let join_switch = "-join" in
  let keys_need_not_occur_in_all_files_switch = "-keys-need-not-occur-in-all-files" in
  let%map_open.Command t =
    Enum.make_param join_switch (module T) ~doc:"as in SQL (default: inner)" ~f:optional
  and keys_need_not_occur_in_all_files =
    flag
      keys_need_not_occur_in_all_files_switch
      no_arg
      ~doc:" deprecated alias for -join full"
  in
  match t, keys_need_not_occur_in_all_files with
  | Some t, false -> t
  | None, false -> Inner
  | None, true -> Full
  | Some _, true ->
    raise_s
      [%message "cannot specify both" join_switch keys_need_not_occur_in_all_files_switch]
;;

module Row = struct
  type t = string list [@@deriving sexp]
end

module Key : sig
  type t

  include Comparable.S with type t := t
  include Hashable.S with type t := t

  val create : string array -> t
  val to_list : t -> string list
end = struct
  module T = struct
    open Ppx_hash_lib.Std.Hash.Builtin

    type 'a array_frozen = 'a array [@@deriving compare, sexp]

    (* [array_frozen] can derive hash.  We don't expose any mutation. *)
    type t = string array_frozen [@@deriving compare, hash, sexp]
  end

  include T
  include Comparable.Make (T)
  include Hashable.Make (T)

  let create t = t
  let to_list = Array.to_list
end

module Rows_by_key = struct
  (* The [header] and each [Row.t] in [data_by_key] are missing the key field. *)
  type t =
    { data_by_key : Row.t list Key.Map.t (* header for each row in [data_by_key]. *)
    ; header : Row.t
    }
  [@@deriving fields ~getters, sexp]

  let load_rows file ~sep =
    protectx
      (In_channel.create file)
      ~f:(fun channel ->
        let rows = Csv.load_in ~separator:sep channel in
        match
          List.dedup_and_sort
            (List.map rows ~f:(List.length :> _ -> _))
            ~compare:[%compare: int]
        with
        | [] | [ _ ] -> rows
        | _ -> failwithf "rows in %s have different lengths" file ())
      ~finally:In_channel.close
  ;;

  let load ~file_name ~key_fields ~sep =
    match load_rows ~sep file_name with
    | [] -> failwithf "file %s is  empty" file_name ()
    | header :: rows ->
      let hmap =
        match String.Map.of_alist (List.mapi header ~f:(fun i h -> h, i)) with
        | `Ok map -> map
        | `Duplicate_key h -> failwithf "repeated column %s in %s" h file_name ()
      in
      let key_indices =
        Array.map key_fields ~f:(fun key_field ->
          match Map.find hmap key_field with
          | Some i -> i
          | None -> failwithf "No %s column in %s" key_field file_name ())
      in
      let data_indices =
        let key_fields = String.Set.of_array key_fields in
        List.filter_mapi header ~f:(fun i h ->
          if Set.mem key_fields h then None else Some i)
      in
      let data_by_key =
        List.map rows ~f:(fun row ->
          let row = Array.of_list row in
          ( Key.create @@ Array.map key_indices ~f:(Array.get row)
          , List.map data_indices ~f:(Array.get row) ))
        |> Key.Map.of_alist_multi
      in
      { data_by_key
      ; header = List.map data_indices ~f:(Array.get (Array.of_list header))
      }
  ;;
end

module Join_result : sig
  type t

  val empty_for_left_join : Rows_by_key.t list -> t
  val empty_for_inner_join : Rows_by_key.t list -> t
  val empty_for_full_join : Rows_by_key.t list -> t

  (* Any join can be expressed as a left join with the correct keys on the left side. *)

  val do_left_join : t -> Rows_by_key.t -> t
  val to_rows : t -> Row.t Sequence.t
end = struct
  type t = (Key.t * Row.t Sequence.t) Sequence.t

  let empty_of_keys keys =
    Set.to_sequence keys |> Sequence.map ~f:(fun key -> key, Sequence.singleton [])
  ;;

  let empty_for_left_join = function
    | [] -> failwith "join requires at least one csv."
    | car :: _ -> empty_of_keys (Map.key_set (Rows_by_key.data_by_key car : _ Key.Map.t))
  ;;

  let reduce_keys maps ~f =
    Sequence.of_list maps
    |> Sequence.map ~f:Rows_by_key.data_by_key
    |> Sequence.map ~f:Map.key_set
    |> Sequence.reduce_exn ~f
    |> empty_of_keys
  ;;

  let empty_for_inner_join = reduce_keys ~f:Set.inter
  let empty_for_full_join = reduce_keys ~f:Set.union

  let do_left_join t rows =
    let empty_right_side_of_rows =
      Rows_by_key.header rows |> List.map ~f:(const "") |> Sequence.singleton
    in
    Sequence.map t ~f:(fun (key, left_side_of_rows) ->
      let right_side_of_rows : Row.t Sequence.t =
        match Map.find (Rows_by_key.data_by_key rows) key with
        | None -> empty_right_side_of_rows
        | Some right_rows -> Sequence.of_list right_rows
      in
      ( key
      , Sequence.concat_map left_side_of_rows ~f:(fun left_side_of_row ->
          Sequence.map right_side_of_rows ~f:(fun right_side_of_row ->
            left_side_of_row @ right_side_of_row)) ))
  ;;

  let to_rows join_result =
    Sequence.concat_map join_result ~f:(fun (key, rows) ->
      Sequence.map rows ~f:(fun row -> List.append (Key.to_list key) row))
  ;;
end

let join t files ~key_fields ~sep =
  let maps =
    List.map files ~f:(fun file_name -> Rows_by_key.load ~file_name ~key_fields ~sep)
  in
  let combined_header =
    List.append (Array.to_list key_fields) (List.concat_map maps ~f:Rows_by_key.header)
  in
  Option.iter
    (List.find_a_dup combined_header ~compare:[%compare: string])
    ~f:(fun duplicate ->
      raise_s
        [%message
          "Only key fields may appear in multiple files."
            (duplicate : string)
            (combined_header : string list)]);
  let init =
    match t with
    | Full -> Join_result.empty_for_full_join maps
    | Inner -> Join_result.empty_for_inner_join maps
    | Left -> Join_result.empty_for_left_join maps
  in
  let rows = List.fold maps ~init ~f:Join_result.do_left_join |> Join_result.to_rows in
  Sequence.shift_right rows combined_header
;;