Skip to content

Commit

Permalink
Remove unnecessary cases from Imp pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed May 1, 2024
1 parent 2ab0882 commit c0946f1
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 61 deletions.
3 changes: 2 additions & 1 deletion src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,8 @@ applyProjectionsReduced (p:ps) x = do
ProjectProduct i -> reduceProj i x'
UnwrapNewtype -> reduceUnwrap x'

mkBlock :: (EnvReader m, IRRep r) => ToExpr e r => Abs (Decls r) e n -> m n (Expr r n)
mkBlock :: (EnvReader m, IRRep r, ToExpr e r) => Abs (Decls r) e n -> m n (Expr r n)
mkBlock (Abs Empty expr) = return $ toExpr expr
mkBlock (Abs decls body) = do
let block = Abs decls (toExpr body)
effTy <- blockEffTy block
Expand Down
63 changes: 6 additions & 57 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,6 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
scalarArgs <- liftM toList $ mapM fromScalarAtom xs
results <- impCall f scalarArgs
restructureScalarOrPairType resultTy results
TabApp _ f' x' -> do
x <- substM x'
f <- atomToRepVal =<< substM f'
repValAtom =<< indexRepVal f x
Atom x -> substM x
PrimOp op -> toImpOp op
Case e alts (EffTy _ unitResultTy) -> do
Expand All @@ -319,6 +315,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
extendSubst (b @> SubstVal (sink xs)) $
void $ translateExpr body
return UnitVal
TabApp _ _ _ -> error "Unexpected `TabApp` in Imp pass."
TabCon _ _ -> error "Unexpected `TabCon` in Imp pass."
Project _ i x -> reduceProj i =<< substM x

Expand Down Expand Up @@ -357,17 +354,13 @@ toImpVectorOp = \case
val' <- fromScalarAtom val
emitInstr (IVectorBroadcast val' $ toIVectorType vty) >>= returnIExprVal
VectorIota vty -> emitInstr (IVectorIota $ toIVectorType vty) >>= returnIExprVal
VectorIdx tbl' i vty -> do
-- VectorIdx requires that tbl' have a scalar element type, which is
-- ultimately enforced by `Lower.getVectorType` barfing on non-scalars.
tbl <- atomToRepVal tbl'
repValAtom =<< vectorIndexRepVal tbl i vty
VectorSubref ref i vty -> do
refDest <- atomToDest ref
refi <- destToAtom <$> indexDest refDest i
refi' <- fromScalarAtom refi
resultVal <- castPtrToVectorType refi' (toIVectorType vty)
repValAtom $ RepVal (RefTy vty) (Leaf resultVal)
VectorIdx _ _ _ -> error "Unexpected VectorIdx in Imp pass"
where
returnIExprVal x = return $ toScalarAtom x

Expand Down Expand Up @@ -399,6 +392,7 @@ toImpMiscOp op = case op of
RepVal _ tree <- atomToRepVal x
repValAtom (RepVal resultTy tree)
GarbageVal resultTy -> buildGarbageVal resultTy
NewRef _ -> error "not implemented"
Select p x y -> do
BaseTy _ <- return $ getType x
returnIExprVal =<< emitInstr =<< (ISelect <$> fsa p <*> fsa x <*> fsa y)
Expand Down Expand Up @@ -491,7 +485,6 @@ type IdxNest r = Nest (IxBinder r)
data TypeCtxLayer (r::IR) (n::S) (l::S) where
TabCtx :: IxBinder r n l -> TypeCtxLayer r n l
DepPairCtx :: MaybeB (Binder r) n l -> TypeCtxLayer r n l
RefCtx :: TypeCtxLayer r n n

instance SinkableE Dest where
sinkingProofE = undefined
Expand Down Expand Up @@ -529,9 +522,6 @@ getElemTypeAndIdxStructure (LeafType ctxs baseTy) = case ctxs of
Just UnitB -> Just ixs
Nothing -> Nothing
(BoxedBuffer eltTy, ixs')
RefCtx -> (,Nothing) $ UnboxedValue $ hostPtrTy $ elemTypeToBaseType eltTy
where BufferType _ eltTy = getRefBufferType (LeafType rest baseTy)
where hostPtrTy ty = PtrType (CPU, ty)

allNothingBs :: Nest (MaybeB b) n l -> Maybe (UnitB n l)
allNothingBs Empty = Just UnitB
Expand Down Expand Up @@ -573,7 +563,7 @@ typeToTree tyTop = return $ go REmpty tyTop
go ctx (TyCon con) = case con of
BaseType b -> Leaf $ LeafType (unRNest ctx) b
TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy
RefType t -> go (RNest ctx RefCtx) t
RefType _ -> error "Unexpected ref type"
DepPairTy (DepPairType _ (b:>t1) (t2)) -> do
let tree1 = rec t1
let tree2 = go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2
Expand Down Expand Up @@ -607,7 +597,7 @@ valueToTree (RepVal tyTop valTop) = do
go ctx (TyCon ty) val = case ty of
BaseType b -> return $ Leaf $ LeafType (unRNest ctx) b
TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val
RefType t -> go (RNest ctx RefCtx) t val
RefType _ -> error "Unexpected ref type"
DepPairTy (DepPairType _ (b:>t1) (t2)) -> case val of
Branch [v1, v2] -> do
case allDepPairCtxs (unRNest ctx) of
Expand Down Expand Up @@ -754,10 +744,7 @@ atomToRepVal x = RepVal (getType x) <$> go x where
StuckProject i val -> do
Branch ts <- go =<< mkStuck val
return $ ts !! i
StuckTabApp f x' -> do
f' <- atomToRepVal =<< mkStuck f
RepVal _ t <- indexRepVal f' x'
return t
StuckTabApp _ _ -> error "unexpected tab app"

-- XXX: We used to have a function called `destToAtom` which loaded the value
-- from the dest. This version is not that. It just lifts a dest into an atom of
Expand Down Expand Up @@ -880,43 +867,6 @@ indexDest (Dest (TyCon (TabPi tabTy)) tree) i = do
indexDest _ _ = error "expected a reference to a table"
{-# INLINE indexDest #-}

-- TODO: de-dup with indexDest?
indexRepValParam :: Emits n
=> RepVal n -> SAtom n -> (SType n -> SType n)
-> (IExpr n -> SubstImpM i n (IExpr n))
-> SubstImpM i n (RepVal n)
indexRepValParam (RepVal (TyCon (TabPi tabTy)) vals) i tyFunc func = do
eltTy <- instantiate tabTy [i]
ord <- ordinalImp (tabIxType tabTy) i
leafTys <- typeToTree (toType tabTy)
vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do
BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy
offset <- computeOffsetImp ixStruct ord
ptr' <- impOffset ptr offset
-- we represent scalars by value, not by reference, so we do a load
-- if this is the last index in the table nest.
case ixStruct of
EmptyAbs (Nest _ Empty) -> func ptr' >>= load
_ -> func ptr'
-- `func` may have changed the types of the `vals'`. The caller must also
-- supply `tyFunc` to reflect that change in the SType.
return $ RepVal (tyFunc eltTy) vals'
indexRepValParam _ _ _ _ = error "expected table type"
{-# INLINE indexRepValParam #-}

indexRepVal :: Emits n => RepVal n -> SAtom n -> SubstImpM i n (RepVal n)
indexRepVal rep i = indexRepValParam rep i id return
{-# INLINE indexRepVal #-}

vectorIndexRepVal :: Emits n => RepVal n -> SAtom n -> SType n -> SubstImpM i n (RepVal n)
vectorIndexRepVal rep i vty =
-- Passing `const vty` here depends on knowing that `vectorIndexRepVal` is
-- only called on references of scalar base type, so that the give `vty` is,
-- actually, the type of the result of the indexing operation.
indexRepValParam rep i (const vty) action where
action ptr = castPtrToVectorType ptr (toIVectorType vty)
{-# INLINE vectorIndexRepVal #-}

projectDest :: Int -> Dest n -> Dest n
projectDest i (Dest (TyCon (ProdType tys)) (Branch ds)) =
Dest (tys!!i) (ds!!i)
Expand Down Expand Up @@ -1372,7 +1322,6 @@ instance Pretty (TypeCtxLayer SimpIR n l) where
TabCtx (PairB _ b) -> pretty b
DepPairCtx (RightB UnitB) -> "dep-pair-instantiated"
DepPairCtx (LeftB b) -> "dep-pair" <+> pretty b
RefCtx -> "refctx"

-- See Note [Confuse GHC] from Simplify.hs
confuseGHC :: EnvReader m => m n (DistinctEvidence n)
Expand Down
1 change: 1 addition & 0 deletions src/lib/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ whenOpt x act = getConfig <&> optLevel >>= \case
Optimize -> act x

evalBlock :: (Topper m, Mut n) => TopBlock CoreIR n -> m n (CAtom n)
evalBlock (TopLam _ _ (LamExpr Empty (Atom result))) = return result
evalBlock typed = do
SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock typed
opt <- simpOptimizations simp
Expand Down
4 changes: 1 addition & 3 deletions src/lib/Types/Source.hs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ type CSDeclW = WithSrcs CSDecl
type SourceNameW = WithSrc SourceName

type BracketedGroup = WithSrcs [GroupW]
-- optional arrow, effects, result type
type ExplicitParams = BracketedGroup
type GivenClause = (BracketedGroup, Maybe BracketedGroup) -- implicits, classes
type WithClause = BracketedGroup -- no classes because we don't want to carry class dicts at runtime
Expand Down Expand Up @@ -644,7 +643,7 @@ data PrimName =
| UWhile | ULinearize | UTranspose
| UProjNewtype | UExplicitApply | UMonoLiteral
| UIndexRef | UApplyMethod Int
| UNat | UNatCon | UFin | UEffectRowKind
| UNat | UNatCon | UFin
| UTuple -- overloaded for type constructor and data constructor, resolved in inference
deriving (Show, Eq, Generic)

Expand Down Expand Up @@ -703,7 +702,6 @@ primNames = M.fromList
, ("PtrPtr" , baseTy $ ptrTy $ ptrTy $ Scalar Word8Type)
, ("Nat" , UNat)
, ("Fin" , UFin)
, ("EffKind" , UEffectRowKind)
, ("NatCon" , UNatCon)
, ("Ref" , UPrimTC $ P.RefType)
, ("indexRef" , UIndexRef)
Expand Down

0 comments on commit c0946f1

Please sign in to comment.