Skip to content

Commit

Permalink
Start implementing explicit DPS pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Apr 24, 2024
1 parent e49f882 commit 2ab0882
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 16 deletions.
1 change: 1 addition & 0 deletions dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ library
, CheckType
, ConcreteSyntax
, Core
, DPS
, Err
, Generalize
, Imp
Expand Down
29 changes: 21 additions & 8 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import Types.Imp
import Types.Primitives
import Types.Source
import Types.Top
import Util (enumerate, transitiveClosureM, bindM2, toSnocList)
import Util (enumerate, transitiveClosureM, bindM2, toSnocList, popList)

-- === Ordinary (local) builder class ===

Expand Down Expand Up @@ -661,6 +661,16 @@ buildTopLamFromPi
buildTopLamFromPi piTy@(PiType bs _) cont =
TopLam False piTy <$> buildLamExpr (EmptyAbs bs) cont

buildTopDestLamFromPi
:: ScopableBuilder r m
=> PiType r n
-> (forall l. (Emits l, Distinct l, DExt n l) => [AtomVar r l] -> AtomVar r l -> m l (Atom r l))
-> m n (TopLam r n)
buildTopDestLamFromPi piTy@(PiType bs _) cont =
TopLam True piTy <$> buildLamExpr (EmptyAbs bs) \argsAndDest -> do
let (args, dest) = popList argsAndDest
cont args dest

buildAlt
:: ScopableBuilder r m
=> Type r n
Expand Down Expand Up @@ -878,6 +888,9 @@ applyProjectionsRef (i:is) ref = getProjRef i =<< applyProjectionsRef is ref
getProjRef :: (Builder r m, Emits n) => Projection -> Atom r n -> m n (Atom r n)
getProjRef i r = emit =<< mkProjRef r i

newUninitializedRef :: (SBuilder m, Emits o) => SType o -> m o (SAtom o)
newUninitializedRef ty = emit $ NewRef ty

-- XXX: getUnpacked must reduce its argument to enforce the invariant that
-- ProjectElt atoms are always fully reduced (to avoid type errors between two
-- equivalent types spelled differently).
Expand Down Expand Up @@ -1020,17 +1033,17 @@ naryApp :: (CBuilder m, Emits n) => CAtom n -> [CAtom n] -> m n (CAtom n)
naryApp f xs= mkApp f xs >>= emit
{-# INLINE naryApp #-}

naryTopApp :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n)
naryTopApp f xs = emit =<< mkTopApp f xs
{-# INLINE naryTopApp #-}
topApp :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n)
topApp f xs = emit =<< mkTopApp f xs
{-# INLINE topApp #-}

naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n)
naryTopAppInlined f xs = do
topAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n)
topAppInlined f xs = do
TopFunBinding f' <- lookupEnv f
case f' of
DexTopFun _ lam _ -> instantiate lam xs >>= emit
_ -> naryTopApp f xs
{-# INLINE naryTopAppInlined #-}
_ -> topApp f xs
{-# INLINE topAppInlined #-}

tabApp :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
tabApp x i = mkTabApp x i >>= emit
Expand Down
151 changes: 151 additions & 0 deletions src/lib/DPS.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
-- Copyright 2022 Google LLC
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

{-# LANGUAGE UndecidableInstances #-}

module DPS (dpsPass) where

import Prelude hiding ((.))
import Data.Functor
import Data.Maybe (fromJust)
import Control.Category
import Control.Monad.Reader
import Unsafe.Coerce

import Builder
import Core
import Imp
import CheapReduction
import IRVariants
import Name
import Subst
import PPrint
import QueryType
import Types.Core
import Types.Top
import Types.Primitives
import Util (enumerate, popList, SMaybe (..), Not)

dpsPass :: EnvReader m => STopLam n -> m n (STopLam n)
dpsPass (TopLam False piTy (LamExpr bs body)) = do
liftEnvReaderM $ liftAtomSubstBuilder do
dpsPiTy <- computeDPSPiTy piTy
buildTopDestLamFromPi dpsPiTy \args dest -> do
extendSubst (bs @@> (SubstVal . toAtom <$> args)) do
SNothing <- dpsExpr (SJust (toAtom dest)) body
return UnitVal
dpsPass (TopLam True _ _) = error "already in destination style"

computeDPSPiTy :: PiType SimpIR i -> DestM i o (PiType SimpIR o)
computeDPSPiTy (PiType bs resultTy) = case bs of
Empty -> do
destTy <- computeDestTy =<< dpsSubstType resultTy
withFreshBinder "ans" destTy \bDest ->
return $ PiType (UnaryNest bDest) UnitTy
Nest (b:>ty) bsRest -> do
repTy <- computeRepTy =<< dpsSubstType ty
withFreshBinder (getNameHint b) repTy \b' ->
extendSubst (b@>Rename (binderName b')) do
PiType bsRest' resultTy' <- computeDPSPiTy (PiType bsRest resultTy)
return $ PiType (Nest b' bsRest') resultTy'

type Dest = SAtom
type MaybeDest d n = SMaybe d (Dest n)
type MaybeResult d n = SMaybe (Not d) (SAtom n)

data DPSTag
type DestM = AtomSubstBuilder DPSTag SimpIR

computeRepTy :: EnvReader m => SType n -> m n (SType n)
computeRepTy ty = case ty of
TyCon con -> case con of
BaseType _ -> return ty

computeDestTy :: EnvReader m => SType n -> m n (SType n)
computeDestTy ty = case ty of
TyCon con -> case con of
BaseType _ -> return $ RefTy ty

lowerAtom :: SAtom i -> DestM i o (SAtom o)
lowerAtom = substM

getDPSFun :: TopFunName o -> DestM i o (TopFunName o)
getDPSFun = undefined

loadIfScalar :: Emits o => SAtom o -> DestM i o (SAtom o)
loadIfScalar = undefined

loadDest :: Emits o => SAtom o -> DestM i o (SAtom o)
loadDest = undefined

storeDest :: Emits o => Dest o -> SAtom o -> DestM i o ()
storeDest dest val = do
RefTy (TyCon tycon) <- return $ getType dest
case tycon of
BaseType _ -> void $ emit $ RefOp dest $ MPut val

-- The dps pass carries a non-type-preserving substitution in which arrays are
-- replaced with refs to arrays. So it's incorrect to directly apply the
-- substitution as we do here. It's fine as long as arrays don't appear in types
-- but it's not clear what we should otherwise. Maybe it's ok to dereference an
-- read-only ref inside a type? reference?
dpsSubstType :: SType i -> DestM i o (SType o)
dpsSubstType = substM

dpsExpr
:: forall i o d. Emits o
=> MaybeDest d o
-> SExpr i
-> DestM i o (MaybeResult d o)
dpsExpr maybeDest expr = case expr of
Block _ block -> dpsBlock maybeDest block
TopApp _ f args -> withDest \dest -> do
f' <- substM f >>= getDPSFun
args' <- mapM lowerAtom args
void $ topApp f' (args' ++ [dest]) >>= emit
TabApp _ xs i -> do
xs' <- lowerAtom xs
i' <- lowerAtom i
x <- indexRef xs' i'
returnResult =<< loadIfScalar x
Case scrut alts _ -> withDest \dest -> do
scrut' <- lowerAtom scrut
void $ buildCase scrut' UnitTy \i x -> do
Abs b body <- return $ alts!!i
extendSubst (b@>SubstVal x) do
SNothing <- dpsExpr (SJust $ sink dest) body
return ()
return UnitVal
Atom x -> lowerAtom x >>= returnResult
TabCon _ _ -> undefined
PrimOp _ -> undefined
Project _ _ _ -> undefined

where
returnResult :: SAtom o -> DestM i o (MaybeResult d o)
returnResult result = do
case maybeDest of
SJust dest -> storeDest dest result >> return SNothing
SNothing -> return $ SJust result

withDest :: (Dest o -> DestM i o ()) -> DestM i o (MaybeResult d o)
withDest cont = do
case maybeDest of
SJust dest -> cont dest >> return SNothing
SNothing -> do
destTy <- dpsSubstType $ RefTy (getType expr)
dest <- newUninitializedRef destTy
cont dest
result <- loadDest dest
return $ SJust result

dpsBlock :: Emits o => MaybeDest d o -> SBlock i -> DestM i o (MaybeResult d o)
dpsBlock maybeDest (Abs decls result) = case decls of
Empty -> dpsExpr maybeDest result
Nest (Let b (DeclBinding _ expr)) declsRest -> do
SJust x <- dpsExpr SNothing expr
extendSubst (b@>SubstVal x) $
dpsBlock maybeDest (Abs declsRest result)
4 changes: 2 additions & 2 deletions src/lib/Linearize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,10 @@ linearizeExpr expr = case expr of
-- bindings and then only hoisting the final result.
Just (PairE fPrimal fTan) <- liftTopBuilderAndEmit $
liftM toPairE $ linearizeTopFun (sink $ LinearizationSpec f' (map isJust ts))
(ans, residuals) <- fromPair =<< naryTopApp fPrimal xs'
(ans, residuals) <- fromPair =<< topApp fPrimal xs'
return $ WithTangent ans do
ts' <- forM (catMaybes ts) \(WithTangent UnitE t) -> t
naryTopApp (sink fTan) (sinkList xs' ++ [sink residuals, Con $ ProdCon ts'])
topApp (sink fTan) (sinkList xs' ++ [sink residuals, Con $ ProdCon ts'])
where
unitLike :: e n -> UnitE n
unitLike _ = UnitE
Expand Down
4 changes: 2 additions & 2 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ simplifyApp resultTy f xs = case f of
CCNoInlineFun v _ _ -> simplifyTopFunApp v xs
CCFFIFun _ f' -> do
xs' <- dropSubst $ mapM toDataAtom xs
liftSimpAtom resultTy =<< naryTopApp f' xs'
liftSimpAtom resultTy =<< topApp f' xs'
CCACase aCase -> forceACase aCase \f' -> simplifyApp (sink resultTy) f' (sink <$> xs)
CCTabLam _ -> error "not a function"
CCLiftSimp _ _ -> error "not a function"
Expand All @@ -529,7 +529,7 @@ simplifyTopFunApp fName xs = do
let spec = AppSpecialization fName xsGeneralized
Just specializedFunction <- getSpecializedFunction spec >>= emitHoistedEnv
runtimeArgs' <- dropSubst $ mapM toDataAtom runtimeArgs
liftSimpAtom resultTy =<< naryTopApp specializedFunction runtimeArgs'
liftSimpAtom resultTy =<< topApp specializedFunction runtimeArgs'
False ->
-- TODO: we should probably just fall back to inlining in this case,
-- especially if we want make everything @noinline by default.
Expand Down
4 changes: 3 additions & 1 deletion src/lib/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import CheckType (checkTypes)
import Core
import ConcreteSyntax
import CheapReduction
import DPS
import Err
import IRVariants
import Imp
Expand Down Expand Up @@ -478,7 +479,8 @@ evalBlock typed = do
simpResult <- case opt of
TopLam _ _ (LamExpr Empty (Atom result)) -> return result
_ -> do
lOpt <- checkPass OptPass $ loweredOptimizations opt
dps <- checkPass LowerPass $ dpsPass opt
lOpt <- checkPass OptPass $ loweredOptimizations dps
cc <- getEntryFunCC
impOpt <- checkPass ImpPass $ toImpFunction cc lOpt
llvmOpt <- packageLLVMCallable impOpt
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Transpose.hs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ transposeExpr expr ct = case expr of
Just fT <- getTransposedTopFun =<< substNonlin f
(xsNonlin, [xLin]) <- return $ splitAt (length xs - 1) xs
xsNonlin' <- mapM substNonlin xsNonlin
ct' <- naryTopApp fT (xsNonlin' ++ [ct])
ct' <- topApp fT (xsNonlin' ++ [ct])
transposeAtom xLin ct'
PrimOp op -> transposeOp op ct
Case e alts _ -> do
Expand Down
5 changes: 4 additions & 1 deletion src/lib/Types/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ data MiscOp (r::IR) (n::S) =
| CastOp (Type r n) (Atom r n) -- (2) Type, then value. See CheckType.hs for valid coercions.
| BitcastOp (Type r n) (Atom r n) -- (2) Type, then value. See CheckType.hs for valid coercions.
| UnsafeCoerce (Type r n) (Atom r n) -- type, then value. Assumes runtime representation is the same.
| GarbageVal (Type r n) -- type of value (assume `Data` constraint)
| GarbageVal (Type r n) -- type of value (assume `Data` constraint) (TODO: redundant with NewRef)
| NewRef (Type r n)
| ThrowError (Type r n) -- (1) Hard error (parameterized by result type)
-- Tag of a sum type
| SumTag (Atom r n)
Expand Down Expand Up @@ -1165,6 +1166,7 @@ instance GenericOp MiscOp where
BitcastOp t x -> GenericOpRep P.BitcastOp [t] [x] []
UnsafeCoerce t x -> GenericOpRep P.UnsafeCoerce [t] [x] []
GarbageVal t -> GenericOpRep P.GarbageVal [t] [] []
NewRef t -> GenericOpRep P.NewRef [t] [] []
ThrowError t -> GenericOpRep P.ThrowError [t] [] []
SumTag x -> GenericOpRep P.SumTag [] [x] []
ToEnum t x -> GenericOpRep P.ToEnum [t] [x] []
Expand All @@ -1178,6 +1180,7 @@ instance GenericOp MiscOp where
GenericOpRep P.BitcastOp [t] [x] [] -> Just $ BitcastOp t x
GenericOpRep P.UnsafeCoerce [t] [x] [] -> Just $ UnsafeCoerce t x
GenericOpRep P.GarbageVal [t] [] [] -> Just $ GarbageVal t
GenericOpRep P.NewRef [t] [] [] -> Just $ NewRef t
GenericOpRep P.ThrowError [t] [] [] -> Just $ ThrowError t
GenericOpRep P.SumTag [] [x] [] -> Just $ SumTag x
GenericOpRep P.ToEnum [t] [x] [] -> Just $ ToEnum t x
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Types/OpNames.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ data CmpOp = Less | Greater | Equal | LessEqual | GreaterEqual
data MemOp = IOAlloc | IOFree | PtrOffset | PtrLoad | PtrStore

data MiscOp =
Select | CastOp | BitcastOp | UnsafeCoerce | GarbageVal | Effects
Select | CastOp | BitcastOp | UnsafeCoerce | GarbageVal | NewRef | Effects
| ThrowError | ThrowException | Tag | SumTag | Create | ToEnum
| OutputStream | ShowAny | ShowScalar

Expand Down
17 changes: 17 additions & 0 deletions src/lib/Util.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ onSnd f (x, y) = (x, f y)
onSndM :: (Functor m) => (a -> m b) -> (c, a) -> m (c, b)
onSndM f (x, y) = (x,) <$> f y

popList :: [a] -> ([a], a)
popList xs = case drop (n-1) xs of
[x] -> (xsPrefix, x)
_ -> error "empty list"
where n = length xs
xsPrefix = take (n-1) xs

unsnocNonempty :: NonEmpty a -> ([a], a)
unsnocNonempty (x:|xs) = case reverse (x:xs) of
(y:ys) -> (reverse ys, y)
Expand Down Expand Up @@ -387,3 +394,13 @@ readFileWithHash path = liftIO $ addHash <$> BS.readFile path
sameConstructor :: a -> a -> Bool
sameConstructor x y = tagToEnum# (getTag x ==# getTag y)
{-# INLINE sameConstructor #-}

-- === static-case version of Maybe ===

type family Not (x::Bool) where
Not True = False
Not False = True

data SMaybe (isJust::Bool) (a:: *) where
SNothing :: SMaybe False a
SJust :: a -> SMaybe True a

0 comments on commit 2ab0882

Please sign in to comment.