diff --git a/lib/github.com/diku-dk/segmented/segmented.fut b/lib/github.com/diku-dk/segmented/segmented.fut index 32e7f27..2a598f9 100644 --- a/lib/github.com/diku-dk/segmented/segmented.fut +++ b/lib/github.com/diku-dk/segmented/segmented.fut @@ -42,12 +42,9 @@ def segmented_reduce [n] 't (op: t -> t -> t) (ne: t) -- returns the array [0,0,1,1,1,2]. def replicated_iota [n] (reps:[n]i64) : []i64 = - let s1 = scan (+) 0 reps - let s2 = map2 (\i x -> if i==0 then 0 else x) - (iota n) (rotate (-1) s1) - let tmp = reduce_by_index (replicate (reduce (+) 0 reps) 0) i64.max 0 s2 (iota n) - let flags = map (>0) tmp - in segmented_scan (+) 0 flags tmp + let offsets = scan (+) 0 reps + in hist (+) 0 (i64.sum reps) offsets (replicate n 1) + |> scan (+) 0 -- | Segmented iota. Given a flags array, the function returns an -- array of index sequences, each of which is reset according to the @@ -59,6 +56,22 @@ def segmented_iota [n] (flags:[n]bool) : [n]i64 = let iotas = segmented_scan (+) 0 flags (replicate n 1) in map (\x -> x-1) iotas +-- | Replicated and segemented iota generated together +-- in a slighly more efficient way. +-- each segment in the segmented iota corresponds to a segment +-- in the replicated iota. As an example repl_segm_iota [2,3,1] +-- returns the arrays [0,0,1,1,1,2] and [0,1,0,1,2,0]. + +def repl_segm_iota [n] (reps:[n]i64) : ([]i64, []i64) = + if n == 0 then ([], []) else + let offsets = scan (+) 0 reps + let start_idx = map2 (-) offsets reps + let sz = last offsets + let repl = hist (+) 0 sz offsets (replicate n 1) + |> scan (+) 0 + let segm = tabulate sz (\i -> i - start_idx[repl[i]]) + in (repl, segm) + -- | Generic expansion function. The function expands a source array -- into a target array given (1) a function that determines, for each -- source element, how many target elements it expands to and (2) a @@ -69,8 +82,7 @@ def segmented_iota [n] (flags:[n]bool) : [n]i64 = def expand 'a 'b (sz: a -> i64) (get: a -> i64 -> b) (arr:[]a) : []b = let szs = map sz arr - let idxs = replicated_iota szs - let iotas = segmented_iota (map2 (!=) idxs (rotate (-1) idxs)) + let (idxs, iotas) = repl_segm_iota szs in map2 (\i j -> get arr[i] j) idxs iotas -- | Expansion function equivalent to performing a segmented reduction