package mehari

  1. Overview
  2. Docs

Source file router_impl.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
module type S = sig
  module IO : Types.IO

  type route
  type rate_limiter
  type addr
  type handler = addr Handler.Make(IO).t
  type middleware = handler -> handler

  val no_middleware : middleware
  val pipeline : middleware list -> middleware
  val router : route list -> handler

  val route :
    ?rate_limit:rate_limiter ->
    ?mw:middleware ->
    ?regex:bool ->
    string ->
    handler ->
    route

  val scope :
    ?rate_limit:rate_limiter -> ?mw:middleware -> string -> route list -> route

  val no_route : route

  val virtual_hosts :
    ?meth:[ `ByURL | `SNI ] -> (string * handler) list -> handler
end

module Make (RateLimiter : Rate_limiter_impl.S) (Logger : Logger_impl.S) :
  S
    with module IO = RateLimiter.IO
     and type rate_limiter := RateLimiter.t
     and type addr := RateLimiter.Addr.t = struct
  module IO = RateLimiter.IO
  module Addr = RateLimiter.Addr

  type handler = Addr.t Handler.Make(IO).t
  type middleware = handler -> handler

  type route = route' list

  and route' = {
    route : [ `Regex of Re.re | `Literal ] * string;
    handler : handler;
    rate_limit : RateLimiter.t option;
  }

  let no_route = []

  let route ?rate_limit ?(mw = Fun.id) ?(regex = false) r handler =
    let kind =
      if regex then `Regex (Re.Perl.re r |> Re.Perl.compile) else `Literal
    in
    [ { route = (kind, r); handler = mw handler; rate_limit } ]

  let compare_url u u' =
    match (u, u') with
    | "", "/" | "/", "" | "", "" -> true
    | "", _ | _, "" -> false
    | _, _ when String.equal u u' -> true
    | _, _ when String.ends_with ~suffix:"/" u ->
        String.equal (String.sub u 0 (String.length u - 1)) u'
    | _, _ when String.ends_with ~suffix:"/" u' ->
        String.equal (String.sub u' 0 (String.length u' - 1)) u
    | _, _ -> false

  let router routes req =
    let routes = List.concat routes in
    let path = Request.target req in
    let route =
      let rec loop = function
        | [] -> None
        | { route = `Regex re, _; handler; rate_limit } :: rs -> (
            match Re.exec_opt re path with
            | None -> loop rs
            | Some _ as grp -> Some (handler, rate_limit, grp))
        | { route = `Literal, r; handler; rate_limit } :: _
          when compare_url r path ->
            Some (handler, rate_limit, None)
        | { route = `Literal, _; _ } :: rs -> loop rs
      in
      loop routes
    in
    match route with
    | None -> Response.(response Status.not_found "") |> IO.return
    | Some (handler, limit_opt, params) -> (
        let req = Request.attach_params req params in
        match limit_opt with
        | None -> handler req
        | Some limiter -> (
            match RateLimiter.check limiter req with
            | None ->
                Logger.info (fun log ->
                    log "'%a' is rate limited" Addr.pp (Request.ip req));
                handler req
            | Some resp -> resp))

  let scope ?rate_limit ?(mw = Fun.id) prefix routes =
    List.concat routes
    |> List.map (fun { route = kind, r; handler; _ } ->
           let r = prefix ^ r in
           let kind =
             match kind with
             | `Regex _ ->
                 `Regex (Re.Perl.re r |> Re.Perl.compile)
                 (* Recompile route with given prefix. *)
             | `Literal as l -> l
           in
           { route = (kind, r); handler = mw handler; rate_limit })

  let virtual_hosts ?(meth = `SNI) domains_handler req =
    let req_host =
      match meth with
      | `SNI -> Request.sni req
      | `ByURL ->
          Request.uri req |> Uri.host
          |> Option.get (* Guaranteed by [Protocol.make_request]. *)
    in
    let _, handler =
      (* Guaranteed by [Protocol.make_request]. *)
      List.find (fun (d, _) -> String.equal d req_host) domains_handler
    in
    handler req

  let no_middleware = ( @@ )

  let rec pipeline mws handler =
    match mws with [] -> handler | m :: ms -> m (pipeline ms handler)
end
OCaml

Innovation. Community. Security.