package pythonlib

  1. Overview
  2. Docs

Source file class_wrapper.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
open Base
open Import

module Id : sig
  type t

  val create : unit -> t
  val to_string : t -> string
end = struct
  type t = int

  let create =
    let current = ref 0 in
    fun () ->
      Int.incr current;
      !current
  ;;

  let to_string = Int.to_string
end

let content_field = "_content"

type 'a t =
  { wrap : 'a -> Py.Object.t
  ; unwrap : Py.Object.t -> 'a
  ; name : string
  ; mutable cls_object : Py.Object.t option
  }

let set_cls_object_exn t pyobject =
  if Option.is_some t.cls_object
  then Printf.failwithf "cls_object for %s has already been set" t.name ();
  t.cls_object <- Some pyobject
;;

module Init = struct
  type 'a cls = 'a t

  type 'a t =
    { fn : 'a cls -> args:pyobject list -> 'a
    ; docstring : string option
    }
  [@@deriving fields]

  let create ?docstring fn = { docstring; fn }
end

module Method = struct
  type 'a cls = 'a t

  type 'a fn =
    | No_keywords of ('a cls -> self:'a * pyobject -> args:pyobject list -> pyobject)
    | With_keywords of
        ('a cls
         -> self:'a * pyobject
         -> args:pyobject list
         -> keywords:(string, pyobject, String.comparator_witness) Map.t
         -> pyobject)

  type 'a t =
    { name : string
    ; fn : 'a fn
    ; docstring : string option
    }
  [@@deriving fields]

  let create ?docstring name fn = { name; fn = No_keywords fn; docstring }

  let create_with_keywords ?docstring name fn =
    { name; fn = With_keywords fn; docstring }
  ;;

  let defunc ?docstring name fn =
    let fn cls ~self ~args ~keywords =
      Defunc.apply (fn cls ~self) (Array.of_list args) keywords
    in
    create_with_keywords ?docstring name fn
  ;;
end

let wrap_capsule t obj = t.wrap obj

let unwrap_exn t pyobj =
  let pyobj =
    match Py.Object.get_attr_string pyobj content_field with
    | None -> Printf.failwithf "no %s field in object" content_field ()
    | Some content -> content
  in
  if not (Py.Capsule.check pyobj) then failwith "not an ocaml capsule";
  t.unwrap pyobj
;;

let unwrap t pyobj =
  try Some (unwrap_exn t pyobj) with
  | _ -> None
;;

let wrap t obj =
  let cls = Option.value_exn t.cls_object in
  let pyobject = Py.Object.call_function_obj_args cls [||] in
  Py.Object.set_attr_string pyobject content_field (wrap_capsule t obj);
  pyobject
;;

let make ?to_string_repr ?to_string ?eq ?init name ~methods =
  let id = Id.create () in
  let t =
    let wrap, unwrap = Py.Capsule.make (Printf.sprintf !"%s-%{Id}" name id) in
    { wrap; unwrap; cls_object = None; name }
  in
  let methods =
    let to_string =
      Option.map to_string ~f:(fun fn t ~self ~args:_ ->
        fn t (fst self) |> Py.String.of_string)
    in
    let to_string_repr =
      Option.map to_string_repr ~f:(fun fn t ~self ~args:_ ->
        fn t (fst self) |> Py.String.of_string)
    in
    let to_string_repr = Option.first_some to_string_repr to_string in
    let eq =
      Option.map eq ~f:(fun fn t ~self ~args ->
        let rhs =
          match args with
          | [] -> failwith "eq with no argument"
          | _ :: _ :: _ ->
            Printf.failwithf "eq with %d arguments" (List.length args) ()
          | [ rhs ] -> rhs
        in
        fn t (fst self) (unwrap_exn t rhs) |> Py.Bool.of_bool)
    in
    List.filter_map
      [ "__str__", to_string; "__repr__", to_string_repr; "__eq__", eq ]
      ~f:(fun (name, fn) -> Option.map fn ~f:(fun fn -> Method.create name fn))
    @ methods
  in
  let methods =
    List.map methods ~f:(fun { Method.name; fn; docstring } ->
      let fn =
        let self_and_args args =
          let args = Array.to_list args in
          match args with
          | [] -> failwith "empty input"
          | p :: q -> p, q
        in
        match (fn : _ Method.fn) with
        | No_keywords fn ->
          Py.Callable.of_function ?docstring (fun args ->
            let self, args = self_and_args args in
            try fn t ~self:(unwrap_exn t self, self) ~args with
            | Py.Err _ as pyerr -> raise pyerr
            | exn ->
              let msg = Printf.sprintf !"ocaml error %{Exn#mach}" exn in
              raise (Py.Err (ValueError, msg)))
        | With_keywords fn ->
          Py.Callable.of_function_with_keywords ?docstring (fun args keywords ->
            try
              let self, args = self_and_args args in
              let keywords =
                Py_module.keywords_of_python keywords |> Or_error.ok_exn
              in
              fn t ~self:(unwrap_exn t self, self) ~args ~keywords
            with
            | Py.Err _ as pyerr -> raise pyerr
            | exn ->
              let msg = Printf.sprintf !"ocaml error %{Exn#mach}" exn in
              raise (Py.Err (ValueError, msg)))
      in
      name, fn)
  in
  let init =
    let fn =
      let docstring = Option.bind init ~f:Init.docstring in
      Py.Callable.of_function_as_tuple ?docstring (fun tuple ->
        try
          let self, args =
            match Py.Tuple.to_list tuple with
            | [] -> failwith "empty input"
            | p :: q -> p, q
          in
          let content =
            match init with
            | Some init -> init.fn t ~args |> wrap_capsule t
            | None -> Py.none
          in
          Py.Object.set_attr_string self content_field content;
          Py.none
        with
        | Py.Err _ as pyerr -> raise pyerr
        | exn ->
          let msg = Printf.sprintf !"ocaml error %{Exn#mach}" exn in
          raise (Py.Err (ValueError, msg)))
    in
    "__init__", fn
  in
  let cls_object =
    Py.Class.init name ~fields:[ content_field, Py.none ] ~methods:(init :: methods)
  in
  set_cls_object_exn t cls_object;
  t
;;

let register_in_module t modl =
  Py_module.set_value modl t.name (Option.value_exn t.cls_object)
;;