package mssql

  1. Overview
  2. Docs

Source file client.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
open Core_kernel
open Async_kernel
open Async_unix
open Freetds
open Mssql_error

type t =
  { (* dbprocess will be set to None when closed to prevent null pointer crashes *)
    (* The sequencer prevents concurrent use of the DB connection, and also
     prevent queries during unrelated transactions. *)
    mutable conn : Dblib.dbprocess Sequencer.t option
        (* ID used to detect deadlocks when attempting to use an outer DB handle
     inside of with_transaction *)
  ; transaction_id : Bigint.t
        (* Months are sometimes 0-based and sometimes 1-based. See:
     http://www.pymssql.org/en/stable/freetds_and_dates.html *)
  ; month_offset : int
  }

let next_transaction_id =
  let next = ref Bigint.zero in
  fun () ->
    let current = !next in
    (next := Bigint.(one + current));
    current
;;

let parent_transactions_key =
  Univ_map.Key.create ~name:"mssql_parent_transactions" [%sexp_of: Bigint.Set.t]
;;

let sequencer_enqueue t f =
  match t.conn with
  | None -> failwith [%here] "Attempt to use closed DB"
  | Some conn ->
    Scheduler.find_local parent_transactions_key
    |> (function
    | Some parent_transactions when Set.mem parent_transactions t.transaction_id ->
      failwith
        [%here]
        "Attempted to use outer DB handle inside of with_transaction. This would have \
         lead to a deadlock."
    | _ -> Throttle.enqueue conn f)
;;

let format_query query params =
  let params_formatted = List.map params ~f:Db_field.to_string_escaped |> Array.of_list in
  let lexbuf = Lexing.from_string query in
  Query_parser.main Query_lexer.token lexbuf
  |> List.map
       ~f:
         (let open Query_parser_types in
         function
         | Other s -> s
         | Param n ->
           (* $1 is the first param *)
           let i = n - 1 in
           if i < 0
           then
             failwithf
               [%here]
               ~query
               ~params
               "Query has param $%d but params should start at $1."
               n;
           let len = Array.length params_formatted in
           if i >= len
           then
             failwithf
               [%here]
               ~query
               ~params
               "Query has param $%d but there are only %d params."
               n
               len;
           params_formatted.(i))
  |> String.concat ~sep:""
;;

let execute' ?params ~query ~formatted_query ({ month_offset; _ } as t) ~f =
  sequencer_enqueue t
  @@ fun conn ->
  Logger.debug !"Executing query: %s" formatted_query;
  In_thread.run
  @@ fun () ->
  Mssql_error.with_wrap ~query ?params ~formatted_query [%here] (fun () ->
      Dblib.cancel conn;
      Dblib.sqlexec conn formatted_query;
      Iter.from_fun (fun () ->
          if Dblib.results conn
          then
            Dblib.numcols conn
            |> List.range 0
            |> List.map ~f:(fun i -> Dblib.colname conn (i + 1))
            |> function
            | [] ->
              (* Skip this result set if there are no columns, since this indicates results from things
                 like inserts with no row data *)
              Some None
            | colnames ->
              Iter.from_fun (fun () ->
                  try
                    let row = Dblib.nextrow conn in
                    let row = Row.create_exn ~month_offset ~colnames row in
                    Some row
                  with
                  | Caml.Not_found -> None)
              |> Option.some
              |> Option.some
          else None)
      |> IterLabels.filter_map ~f:Fn.id
      |> f)
;;

let execute_multi_result' ?(params = []) conn query =
  let formatted_query = format_query query params in
  execute' conn ~query ~params ~formatted_query
;;

let execute_multi_result ?params conn query =
  execute_multi_result' ?params conn query ~f:(fun result_set ->
      IterLabels.map result_set ~f:Iter.to_list |> Iter.to_list)
;;

(* Execute [f iter] for the first result set iterator and throw an exception if there is more than
   one result set *)
let execute_f' ?params ~f conn query =
  let f = Scheduler.preserve_execution_context' f |> Staged.unstage in
  execute_multi_result' ?params conn query ~f:(fun result_sets ->
      let result =
        let input =
          IterLabels.head result_sets |> Option.value ~default:IterLabels.empty
        in
        Thread_safe.block_on_async_exn (fun () -> f input)
      in
      (* Need to ensure we consume the results or we'll get errors about results pending *)
      match
        IterLabels.map result_sets ~f:(fun result_set ->
            IterLabels.iter result_set ~f:ignore)
        |> IterLabels.length
      with
      | 0 -> result
      | n ->
        failwithf
          [%here]
          ~query
          ?params
          "Mssql.execute expected one result set but got %d result sets"
          (n + 1))
;;

let execute_f ?params ~f conn query =
  execute_f' ?params conn query ~f:(fun result_set -> f result_set |> return)
;;

let execute ?params conn query = execute_f ?params ~f:Iter.to_list conn query

let execute_iter ?params ~f conn query =
  execute_f ?params ~f:(IterLabels.iter ~f) conn query
;;

let execute_fold ?params ~init ~f conn query =
  let acc = ref init in
  execute_iter ?params conn query ~f:(fun row -> acc := f !acc row) >>| fun () -> !acc
;;

let execute_map ?params ~f conn query =
  execute_f ?params ~f:(Fn.compose Iter.to_list (IterLabels.map ~f)) conn query
;;

let execute_pipe ?params conn query =
  Pipe.create_reader ~close_on_exception:false
  @@ fun writer ->
  Monitor.protect
    (fun () ->
      execute_f' ?params conn query ~f:(fun rows ->
          IterLabels.fold rows ~init:Deferred.unit ~f:(fun acc row ->
              acc >>= fun () -> Pipe.write_if_open writer row)))
    ~finally:(fun () ->
      Pipe.close writer;
      Deferred.unit)
;;

let execute_unit ?params conn query =
  execute ?params conn query
  >>| function
  | [] -> ()
  | rows ->
    failwithf
      [%here]
      ~query
      ?params
      ~results:[ rows ]
      "Mssql.execute_unit expected no rows but result set has %d rows"
      (List.length rows)
      ()
;;

let execute_single ?params conn query =
  execute ?params conn query
  >>| function
  | [] -> None
  | [ row ] -> Some row
  | rows ->
    failwithf
      [%here]
      ~query
      ?params
      ~results:[ rows ]
      "Mssql.execute_single expected 0 or 1 results but got %d rows"
      (List.length rows)
;;

let execute_many ~params conn query =
  let formatted_query =
    List.map params ~f:(format_query query) |> String.concat ~sep:";"
  in
  execute' conn ~query ~params:(List.concat params) ~formatted_query ~f:(fun result_set ->
      IterLabels.map result_set ~f:Iter.to_list |> Iter.to_list)
;;

let begin_transaction conn = execute_unit conn "BEGIN TRANSACTION"
let commit conn = execute_unit conn "COMMIT"
let rollback conn = execute_unit conn "ROLLBACK"

let with_transaction' t f =
  (* Use the sequencer to prevent any other copies of this DB handle from
     executing during the transaction *)
  sequencer_enqueue t
  @@ fun conn ->
  Scheduler.find_local parent_transactions_key
  |> Option.value ~default:Bigint.Set.empty
  |> Fn.flip Set.add t.transaction_id
  |> Option.some
  |> Scheduler.with_local parent_transactions_key ~f:(fun () ->
         (* Make a new sub-sequencer so our own queries can continue *)
         let t =
           { t with
             conn = Sequencer.create ~continue_on_error:true conn |> Option.some
           ; transaction_id = next_transaction_id ()
           }
         in
         let%bind () = begin_transaction t in
         let%bind res = f t in
         let%map () =
           match res with
           | Ok _ -> commit t
           | Error _ -> rollback t
         in
         res)
;;

let with_transaction t f =
  with_transaction' t (fun t -> Monitor.try_with ~here:[%here] (fun () -> f t))
  >>| function
  | Ok res -> res
  | Error exn -> raise exn
;;

let with_transaction_or_error t f =
  with_transaction' t (fun t ->
      Monitor.try_with_join_or_error ~here:[%here] (fun () -> f t))
;;

let rec connect ?(tries = 5) ~host ~db ~user ~password ?port () =
  try
    let conn =
      Dblib.connect
        ~user
        ~password (* We have issues with anything higher than this *)
        ~version:
          Dblib.V70
          (* Clifford gives FreeTDS conversion errors if we choose anything else,
           eg:
           ("Error(CONVERSION, \"Some character(s) could not be converted into
           client's character set.  Unconverted bytes were changed to question
           marks ('?')\")") *)
        ~charset:"CP1252"
        (* You set ports in FreeTDS by appending them to the host name:
           http://www.freetds.org/userguide/portoverride.htm *)
        (match port with
        | None -> host
        | Some port -> sprintf "%s:%d" host port)
    in
    Dblib.use conn db;
    conn
  with
  | exn ->
    if tries = 0
    then raise exn
    else Logger.info "Retrying Mssql.connect due to exn: %s" (Exn.to_string exn);
    connect ~tries:(tries - 1) ~host ~db ~user ~password ?port ()
;;

(* These need to be on for some reason, eg: DELETE failed because the following
   SET options have incorrect settings: 'ANSI_NULLS, QUOTED_IDENTIFIER,
   CONCAT_NULL_YIELDS_NULL, ANSI_WARNINGS, ANSI_PADDING'. Verify that SET
   options are correct for use with indexed views and/or indexes on computed
   columns and/or filtered indexes and/or query notifications and/or XML data
   type methods and/or spatial index operations.*)
let init_conn c =
  execute_multi_result
    c
    "SET QUOTED_IDENTIFIER ON\n\
    \     SET ANSI_NULLS ON\n\
    \     SET ANSI_WARNINGS ON\n\
    \     SET ANSI_PADDING ON\n\
    \     SET CONCAT_NULL_YIELDS_NULL ON"
  |> Deferred.ignore_m
;;

let close ({ conn; _ } as t) =
  match conn with
  (* already closed *)
  | None -> Deferred.unit
  | Some conn ->
    t.conn <- None;
    Throttle.enqueue conn @@ fun conn -> In_thread.run (fun () -> Dblib.close conn)
;;

let create ~host ~db ~user ~password ?port () =
  let%bind conn =
    let%map conn =
      In_thread.run (connect ~host ~db ~user ~password ?port)
      >>| Sequencer.create ~continue_on_error:true
    in
    { conn = Some conn; transaction_id = next_transaction_id (); month_offset = 0 }
  in
  Monitor.try_with ~here:[%here] (fun () ->
      (* Since FreeTDS won't tell us if it was compiled with 0-based month or
       1-based months, make a query to check when we first startup and keep
       track of the offset so we can correct it. *)
      let query = "SELECT CAST('2017-02-02' AS DATETIME) AS x" in
      execute_single conn query
      >>= function
      | Some row ->
        let month_offset =
          Row.datetime_exn row "x"
          |> Time.(to_date ~zone:Zone.utc)
          |> Date.month
          |> function
          | Month.Feb -> 0
          | Month.Jan -> 1
          | month ->
            failwithf
              [%here]
              ~query
              "Expected month index workaround query to return February as either Jan or \
               Feb but got %s"
              (Month.to_string month)
        in
        let conn = { conn with month_offset } in
        init_conn conn >>| fun () -> conn
      | None ->
        failwith
          [%here]
          ~query
          "Expected month index workaround query to return one row but got none")
  >>= function
  | Ok res -> return res
  | Error exn ->
    let%map () = close conn in
    raise exn
;;

let with_conn ~host ~db ~user ~password ?port f =
  let%bind conn = create ~host ~db ~user ~password ?port () in
  Monitor.protect (fun () -> f conn) ~finally:(fun () -> close conn)
;;