package protocol-9p

  1. Overview
  2. Docs

Source file protocol_9p_buffered9PReader.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
(*
 * Copyright (C) 2015 David Scott <dave.scott@unikernel.com>
 *
 * Permission to use, copy, modify, and distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 *)

open Rresult
open Protocol_9p_error
open Lwt.Infix

let max_message_size = 655360l       (* 640 KB should be enough... Linux limit is 32 KB *)

module Make(Log: Protocol_9p_s.LOG)(FLOW: Mirage_flow.S) = struct
  module C = Mirage_channel.Make(FLOW)
  type t = {
    channel: C.t;
    read_m: Lwt_mutex.t;
    mutable input_buffer: Cstruct.t;
  }

  let create flow =
    let channel = C.create flow in
    let read_m = Lwt_mutex.create () in
    let input_buffer = Cstruct.create 0 in
    { channel; read_m; input_buffer }

  let read_exactly ~len c =
    C.read_exactly ~len c >>= function
    | Ok (`Data bufs) -> Lwt.return (Ok (Cstruct.concat bufs))
    | Ok `Eof -> Lwt.return (Error `Eof)
    | Error e -> Lwt.return (Error (`Msg (Fmt.str "%a" C.pp_error e)))

  let read_must_have_lock t =
    let len_size = 4 in
    read_exactly ~len:len_size t.channel >>= function
    | Ok length_buffer -> begin
        match Cstruct.LE.get_uint32 length_buffer 0 with
        | bad_length when bad_length < Int32.of_int len_size
                       || bad_length > max_message_size ->
            Lwt.return (error_msg "Message size %lu out of range" bad_length)
        | length -> begin
          read_exactly ~len:(Int32.to_int length - len_size) t.channel >>= function
          | Ok packet_buffer -> Lwt.return (Ok packet_buffer)
          | err -> Lwt.return err
        end
    end
    | Error e -> Lwt.return (Error e)

  let read t =
    Lwt_mutex.with_lock t.read_m (fun () ->
      read_must_have_lock t >|= function
      | Ok _ as ok -> ok
      | Error `Eof -> error_msg "Caught EOF on underlying FLOW"
      | Error (`Msg _) as err ->
        R.reword_error_msg (fun msg ->
            R.msgf "Unexpected error on underlying FLOW: %s" msg) err
    )
end