Skip to content

Commit

Permalink
Remove ACase and TabLam
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed May 1, 2024
1 parent c0946f1 commit 3e0cbe9
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 198 deletions.
24 changes: 1 addition & 23 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module CheapReduction
, bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst
, repValAtom, reduceUnwrap, reduceProj, reduceSuperclassProj, typeOfApp
, reduceInstantiateGiven, queryStuckType, substMStuck, reduceTabApp, substStuck
, liftSimpAtom, reduceACase)
, liftSimpAtom)
where

import Control.Applicative
Expand Down Expand Up @@ -64,10 +64,6 @@ reduceProj :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n)
reduceProj i x = liftM fromJust $ liftReducerM $ reduceProjM i x
{-# INLINE reduceProj #-}

reduceACase :: EnvReader m => SAtom n -> [Abs SBinder CAtom n] -> CType n -> m n (CAtom n)
reduceACase scrut alts resultTy = liftM fromJust $ liftReducerM $ reduceACaseM scrut alts resultTy
{-# INLINE reduceACase #-}

reduceUnwrap :: EnvReader m => CAtom n -> m n (CAtom n)
reduceUnwrap x = liftM fromJust $ liftReducerM $ reduceUnwrapM x
{-# INLINE reduceUnwrap #-}
Expand Down Expand Up @@ -138,14 +134,6 @@ reduceApp f xs = do
Con (Lam lam) -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body
_ -> empty

reduceACaseM :: SAtom n -> [Abs SBinder CAtom n] -> CType n -> ReducerM i n (CAtom n)
reduceACaseM scrut alts resultTy = case scrut of
Con (SumCon _ i arg) -> do
Abs b body <- return $ alts !! i
applySubst (b@>SubstVal arg) body
Con _ -> error "not a sum type"
Stuck _ scrut' -> mkStuck $ ACase scrut' alts resultTy

reduceProjM :: IRRep r => Int -> Atom r o -> ReducerM i o (Atom r o)
reduceProjM i x = case x of
Con con -> case con of
Expand Down Expand Up @@ -199,10 +187,6 @@ queryStuckType = \case
SuperclassProj i s -> superclassProjType i =<< queryStuckType s
LiftSimp t _ -> return t
LiftSimpFun t _ -> return $ toType t
-- TabLam and ACase are just defunctionalization tools. The result type
-- in both cases should *not* be `Data`.
TabLam (PairE t _) -> return $ toType t
ACase _ _ resultTy -> return resultTy

projType :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Type r n)
projType i x = case getType x of
Expand Down Expand Up @@ -637,12 +621,6 @@ reduceStuck = \case
s' <- reduceStuck s
liftSimpAtom t' s'
LiftSimpFun t f -> mkStuck =<< (LiftSimpFun <$> substM t <*> substM f)
TabLam lam -> mkStuck =<< (TabLam <$> substM lam)
ACase scrut alts resultTy -> do
scrut' <- reduceStuck scrut
resultTy' <- substM resultTy
alts' <- mapM substM alts
reduceACaseM scrut' alts' resultTy'

liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n)
liftSimpAtom (StuckTy _ _) _ = error "Can't lift stuck type"
Expand Down
2 changes: 0 additions & 2 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,6 @@ instance IRRep r => CheckableE r (Stuck r) where
RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check
LiftSimp t x -> LiftSimp <$> checkE t <*> renameM x -- TODO: check
LiftSimpFun t x -> LiftSimpFun <$> checkE t <*> renameM x -- TODO: check
ACase scrut alts resultTy -> ACase <$> renameM scrut <*> mapM renameM alts <*> checkE resultTy -- TODO: check
TabLam lam -> TabLam <$> renameM lam -- TODO: check

depPairLeftTy :: DepPairType r n -> Type r n
depPairLeftTy (DepPairType _ (_:>ty) _) = ty
Expand Down
165 changes: 5 additions & 160 deletions src/lib/Simplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,10 @@ tryAsDataAtom atom = do
data WithSubst (e::E) (o::S) where
WithSubst :: Subst AtomSubstVal i o -> e i -> WithSubst e o

type ACase = SStuck `PairE` ListE (Abs SBinder CAtom) `PairE` CType

data ConcreteCAtom (n::S) =
CCCon (WithSubst (Con CoreIR) n)
| CCLiftSimp (CType n) (Stuck SimpIR n)
| CCFun (ConcreteCFun n)
| CCTabLam (WithSubst TabLamExpr n)
| CCACase (WithSubst ACase n)

data ConcreteCFun (n::S) =
CCLiftSimpFun (CorePiType n) (LamExpr SimpIR n)
Expand Down Expand Up @@ -112,41 +108,25 @@ forceStuck stuck = withDistinct case stuck of
Rename v' -> v'
x' <- runSubstReaderT s' $ renameM x
returnLifted x'
-- We "thunk" ACase rather than forcing it because different use-cases require different ways to force it
ACase e alts resultTy -> do
subst <- getSubst
return $ CCACase $ WithSubst subst $ e `PairE` ListE alts `PairE` resultTy
TabLam e -> do
subst <- getSubst
return $ CCTabLam $ WithSubst subst e
StuckProject i x -> forceStuck x >>= \case
CCLiftSimp _ x' -> returnLifted $ StuckProject i x'
CCCon (WithSubst s con) -> withSubst s case con of
ProdCon xs -> forceConstructor (xs!!i)
DepPair l r _ -> forceConstructor ([l, r]!!i)
_ -> error "not a product"
CCACase x' -> pushUnderACase x' \x'' -> reduceProj i x''
CCFun _ -> error "not a product"
CCTabLam _ -> error "not a product"
StuckTabApp f x -> forceStuck f >>= \case
CCLiftSimp _ f' -> do
x' <- toDataAtom x
returnLifted $ StuckTabApp f' x'
CCTabLam (WithSubst s (PairE _ (Abs b body))) -> do
x' <- toDataAtom x
result <- withSubst s $ extendSubst (b@>SubstVal x') $ substM body
dropSubst $ forceConstructor result
CCACase f' -> pushUnderACase f' \f'' -> reduceTabApp f'' =<< substM x
CCCon _ -> error "not a table"
CCFun _ -> error "not a table"
StuckUnwrap x -> forceStuck x >>= \case
CCCon (WithSubst s con) -> case con of
NewtypeCon _ x' -> withSubst s $ forceConstructor x'
_ -> error "not a newtype"
CCLiftSimp _ x' -> returnLifted x'
CCACase x' -> pushUnderACase x' \x'' -> reduceUnwrap x''
CCFun _ -> error "not a newtype"
CCTabLam _ -> error "not a newtype"
InstantiatedGiven _ _ -> error "shouldn't have this left"
SuperclassProj _ _ -> error "shouldn't have this left"
PtrVar ty p -> do
Expand All @@ -159,27 +139,6 @@ forceStuck stuck = withDistinct case stuck of
resultTy <- getType <$> substMStuck stuck
return $ CCLiftSimp resultTy s

pushUnderACase
:: WithSubst ACase o
-> (forall o'. DExt o o' => CAtom o' -> SimplifyM i o' (CAtom o'))
-> SimplifyM i o (ConcreteCAtom o)
pushUnderACase _ _ = undefined
-- pushUnderACase (WithSubst s (scrut `PairE` ListE alts `PairE` resultTy)) cont = undefined
-- TODO: make a buildACase to use here and elsewhere in Simplify. Maybe in CheapReduce too?


forceACase
:: Emits o => WithSubst ACase o
-> (forall o'. (Emits o', DExt o o') => ConcreteCAtom o' -> SimplifyM i o' (CAtom o'))
-> SimplifyM i o (CAtom o)
forceACase (WithSubst subst (scrut `PairE` ListE alts `PairE` resultTy)) cont = do
resultTy' <- withSubst subst $ substM resultTy
scrut' <- withSubst subst $ substMStuck scrut
defuncCase scrut' resultTy' \i x -> do
Abs b body <- return $ alts !! i
body' <- withSubst (sink subst) $ extendSubst (b@>SubstVal x) $ forceConstructor body
cont body'

tryGetRepType :: Type CoreIR n -> SimplifyM i n (Maybe (SType n))
tryGetRepType t = isData t >>= \case
False -> return Nothing
Expand Down Expand Up @@ -232,8 +191,6 @@ toDataAtom (Stuck _ stuck) = forceStuck stuck >>= \case
CCCon (WithSubst s con) -> withSubst s $ toDataAtom (Con con)
CCLiftSimp _ e -> mkStuck e
CCFun _ -> notData
CCACase _ -> notData -- TODO: make sure we observe this invariant"
CCTabLam _ -> notData -- TODO: make sure we observe this invariant"
where notData = error $ "Not runtime-representable data"

toDataAtomAssumeNoDecls :: CAtom i -> SimplifyM i o (SAtom o)
Expand Down Expand Up @@ -374,7 +331,7 @@ simpDeclsSubst !s = \case
simpDeclsSubst (s <>> (b@>SubstVal x)) rest

simplifyExpr :: Emits o => Expr CoreIR i -> SimplifyM i o (CAtom o)
simplifyExpr expr = confuseGHC >>= \_ -> case expr of
simplifyExpr = \case
Block _ (Abs decls body) -> simplifyDecls decls $ simplifyExpr body
App (EffTy _ ty) f xs -> do
ty' <- substM ty
Expand Down Expand Up @@ -409,7 +366,7 @@ simplifyExpr expr = confuseGHC >>= \_ -> case expr of
tryAsDataAtom x' >>= \case
Just (x'', _) -> liftSimpAtom ty' =<< proj i x''
Nothing -> requireReduced $ Project ty' i x'
Unwrap _ _ -> requireReduced =<< substM expr
Unwrap ty x -> requireReduced =<< substM (Unwrap ty x)

requireReduced :: CExpr o -> SimplifyM i o (CAtom o)
requireReduced expr = reduceExpr expr >>= \case
Expand Down Expand Up @@ -463,43 +420,7 @@ defuncCase scrut resultTy cont = do
dropSubst $ toDataAtom ans
caseExpr <- mkCase scrut resultTyData alts'
emit caseExpr >>= liftSimpAtom resultTy
Nothing -> do
split <- splitDataComponents resultTy
(alts', closureTys, recons) <- unzip3 <$> forM (enumerate altBinderTys) \(i, bTy) -> do
simplifyAlt split bTy $ cont i
let closureSumTy = TyCon $ SumType closureTys
let newNonDataTy = nonDataTy split
alts'' <- forM (enumerate alts') \(i, alt) -> injectAltResult closureTys i alt
caseExpr <- mkCase scrut (PairTy (dataTy split) closureSumTy) alts''
caseResult <- emit $ caseExpr
(dataVal, sumVal) <- fromPair caseResult
reconAlts <- forM (zip closureTys recons) \(ty, recon) ->
buildAbs noHint ty \v -> applyRecon (sink recon) (toAtom v)
nonDataVal <- reduceACase sumVal reconAlts newNonDataTy
Distinct <- getDistinct
fromSplit split dataVal nonDataVal

simplifyAlt
:: SplitDataNonData n
-> SType o
-> (forall o'. (Emits o', DExt o o') => SAtom o' -> SimplifyM i o' (CAtom o'))
-> SimplifyM i o (Alt SimpIR o, SType o, ReconstructAtom o)
simplifyAlt split ty cont = do
withFreshBinder noHint ty \b -> do
ab <- buildScoped $ cont $ sink $ toAtom $ binderVar b
(body, recon) <- refreshAbs ab \decls result -> do
let locals = toScopeFrag b >>> toScopeFrag decls
-- TODO: this might be too cautious. The type only needs to
-- be hoistable above the decls. In principle it can still
-- mention vars from the lambda binders.
Distinct <- getDistinct
(resultData, resultNonData) <- toSplit split result
(newResult, reconAbs) <- telescopicCapture locals resultNonData
return (Abs decls (PairVal resultData newResult), LamRecon reconAbs)
body' <- mkBlock body
PairTy _ nonDataType <- return $ getType body'
let nonDataType' = ignoreHoistFailure $ hoist b nonDataType
return (Abs b body', nonDataType', recon)
Nothing -> error "not data"

simplifyApp :: Emits o => CType o -> ConcreteCAtom o -> [CAtom o] -> SimplifyM i o (CAtom o)
simplifyApp resultTy f xs = case f of
Expand All @@ -515,8 +436,6 @@ simplifyApp resultTy f xs = case f of
CCFFIFun _ f' -> do
xs' <- dropSubst $ mapM toDataAtom xs
liftSimpAtom resultTy =<< topApp f' xs'
CCACase aCase -> forceACase aCase \f' -> simplifyApp (sink resultTy) f' (sink <$> xs)
CCTabLam _ -> error "not a function"
CCLiftSimp _ _ -> error "not a function"

simplifyTopFunApp :: Emits n => CAtomVar n -> [CAtom n] -> SimplifyM i n (CAtom n)
Expand Down Expand Up @@ -567,10 +486,6 @@ simplifyTabApp f x = case f of
resultTy <- typeOfTabApp fTy x
x' <- dropSubst $ toDataAtom x
liftSimpAtom resultTy =<< tabApp f'' x'
CCACase aCase -> forceACase aCase \f' -> simplifyTabApp f' (sink x)
CCTabLam (WithSubst s (PairE _ (Abs b ab))) -> do
x' <- dropSubst $ toDataAtom x
withSubst s $ extendSubst (b@>(SubstVal x')) $ substM ab
_ -> error "not a table"

simplifyIxDict :: Dict CoreIR i -> SimplifyM i o (SDict o)
Expand All @@ -580,8 +495,6 @@ simplifyIxDict (StuckDict _ stuck) = forceStuck stuck >>= \case
_ -> error "not a dict"
CCLiftSimp _ _ -> error "not a dict"
CCFun _ -> error "not a dict"
CCTabLam _ -> error "not a dict"
CCACase _ -> error "not implemented" -- TODO: consider what to do about this
simplifyIxDict (DictCon con) = case con of
IxFin n -> DictCon <$> IxRawFin <$> toDataAtomAssumeNoDecls n
IxRawFin n -> DictCon <$> IxRawFin <$> toDataAtomAssumeNoDecls n
Expand Down Expand Up @@ -641,41 +554,6 @@ simplifyLam (LamExpr bsTop body) = case bsTop of
SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body
return (LamExpr Empty body', Abs Empty recon)

data SplitDataNonData n = SplitDataNonData
{ dataTy :: Type SimpIR n
, nonDataTy :: Type CoreIR n
, toSplit :: forall i l . CAtom l -> SimplifyM i l (SAtom l, CAtom l)
, fromSplit :: forall i l . DExt n l => SAtom l -> CAtom l -> SimplifyM i l (CAtom l) }

-- bijection between that type and a (data, non-data) pair type.
splitDataComponents :: Type CoreIR n -> SimplifyM i n (SplitDataNonData n)
splitDataComponents = \case
TyCon (ProdType tys) -> do
splits <- mapM splitDataComponents tys
return $ SplitDataNonData
{ dataTy = TyCon $ ProdType $ map dataTy splits
, nonDataTy = TyCon $ ProdType $ map nonDataTy splits
, toSplit = \xProd -> do
xs <- getUnpackedReduced xProd
(ys, zs) <- unzip <$> forM (zip xs splits) \(x, split) -> toSplit split x
return (Con $ ProdCon ys, Con $ ProdCon zs)
, fromSplit = \xsProd ysProd -> do
xs <- getUnpackedReduced xsProd
ys <- getUnpackedReduced ysProd
zs <- forM (zip (zip xs ys) splits) \((x, y), split) -> fromSplit split x y
return $ Con $ ProdCon zs }
ty -> tryGetRepType ty >>= \case
Just repTy -> return $ SplitDataNonData
{ dataTy = repTy
, nonDataTy = UnitTy
, toSplit = \x -> (,UnitVal) <$> (dropSubst $ toDataAtomAssumeNoDecls x)
, fromSplit = \x _ -> liftSimpAtom (sink ty) x }
Nothing -> return $ SplitDataNonData
{ dataTy = UnitTy
, nonDataTy = ty
, toSplit = \x -> return (UnitVal, x)
, fromSplit = \_ x -> return x }

buildSimplifiedBlock
:: (forall o'. (Emits o', DExt o o') => SimplifyM i o' (CAtom o'))
-> SimplifyM i o (SimplifiedBlock o)
Expand Down Expand Up @@ -764,18 +642,11 @@ applyDictMethod resultTy d i methodArgs = case d of
simplifyHof :: Emits o => CType o -> Hof CoreIR i -> SimplifyM i o (CAtom o)
simplifyHof resultTy = \case
For d (IxType ixTy ixDict) lam -> do
(lam', Abs (UnaryNest bIx) recon) <- simplifyLam lam
(lam', CoerceReconAbs) <- simplifyLam lam
ixTy' <- getRepType ixTy
ixDict' <- simplifyIxDict ixDict
ans <- emitHof $ For d (IxType ixTy' ixDict') lam'
case recon of
CoerceRecon _ -> liftSimpAtom resultTy ans
LamRecon (Abs bsClosure reconResult) -> do
ab <- buildAbs noHint ixTy' \i -> do
xs <- unpackTelescope bsClosure =<< reduceTabApp (sink ans) (toAtom i)
applySubst (bIx@>Rename (atomVarName i) <.> bsClosure @@> map SubstVal xs) reconResult
TyCon (TabPi resultTy') <- return resultTy
mkStuck $ TabLam $ resultTy' `PairE` ab
liftSimpAtom resultTy ans
While body -> do
SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyExpr body
result <- emitHof $ While body'
Expand Down Expand Up @@ -948,29 +819,3 @@ instance RenameE ReconstructAtom
instance Pretty (ReconstructAtom n) where
pretty (CoerceRecon ty) = "Coercion reconstruction: " <> pretty ty
pretty (LamRecon ab) = "Reconstruction abs: " <> pretty ab

-- === GHC performance hacks ===

-- Note [Confuse GHC]
-- I can't explain this, but for some reason using this function in strategic
-- places makes GHC produce significantly better code. If we define
--
-- simplifyAtom = \case
-- ...
-- Con con -> traverse simplifyAtom con
-- ...
--
-- then GHC is reluctant to generate a fast-path worker function for simplifyAtom
-- that would return unboxed tuples, because (at least that's my guess) it's afraid
-- that it will have to allocate a reader closure for the traverse, which does not
-- get inlined. For some reason writing the `confuseGHC >>= \_ -> case atom of ...`
-- makes GHC do the right thing, i.e. generate unboxed worker + a tiny wrapper that
-- allocates -- a closure to be passed into traverse.
--
-- What's so special about this, I don't know. `return ()` is insufficient and doesn't
-- make the optimization go through. I'll just take the win for now...
--
-- NB: We should revise this whenever we upgrade to a newer GHC version.
confuseGHC :: SimplifyM i o (DistinctEvidence o)
confuseGHC = getDistinct
{-# INLINE confuseGHC #-}
Loading

0 comments on commit 3e0cbe9

Please sign in to comment.