Skip to content

Commit

Permalink
Merge branch 'master' into flattening
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Dec 6, 2024
2 parents be27284 + daa7181 commit f71c51e
Show file tree
Hide file tree
Showing 17 changed files with 145 additions and 154 deletions.
6 changes: 3 additions & 3 deletions src/Futhark/AD/Fwd.hs
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,13 @@ zeroFromSubExp (Var v) = do
letExp "zero" $ zeroExp t

fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC pat aux (Screma size xs (ScremaForm scs reds f)) = do
fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do
pat' <- bundleNewPat pat
xs' <- bundleTangents xs
f' <- fwdLambda f
scs' <- mapM fwdScan scs
reds' <- mapM fwdRed reds
f' <- fwdLambda f
addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm scs' reds' f'
addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds'
where
fwdScan :: Scan SOACS -> ADM (Scan SOACS)
fwdScan sc = do
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/AD/Rev/SOAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ mapOp (Lambda [pa1, pa2] _ lam_body)
cs == mempty,
[map_stm] <- stmsToList (bodyStms lam_body),
(Let (Pat [pe]) _ (Op scrm)) <- map_stm,
(Screma _ [a1, a2] (ScremaForm [] [] map_lam)) <- scrm,
(Screma _ [a1, a2] (ScremaForm map_lam [] [])) <- scrm,
(a1 == paramName pa1 && a2 == paramName pa2) || (a1 == paramName pa2 && a2 == paramName pa1),
r == Var (patElemName pe) =
Just map_lam
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/Analysis/HORep/MapNest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fromSOAC' ::
[Ident] ->
SOAC rep ->
m (Maybe (MapNest rep))
fromSOAC' bound (SOAC.Screma w inps (SOAC.ScremaForm [] [] lam)) = do
fromSOAC' bound (SOAC.Screma w inps (SOAC.ScremaForm lam [] [])) = do
maybenest <- case ( stmsToList $ bodyStms $ lambdaBody lam,
bodyResult $ lambdaBody lam
) of
Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/Analysis/HORep/SOAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ newWidth (inp : _) _ = arraySize 0 $ inputType inp
lambda :: SOAC rep -> Lambda rep
lambda (Stream _ _ _ lam) = lam
lambda (Scatter _len _ivs _spec lam) = lam
lambda (Screma _ _ (ScremaForm _ _ lam)) = lam
lambda (Screma _ _ (ScremaForm lam _ _)) = lam
lambda (Hist _ _ _ lam) = lam

-- | Set the lambda used in the SOAC.
Expand All @@ -444,8 +444,8 @@ setLambda lam (Stream w arrs nes _) =
Stream w arrs nes lam
setLambda lam (Scatter len arrs spec _lam) =
Scatter len arrs spec lam
setLambda lam (Screma w arrs (ScremaForm scan red _)) =
Screma w arrs (ScremaForm scan red lam)
setLambda lam (Screma w arrs (ScremaForm _ scan red)) =
Screma w arrs (ScremaForm lam scan red)
setLambda lam (Hist w ops inps _) =
Hist w ops inps lam

Expand Down
29 changes: 15 additions & 14 deletions src/Futhark/IR/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -727,24 +727,25 @@ pSOAC pr =
<*> p
pScremaForm =
SOAC.ScremaForm
<$> braces (pScan pr `sepBy` pComma)
<$> pLambda pr
<* pComma
<*> braces (pReduce pr `sepBy` pComma)
<*> braces (pScan pr `sepBy` pComma)
<* pComma
<*> pLambda pr
<*> braces (pReduce pr `sepBy` pComma)
pRedomapForm =
SOAC.ScremaForm mempty
<$> braces (pReduce pr `sepBy` pComma)
SOAC.ScremaForm
<$> pLambda pr
<*> pure []
<* pComma
<*> pLambda pr
<*> braces (pReduce pr `sepBy` pComma)
pScanomapForm =
SOAC.ScremaForm
<$> braces (pScan pr `sepBy` pComma)
<$> pLambda pr
<* pComma
<*> pure mempty
<*> pLambda pr
<*> braces (pScan pr `sepBy` pComma)
<*> pure []
pMapForm =
SOAC.ScremaForm mempty mempty <$> pLambda pr
SOAC.ScremaForm <$> pLambda pr <*> pure mempty <*> pure mempty
pScatter =
keyword "scatter"
*> parens
Expand Down Expand Up @@ -797,19 +798,19 @@ pSOAC pr =
pVJP =
parens $
SOAC.VJP
<$> pLambda pr
<$> braces (pSubExp `sepBy` pComma)
<* pComma
<*> braces (pSubExp `sepBy` pComma)
<* pComma
<*> braces (pSubExp `sepBy` pComma)
<*> pLambda pr
pJVP =
parens $
SOAC.JVP
<$> pLambda pr
<$> braces (pSubExp `sepBy` pComma)
<* pComma
<*> braces (pSubExp `sepBy` pComma)
<* pComma
<*> braces (pSubExp `sepBy` pComma)
<*> pLambda pr

pSizeClass :: Parser GPU.SizeClass
pSizeClass =
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/SOACS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ usesAD prog = any stmUsesAD (progConsts prog) || any funUsesAD (progFuns prog)
expUsesAD (Op JVP {}) = True
expUsesAD (Op VJP {}) = True
expUsesAD (Op (Stream _ _ _ lam)) = lamUsesAD lam
expUsesAD (Op (Screma _ _ (ScremaForm scans reds lam))) =
expUsesAD (Op (Screma _ _ (ScremaForm lam scans reds))) =
lamUsesAD lam
|| any (lamUsesAD . scanLambda) scans
|| any (lamUsesAD . redLambda) reds
Expand Down
Loading

0 comments on commit f71c51e

Please sign in to comment.