Skip to content

Commit

Permalink
Purify fold and move it into Tar with a GADT, use it then for Tar_gz …
Browse files Browse the repository at this point in the history
…which will produce an other GADT value and Tar_{,lwt_}unix which evaluate our GADT
  • Loading branch information
dinosaure committed Feb 7, 2024
1 parent 2b49b1f commit 2388f62
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 148 deletions.
41 changes: 41 additions & 0 deletions lib/tar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -815,3 +815,44 @@ let encode_header ?level header =

let encode_global_extended_header ?level global =
encode_extended_header ?level `Global global

type ('a, 'err) t =
| Really_read : int -> (string, 'err) t
| Read : int -> (string, 'err) t
| Seek : int -> (int, 'err) t
| Bind : ('a, 'err) t * ('a -> ('b, 'err) t) -> ('b, 'err) t
| Return : ('a, 'err) result -> ('a, 'err) t

let ( let* ) x f = Bind (x, f)
let return x = Return x
let really_read n = Really_read n
let read n = Read n
let seek n = Seek n

type ('a, 'err) fold = (?global:Header.Extended.t -> Header.t -> 'a -> ('a, 'err) result) -> 'a -> ('a, 'err) t

let fold f init =
let rec go t ?global ?data acc =
let* data = match data with
| None -> really_read Header.length
| Some data -> return (Ok data) in
match decode t data with
| Ok (t, Some `Header hdr, g) ->
let global = Option.fold ~none:global ~some:(fun g -> Some g) g in
let* acc' = return (f ?global hdr acc) in
let* _off = seek (Header.compute_zero_padding_length hdr) in
go t ?global acc'
| Ok (t, Some `Skip n, g) ->
let global = Option.fold ~none:global ~some:(fun g -> Some g) g in
let* _off = seek n in
go t ?global acc
| Ok (t, Some `Read n, g) ->
let global = Option.fold ~none:global ~some:(fun g -> Some g) g in
let* data = really_read n in
go t ?global ~data acc
| Ok (t, None, g) ->
let global = Option.fold ~none:global ~some:(fun g -> Some g) g in
go t ?global acc
| Error `Eof -> return (Ok acc)
| Error `Fatal _ as e -> return e in
go (decode_state ()) init
25 changes: 22 additions & 3 deletions lib/tar.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
{e %%VERSION%% - {{:%%PKG_HOMEPAGE%% }homepage}} *)

(** The type of errors that may occur. *)
type error = [`Checksum_mismatch | `Corrupt_pax_header | `Zero_block | `Unmarshal of string]
type error = [ `Checksum_mismatch | `Corrupt_pax_header | `Zero_block | `Unmarshal of string ]

(** [pp_error ppf e] pretty prints the error [e] on the formatter [ppf]. *)
val pp_error : Format.formatter -> [< error] -> unit
Expand Down Expand Up @@ -123,7 +123,7 @@ module Header : sig
(** Unmarshal a header block, returning [None] if it's all zeroes.
This header block may be preceded by an [?extended] block which
will override some fields. *)
val unmarshal : ?extended:Extended.t -> string -> (t, [`Zero_block | `Checksum_mismatch | `Unmarshal of string]) result
val unmarshal : ?extended:Extended.t -> string -> (t, [> `Zero_block | `Checksum_mismatch | `Unmarshal of string]) result

(** Marshal a header block, computing and inserting the checksum. *)
val marshal : ?level:compatibility -> bytes -> t -> (unit, [> `Msg of string ]) result
Expand Down Expand Up @@ -157,7 +157,7 @@ val decode_state : ?global:Header.Extended.t -> unit -> decode_state
further decoding until [`Eof] (or an error) occurs. *)
val decode : decode_state -> string ->
(decode_state * [ `Read of int | `Skip of int | `Header of Header.t ] option * Header.Extended.t option,
[ `Eof | `Fatal of [ `Checksum_mismatch | `Corrupt_pax_header | `Unmarshal of string ] ])
[ `Eof | `Fatal of error ])
result

(** [encode_header ~level hdr] encodes the header with the provided [level]
Expand All @@ -170,3 +170,22 @@ val encode_header : ?level:Header.compatibility ->
(** [encode_global_extended_header hdr] encodes the global extended header as
a list of strings. *)
val encode_global_extended_header : ?level:Header.compatibility -> Header.Extended.t -> (string list, [> `Msg of string ]) result

(** {1 Pure implementation of [fold].} *)

type ('a, 'err) t =
| Really_read : int -> (string, 'err) t
| Read : int -> (string, 'err) t
| Seek : int -> (int, 'err) t
| Bind : ('a, 'err) t * ('a -> ('b, 'err) t) -> ('b, 'err) t
| Return : ('a, 'err) result -> ('a, 'err) t

val really_read : int -> (string, _) t
val read : int -> (string, _) t
val seek : int -> (int, _) t
val ( let* ) : ('a, 'err) t -> ('a -> ('b, 'err) t) -> ('b, 'err) t
val return : ('a, 'err) result -> ('a, 'err) t

type ('a, 'err) fold = (?global:Header.Extended.t -> Header.t -> 'a -> ('a, 'err) result) -> 'a -> ('a, 'err) t

val fold : ('a, [> `Fatal of error ]) fold
161 changes: 86 additions & 75 deletions lib/tar_gz.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

module type READER = sig
type in_channel
type 'a io
val read : in_channel -> bytes -> int io
end

external ba_get_int32_ne : De.bigstring -> int -> int32 = "%caml_bigstring_get32"
external ba_set_int32_ne : De.bigstring -> int -> int32 -> unit = "%caml_bigstring_set32"

(*
let bigstring_to_string ?(off= 0) ?len ba =
let len = match len with
| Some len -> len
Expand All @@ -41,6 +36,7 @@ let bigstring_to_string ?(off= 0) ?len ba =
Bytes.set res i v
done;
Bytes.unsafe_to_string res
*)

let bigstring_blit_string src ~src_off dst ~dst_off ~len =
let len0 = len land 3 in
Expand Down Expand Up @@ -71,6 +67,89 @@ let bigstring_blit_bytes src ~src_off dst ~dst_off ~len =
Bytes.set dst (dst_off + i) v
done

type decoder =
{ mutable gz : Gz.Inf.decoder
; ic_buffer : De.bigstring
; oc_buffer : De.bigstring
; tp_length : int
; mutable pos : int }

let really_read_through_gz
: decoder -> bytes -> (unit, 'err) Tar.t
= fun ({ ic_buffer; oc_buffer; tp_length; _ } as state) res ->
let open Tar in
let rec until_full_or_end gz (res, res_off, res_len) =
match Gz.Inf.decode gz with
| `Flush gz ->
let max = De.bigstring_length oc_buffer - Gz.Inf.dst_rem gz in
let len = min res_len max in
bigstring_blit_bytes oc_buffer ~src_off:0 res ~dst_off:res_off ~len;
if len < max
then ( state.pos <- len
; state.gz <- gz
; return (Ok ()) )
else until_full_or_end (Gz.Inf.flush gz) (res, res_off + len, res_len - len)
| `End gz ->
let max = De.bigstring_length oc_buffer - Gz.Inf.dst_rem gz in
let len = min res_len max in
bigstring_blit_bytes oc_buffer ~src_off:0 res ~dst_off:res_off ~len;
if res_len > len
then return (Error `Eof)
else ( state.pos <- len
; state.gz <- gz
; return (Ok ()) )
| `Await gz ->
let* tp_buffer = Tar.read tp_length in
let len = String.length tp_buffer in
bigstring_blit_string tp_buffer ~src_off:0 ic_buffer ~dst_off:0 ~len;
let gz = Gz.Inf.src gz ic_buffer 0 len in
until_full_or_end gz (res, res_off, res_len)
| `Malformed err -> return (Error (`Gz err)) in
let max = (De.bigstring_length oc_buffer - Gz.Inf.dst_rem state.gz) - state.pos in
let len = min (Bytes.length res) max in
bigstring_blit_bytes oc_buffer ~src_off:state.pos res ~dst_off:0 ~len;
if len < max
then ( state.pos <- state.pos + len
; return (Ok ()) )
else until_full_or_end (Gz.Inf.flush state.gz) (res, len, Bytes.length res - len)

let really_read_through_gz decoder len =
let open Tar in
let res = Bytes.create len in
let* () = really_read_through_gz decoder res in
Tar.return (Ok (Bytes.unsafe_to_string res))

type error = [ `Fatal of Tar.error | `Eof | `Gz of string ]

let seek_through_gz : decoder -> int -> (int, [> error ]) Tar.t = fun state len ->
let open Tar in
let* _buf = really_read_through_gz state len in
Tar.return (Ok 0 (* XXX(dinosaure): actually, [fold] ignores the result. *))

type 'err run = { run : 'a 'err. ('a, 'err) Tar.t -> ('a, 'err) result } [@@unboxed]

let fold_with_gz
: run:[> error ] run -> _ -> _ -> _
= fun ~run:{ run } f init ->
let rec go : type a. decoder -> (a, [> error ] as 'err) Tar.t -> (a, 'err) Tar.t = fun decoder -> function
| Tar.Really_read len -> really_read_through_gz decoder len
| Tar.Read _len -> assert false (* XXX(dinosaure): actually does not emit [Tar.Read]. *)
| Tar.Seek len -> seek_through_gz decoder len
| Tar.Return v -> Tar.return v
| Tar.Bind (x, f) ->
match run x with
| Ok value -> go decoder (f value)
| Error _ as err -> Tar.return err in
let decoder =
let oc_buffer = De.bigstring_create 0x1000 in
{ gz= Gz.Inf.decoder `Manual ~o:oc_buffer
; oc_buffer
; ic_buffer= De.bigstring_create 0x1000
; tp_length= 0x1000
; pos= 0 } in
go decoder (Tar.fold f init)

(*
module Make
(Async : Tar.ASYNC)
(Writer : Tar.WRITER with type 'a io = 'a Async.t)
Expand Down Expand Up @@ -108,75 +187,6 @@ module Make
go gz (str, 0, String.length str)
end
module Gz_reader = struct
type in_channel =
{ mutable gz : Gz.Inf.decoder
; ic_buffer : De.bigstring
; oc_buffer : De.bigstring
; tp_buffer : bytes
; in_channel : Reader.in_channel
; mutable pos : int }

type 'a io = 'a Async.t

let really_read
: in_channel -> bytes -> unit io
= fun ({ ic_buffer; oc_buffer; in_channel; tp_buffer; _ } as state) res ->
let rec until_full_or_end gz (res, res_off, res_len) =
match Gz.Inf.decode gz with
| `Flush gz ->
let max = De.bigstring_length oc_buffer - Gz.Inf.dst_rem gz in
let len = min res_len max in
bigstring_blit_bytes oc_buffer ~src_off:0 res ~dst_off:res_off ~len;
if len < max
then ( state.pos <- len
; state.gz <- gz
; Async.return () )
else until_full_or_end (Gz.Inf.flush gz) (res, res_off + len, res_len - len)
| `End gz ->
let max = De.bigstring_length oc_buffer - Gz.Inf.dst_rem gz in
let len = min res_len max in
bigstring_blit_bytes oc_buffer ~src_off:0 res ~dst_off:res_off ~len;
if res_len > len
then raise End_of_file
else ( state.pos <- len
; state.gz <- gz
; Async.return () )
| `Await gz ->
Reader.read in_channel tp_buffer >>= fun len ->
bigstring_blit_string (Bytes.unsafe_to_string tp_buffer) ~src_off:0 ic_buffer ~dst_off:0 ~len;
let gz = Gz.Inf.src gz ic_buffer 0 len in
until_full_or_end gz (res, res_off, res_len)
| `Malformed err -> failwith ("gzip: " ^ err) in
let max = (De.bigstring_length oc_buffer - Gz.Inf.dst_rem state.gz) - state.pos in
let len = min (Bytes.length res) max in
bigstring_blit_bytes oc_buffer ~src_off:state.pos res ~dst_off:0 ~len;
if len < max
then ( state.pos <- state.pos + len
; Async.return () )
else until_full_or_end (Gz.Inf.flush state.gz) (res, len, Bytes.length res - len)

let skip : in_channel -> int -> unit io = fun state len ->
let res = Bytes.create len in
really_read state res
end

module HeaderWriter = Tar.HeaderWriter (Async) (Gz_writer)
module HeaderReader = Tar.HeaderReader (Async) (Gz_reader)

type in_channel = Gz_reader.in_channel

let of_in_channel ~internal:oc_buffer in_channel =
{ Gz_reader.gz= Gz.Inf.decoder `Manual ~o:oc_buffer
; oc_buffer
; ic_buffer= De.bigstring_create 0x1000
; tp_buffer= Bytes.create 0x1000
; in_channel
; pos= 0 }

let really_read = Gz_reader.really_read
let skip = Gz_reader.skip

type out_channel = Gz_writer.out_channel
let of_out_channel ?bits:(w_bits= 15) ?q:(q_len= 0x1000) ~level ~mtime os out_channel =
Expand Down Expand Up @@ -230,3 +240,4 @@ module Make
| `End _gz -> Async.return () in
until_end (Gz.Def.src state.gz De.bigstring_empty 0 0)
end
*)
8 changes: 8 additions & 0 deletions lib/tar_gz.mli
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*)

type error = [ `Fatal of Tar.error | `Eof | `Gz of string ]

type 'err run = { run : 'a 'err. ('a, 'err) Tar.t -> ('a, 'err) result } [@@unboxed]

val fold_with_gz : run:[> error ] run -> ('a, [> error]) Tar.fold

(*
module type READER = sig
type in_channel
type 'a io
Expand Down Expand Up @@ -72,3 +79,4 @@ module Make
module HeaderWriter :
Tar.HEADERWRITER with type out_channel = out_channel and type 'a io = 'a Async.t
end
*)
Loading

0 comments on commit 2388f62

Please sign in to comment.