Source file utcp_mirage.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
open Lwt.Infix
let src = Logs.Src.create "tcp.mirage" ~doc:"TCP mirage"
module Log = (val Logs.src_log src : Logs.LOG)
module Make (Ip : Tcpip.Ip.S with type ipaddr = Ipaddr.t) = struct
let now () = Mtime.of_uint64_ns (Mirage_mtime.elapsed_ns ())
type error = Tcpip.Tcp.error
let pp_error = Tcpip.Tcp.pp_error
type write_error = Tcpip.Tcp.write_error
let pp_write_error = Tcpip.Tcp.pp_write_error
type ipaddr = Ipaddr.t
module Port_map = Map.Make (struct
type t = int
let compare (a : int) (b : int) = compare a b
end)
type t = {
mutable tcp : (unit, [ `Eof | `Msg of string ]) result Lwt_condition.t Utcp.state ;
ip : Ip.t ;
mutable listeners : (flow -> unit Lwt.t) Port_map.t ;
}
and flow = t * Utcp.flow
let dst (_t, flow) =
let _, (dst, dst_port) = Utcp.peers flow in
dst, dst_port
let src (_t, flow) =
let (src, src_port), _ = Utcp.peers flow in
src, src_port
let output_ip t (src, dst, seg) =
let size = Utcp.Segment.length seg in
Log.debug (fun m -> m "output to %a: %a" Ipaddr.pp dst Utcp.Segment.pp seg);
Ip.write t.ip ~src dst `TCP ~size
(fun buf ->
Utcp.Segment.encode_and_checksum_into (now ()) buf ~src ~dst seg;
size) []
let output_ign t segs =
List.fold_left (fun r seg ->
r >>= fun () ->
output_ip t seg >|= function
| Error e ->
let _, dst, _ = seg in
Log.err (fun m -> m "error sending data to %a: %a" Ipaddr.pp dst Ip.pp_error e)
| Ok () -> ())
Lwt.return_unit segs
let read (t, flow) =
match Utcp.recv t.tcp (now ()) flow with
| Ok (tcp, [], cond, segs) -> (
t.tcp <- tcp ;
output_ign t segs >>= fun () ->
Lwt_condition.wait cond >>= fun r ->
match r with
| Error `Eof ->
Lwt.return (Ok `Eof)
| Error `Msg msg ->
Log.err (fun m -> m "%a error %s from condition while recv" Utcp.pp_flow flow msg);
Lwt.return (Error `Refused)
| Ok () ->
match Utcp.recv t.tcp (now ()) flow with
| Ok (tcp, data, _cond, segs) ->
t.tcp <- tcp ;
output_ign t segs >>= fun () ->
begin match data with
| [] -> Lwt.return (Ok `Eof)
| sstr ->
let cs = Cstruct.of_string (String.concat "" sstr) in
Lwt.return (Ok (`Data cs)) end
| Error `Eof ->
Lwt.return (Ok `Eof)
| Error `Msg msg ->
Log.err (fun m -> m "%a error while read (second recv) %s" Utcp.pp_flow flow msg);
Lwt.return (Error `Refused)
| Error `Not_found -> Lwt.return (Error `Refused))
| Ok (tcp, sstr, _cond, segs) ->
t.tcp <- tcp ;
output_ign t segs >>= fun () ->
let cs = Cstruct.of_string (String.concat "" sstr) in
Lwt.return (Ok (`Data cs))
| Error `Eof ->
Lwt.return (Ok `Eof)
| Error `Msg msg ->
Log.err (fun m -> m "%a error while read %s" Utcp.pp_flow flow msg);
Lwt.return (Error `Refused)
| Error `Not_found -> Lwt.return (Error `Refused)
let rec write (t, flow) buf =
match Utcp.send t.tcp (now ()) flow buf with
| Ok (tcp, bytes_sent, cond, segs) ->
t.tcp <- tcp ;
output_ign t segs >>= fun () ->
if bytes_sent < String.length buf then
Lwt_condition.wait cond >>= fun r ->
match r with
| Error `Eof ->
Lwt.return (Error `Closed)
| Error `Msg msg ->
Log.err (fun m -> m "%a error %s from condition while sending" Utcp.pp_flow flow msg);
Lwt.return (Error `Closed)
| Ok () ->
let buf = String.sub buf bytes_sent (String.length buf - bytes_sent) in
write (t, flow) buf
else
Lwt.return (Ok ())
| Error `Msg msg ->
Log.err (fun m -> m "%a error while write %s" Utcp.pp_flow flow msg);
Lwt.return (Error `Closed)
| Error `Not_found -> Lwt.return (Error `Refused)
let writev flow bufs = write flow (Cstruct.to_string (Cstruct.concat bufs))
let write flow buf = write flow (Cstruct.to_string buf)
let close (t, flow) =
match Utcp.close t.tcp (now ()) flow with
| Ok (tcp, segs) ->
t.tcp <- tcp ;
output_ign t segs
| Error `Msg msg ->
Log.err (fun m -> m "%a error in close: %s" Utcp.pp_flow flow msg);
Lwt.return_unit
| Error `Not_found -> Lwt.return_unit
let shutdown (t, flow) mode =
match Utcp.shutdown t.tcp (now ()) flow mode with
| Ok (tcp, segs) ->
t.tcp <- tcp ;
output_ign t segs
| Error `Msg msg ->
Log.err (fun m -> m "%a error in shutdown: %s" Utcp.pp_flow flow msg);
Lwt.return_unit
| Error `Not_found -> Lwt.return_unit
let write_nodelay flow buf = write flow buf
let writev_nodelay flow bufs = write flow (Cstruct.concat bufs)
let create_connection ?keepalive:_ t (dst, dst_port) =
let src = Ip.src t.ip ~dst in
let tcp, id, cond, seg = Utcp.connect ~src ~dst ~dst_port t.tcp (now ()) in
t.tcp <- tcp;
output_ip t seg >>= function
| Error e ->
Log.err (fun m -> m "%a error sending syn: %a" Utcp.pp_flow id Ip.pp_error e);
Lwt.return (Error `Refused)
| Ok () ->
Lwt_condition.wait cond >|= fun r ->
match r with
| Ok () -> Ok (t, id)
| Error `Eof ->
Log.err (fun m -> m "%a error establishing connection (timeout)" Utcp.pp_flow id);
Error `Timeout
| Error `Msg msg ->
Log.err (fun m -> m "%a error establishing connection: %s" Utcp.pp_flow id msg);
Error `Timeout
let input t ~src ~dst data =
let tcp, ev, segs = Utcp.handle_buf t.tcp (now ()) ~src ~dst data in
t.tcp <- tcp;
Option.fold ~none:()
~some:(function
| `Established (id, cond) ->
(match cond with
| None ->
let (_, port), _ = Utcp.peers id in
(match Port_map.find_opt port t.listeners with
| None ->
Log.debug (fun m -> m "%a not found in waiting or listeners"
Utcp.pp_flow id)
| Some cb ->
Lwt.async (fun () -> cb (t, id)))
| Some cond ->
Lwt_condition.signal cond (Ok ()))
| `Drop (_id, c_opt, cs) ->
List.iter (fun c -> Lwt_condition.signal c (Error `Eof)) cs;
Option.iter (fun c -> Lwt_condition.signal c (Ok ())) c_opt
| `Signal (_id, conds) ->
List.iter (fun c -> Lwt_condition.signal c (Ok ())) conds
)
ev;
output_ign t segs
let connect id ip =
Log.info (fun m -> m "starting µTCP on %S" id);
let tcp = Utcp.empty Lwt_condition.create id in
let t = { tcp ; ip ; listeners = Port_map.empty } in
Lwt.async (fun () ->
let rec timer n =
let tcp, drops, outs = Utcp.timer t.tcp (now ()) in
t.tcp <- tcp;
List.iter (fun (_id, err, rcv, snd) ->
let err = Error (match err with
| `Retransmission_exceeded -> `Msg "retransmission exceeded"
| `Timer_2msl -> `Eof
| `Timer_connection_established -> `Eof
| `Timer_fin_wait_2 -> `Eof)
in
Lwt_condition.signal rcv err;
Lwt_condition.signal snd err;
)
drops;
Lwt_list.iter_p (fun data -> output_ip t data >|= ignore) outs >>= fun () ->
Mirage_sleep.ns (Duration.of_ms 100) >>= fun () ->
(timer [@tailcall]) (succ n)
in
timer 0);
t
let listen t ~port ?keepalive:_ callback =
let tcp = Utcp.start_listen t.tcp port in
t.tcp <- tcp;
t.listeners <- Port_map.add port callback t.listeners
let unlisten t ~port =
let tcp = Utcp.stop_listen t.tcp port in
t.tcp <- tcp;
t.listeners <- Port_map.remove port t.listeners
let disconnect _t =
Lwt.return_unit
end