Skip to content

Commit

Permalink
Faster expand and replicated_iota (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
FluxusMagna authored Dec 2, 2023
1 parent a46152f commit 9653ba1
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions lib/github.com/diku-dk/segmented/segmented.fut
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,9 @@ def segmented_reduce [n] 't (op: t -> t -> t) (ne: t)
-- returns the array [0,0,1,1,1,2].

def replicated_iota [n] (reps:[n]i64) : []i64 =
let s1 = scan (+) 0 reps
let s2 = map2 (\i x -> if i==0 then 0 else x)
(iota n) (rotate (-1) s1)
let tmp = reduce_by_index (replicate (reduce (+) 0 reps) 0) i64.max 0 s2 (iota n)
let flags = map (>0) tmp
in segmented_scan (+) 0 flags tmp
let offsets = scan (+) 0 reps
in hist (+) 0 (i64.sum reps) offsets (replicate n 1)
|> scan (+) 0

-- | Segmented iota. Given a flags array, the function returns an
-- array of index sequences, each of which is reset according to the
Expand All @@ -59,6 +56,22 @@ def segmented_iota [n] (flags:[n]bool) : [n]i64 =
let iotas = segmented_scan (+) 0 flags (replicate n 1)
in map (\x -> x-1) iotas

-- | Replicated and segemented iota generated together
-- in a slighly more efficient way.
-- each segment in the segmented iota corresponds to a segment
-- in the replicated iota. As an example repl_segm_iota [2,3,1]
-- returns the arrays [0,0,1,1,1,2] and [0,1,0,1,2,0].

def repl_segm_iota [n] (reps:[n]i64) : ([]i64, []i64) =
if n == 0 then ([], []) else
let offsets = scan (+) 0 reps
let start_idx = map2 (-) offsets reps
let sz = last offsets
let repl = hist (+) 0 sz offsets (replicate n 1)
|> scan (+) 0
let segm = tabulate sz (\i -> i - start_idx[repl[i]])
in (repl, segm)

-- | Generic expansion function. The function expands a source array
-- into a target array given (1) a function that determines, for each
-- source element, how many target elements it expands to and (2) a
Expand All @@ -69,8 +82,7 @@ def segmented_iota [n] (flags:[n]bool) : [n]i64 =

def expand 'a 'b (sz: a -> i64) (get: a -> i64 -> b) (arr:[]a) : []b =
let szs = map sz arr
let idxs = replicated_iota szs
let iotas = segmented_iota (map2 (!=) idxs (rotate (-1) idxs))
let (idxs, iotas) = repl_segm_iota szs
in map2 (\i j -> get arr[i] j) idxs iotas

-- | Expansion function equivalent to performing a segmented reduction
Expand Down

0 comments on commit 9653ba1

Please sign in to comment.