Skip to content

Commit

Permalink
Represent the huge pile of first-order ops as a simple functor.
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed May 15, 2024
1 parent 7f895d4 commit d0334fc
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 605 deletions.
1 change: 0 additions & 1 deletion dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ library
, Types.Core
, Types.Imp
, Types.Primitives
, Types.OpNames
, Types.Source
, Types.Top
, QueryType
Expand Down
4 changes: 2 additions & 2 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ import Err
import Name
import PPrint
import Types.Primitives
import qualified Types.Source as S
import Types.Source
import qualified Types.OpNames as P
import Util

-- === Converting concrete syntax to abstract syntax ===
Expand Down Expand Up @@ -521,7 +521,7 @@ charExpr :: Char -> (UExpr' VoidS)
charExpr c = ULit $ Word8Lit $ fromIntegral $ fromEnum c

unitExpr :: SrcId -> UExpr VoidS
unitExpr sid = WithSrcE sid $ UPrim (UCon $ P.ProdCon) []
unitExpr sid = WithSrcE sid $ UPrim (UCon $ S.ProdCon) []

-- === Builders ===

Expand Down
9 changes: 4 additions & 5 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import QueryType
import Types.Core
import Types.Imp
import Types.Primitives
import Types.Source
import Types.Source hiding (TCName (..), ConName (..))
import Types.Top
import Util (enumerate, transitiveClosureM, bindM2, toSnocList, popList)

Expand All @@ -51,7 +51,6 @@ instance ToExpr (Expr r) r where toExpr = id
instance ToExpr (Atom r) r where toExpr = Atom
instance ToExpr (Con r) r where toExpr = Atom . Con
instance ToExpr (AtomVar r) r where toExpr = toExpr . toAtom
instance IRRep r => ToExpr (MemOp r) r where toExpr op = PrimOp (getType op) (MemOp op)
instance ToExpr (TypedHof r) r where toExpr = Hof

-- === Ordinary (local) builder class ===
Expand Down Expand Up @@ -102,7 +101,7 @@ emitBinOp :: (Builder r m, Emits n) => BinOp -> Atom r n -> Atom r n -> m n (Ato
emitBinOp op x y = emit $ PrimOp resultTy $ BinOp op x y
where resultTy = TyCon $ BaseType $ typeBinOp op $ getTypeBaseType x

emitRefOp :: (Builder r m, Emits n) => Atom r n -> RefOp r n -> m n (Atom r n)
emitRefOp :: (Builder r m, Emits n) => Atom r n -> RefOp r (Atom r n) -> m n (Atom r n)
emitRefOp ref op = undefined

emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n)
Expand Down Expand Up @@ -1092,11 +1091,11 @@ naryIndexRef ref is = foldM indexRef ref is

ptrOffset :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
ptrOffset x (IdxRepVal 0) = return x
ptrOffset x i = emit $ PtrOffset x i
ptrOffset x i = undefined -- emit $ PtrOffset x i
{-# INLINE ptrOffset #-}

unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n)
unsafePtrLoad x = emit . PtrLoad =<< sinkM x
unsafePtrLoad x = undefined -- emit . PtrLoad =<< sinkM x

mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (Expr r n)
mkIndexRef ref i = do
Expand Down
26 changes: 1 addition & 25 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,6 @@ visitAlt (Abs b body) = do
LamExpr (UnaryNest b') body' -> return $ Abs b' body'
_ -> error "not an alt"

traverseOpTerm
:: (GenericOp e, Visitor m r i o, OpConst e r ~ OpConst e r)
=> e r i -> m (e r o)
traverseOpTerm e = traverseOp e visitGeneric visitGeneric

visitTypeDefault
:: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m)
=> Type r i -> m i o (Type r o)
Expand Down Expand Up @@ -392,22 +387,13 @@ instance IRRep r => VisitGeneric (Expr r) r where
return $ Case x' alts' effTy'
Atom x -> Atom <$> visitGeneric x
TabCon t xs -> TabCon <$> visitGeneric t <*> mapM visitGeneric xs
PrimOp t op -> PrimOp <$> visitGeneric t <*> visitGeneric op
PrimOp t op -> PrimOp <$> visitGeneric t <*> mapM visitGeneric op
App et fAtom xs -> App <$> visitGeneric et <*> visitGeneric fAtom <*> mapM visitGeneric xs
ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs
Project t i x -> Project <$> visitGeneric t <*> pure i <*> visitGeneric x
Unwrap t x -> Unwrap <$> visitGeneric t <*> visitGeneric x
Hof op -> Hof <$> visitGeneric op

instance IRRep r => VisitGeneric (PrimOp r) r where
visitGeneric = \case
UnOp op x -> UnOp op <$> visitGeneric x
BinOp op x y -> BinOp op <$> visitGeneric x <*> visitGeneric y
MemOp op -> MemOp <$> visitGeneric op
VectorOp op -> VectorOp <$> visitGeneric op
MiscOp op -> MiscOp <$> visitGeneric op
RefOp r op -> RefOp <$> visitGeneric r <*> traverseOp op visitGeneric visitGeneric

instance IRRep r => VisitGeneric (TypedHof r) r where
visitGeneric (TypedHof eff hof) = TypedHof <$> visitGeneric eff <*> visitGeneric hof

Expand Down Expand Up @@ -536,10 +522,6 @@ instance IRRep r => VisitGeneric (Dict r) r where
StuckDict ty s -> fromJust <$> toMaybeDict <$> visitGeneric (Stuck ty s)
DictCon con -> DictCon <$> visitGeneric con

instance VisitGeneric (MiscOp r) r where visitGeneric = traverseOpTerm
instance VisitGeneric (VectorOp r) r where visitGeneric = traverseOpTerm
instance VisitGeneric (MemOp r) r where visitGeneric = traverseOpTerm

-- === SubstE/SubstB instances ===
-- These live here, as orphan instances, because we normalize as we substitute.

Expand Down Expand Up @@ -631,14 +613,8 @@ instance IRRep r => SubstE AtomSubstVal (Hof r)
instance IRRep r => SubstE AtomSubstVal (TyCon r)
instance IRRep r => SubstE AtomSubstVal (DictCon r)
instance IRRep r => SubstE AtomSubstVal (Con r)
instance IRRep r => SubstE AtomSubstVal (MiscOp r)
instance IRRep r => SubstE AtomSubstVal (VectorOp r)
instance IRRep r => SubstE AtomSubstVal (MemOp r)
instance IRRep r => SubstE AtomSubstVal (PrimOp r)
instance IRRep r => SubstE AtomSubstVal (RefOp r)
instance IRRep r => SubstE AtomSubstVal (EffTy r)
instance IRRep r => SubstE AtomSubstVal (Expr r)
instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r)
instance SubstE AtomSubstVal InstanceBody
instance SubstE AtomSubstVal DictType
instance IRRep r => SubstE AtomSubstVal (LamExpr r)
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import Types.Core
import Types.Top
import Types.Imp
import Types.Primitives
import Types.Source
import Types.Source hiding (ProdCon, ProdType)

-- === Typeclasses for monads ===

Expand Down
18 changes: 9 additions & 9 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,10 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
Hof hof -> toImpTypedHof hof

toImpRefOp :: Emits o
=> SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o)
=> SAtom i -> RefOp SimpIR (SAtom i) -> SubstImpM i o (SAtom o)
toImpRefOp refDest' m = do
refDest <- atomToDest =<< substM refDest'
substM m >>= \case
mapM substM m >>= \case
MPut x -> storeAtom refDest x >> return UnitVal
MGet -> do
Dest resultTy _ <- return refDest
Expand All @@ -335,23 +335,23 @@ toImpRefOp refDest' m = do
IndexRef i -> destToAtom <$> indexDest refDest i
ProjRef ~(ProjectProduct i) -> return $ destToAtom $ projectDest i refDest

toImpOp :: forall i o . Emits o => SType i -> PrimOp SimpIR i -> SubstImpM i o (SAtom o)
toImpOp :: forall i o . Emits o => SType i -> PrimOp SimpIR (SAtom i) -> SubstImpM i o (SAtom o)
toImpOp resultTy op = case op of
RefOp refDest eff -> toImpRefOp refDest eff
BinOp binOp x y -> returnIExprVal =<< emitInstr =<< (IBinOp binOp <$> fsa x <*> fsa y)
UnOp unOp x -> returnIExprVal =<< emitInstr =<< (IUnOp unOp <$> fsa x)
MemOp op' -> toImpMemOp =<< substM op'
MemOp op' -> toImpMemOp =<< mapM substM op'
MiscOp op' -> do
resultTy' <- substM resultTy
toImpMiscOp resultTy' =<< substM op'
toImpMiscOp resultTy' =<< mapM substM op'
VectorOp op' -> do
resultTy' <- substM resultTy
toImpVectorOp resultTy' =<< substM op'
toImpVectorOp resultTy' =<< mapM substM op'
where
fsa x = substM x >>= fromScalarAtom
returnIExprVal x = return $ toScalarAtom x

toImpVectorOp :: Emits o => SType o -> VectorOp SimpIR o -> SubstImpM i o (SAtom o)
toImpVectorOp :: Emits o => SType o -> VectorOp SimpIR (SAtom o) -> SubstImpM i o (SAtom o)
toImpVectorOp vty = \case
VectorBroadcast val -> do
val' <- fromScalarAtom val
Expand All @@ -373,7 +373,7 @@ castPtrToVectorType ptr vty = do
let PtrType (addrSpace, _) = getIType ptr
cast ptr (PtrType (addrSpace, vty))

toImpMiscOp :: forall i o . Emits o => SType o -> MiscOp SimpIR o -> SubstImpM i o (SAtom o)
toImpMiscOp :: forall i o . Emits o => SType o -> MiscOp SimpIR (SAtom o) -> SubstImpM i o (SAtom o)
toImpMiscOp resultTy op = case op of
ThrowError -> do
emitStatement IThrowError
Expand Down Expand Up @@ -424,7 +424,7 @@ toImpMiscOp resultTy op = case op of
fsa = fromScalarAtom
returnIExprVal x = return $ toScalarAtom x

toImpMemOp :: forall i o . Emits o => MemOp SimpIR o -> SubstImpM i o (SAtom o)
toImpMemOp :: forall i o . Emits o => MemOp SimpIR (SAtom o) -> SubstImpM i o (SAtom o)
toImpMemOp op = case op of
IOAlloc n -> do
n' <- fsa n
Expand Down
42 changes: 21 additions & 21 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ import QueryType
import Types.Core
import Types.Imp
import Types.Primitives
import Types.Source
import qualified Types.Source as S
import Types.Source hiding (ConName (..), TCName (..))
import Types.Top
import qualified Types.OpNames as P
import Util hiding (group)

-- === Top-level interface ===
Expand Down Expand Up @@ -1030,13 +1030,13 @@ matchPrimApp = \case
UBaseType b -> \case ~[] -> return $ toAtomR $ BaseType b
UNatCon -> \case ~[x] -> return $ toAtom $ NewtypeCon NatCon x
UPrimTC tc -> case tc of
P.ProdType -> \ts -> return $ toAtom $ ProdType $ map (fromJust . toMaybeType) ts
P.SumType -> \ts -> return $ toAtom $ SumType $ map (fromJust . toMaybeType) ts
P.RefType -> \case ~[h, a] -> undefined -- return $ toAtom $ RefType h (fromJust $ toMaybeType a)
P.TypeKind -> \case ~[] -> return $ toAtom $ Kind $ TypeKind
S.ProdType -> \ts -> return $ toAtom $ ProdType $ map (fromJust . toMaybeType) ts
S.SumType -> \ts -> return $ toAtom $ SumType $ map (fromJust . toMaybeType) ts
S.RefType -> \case ~[h, a] -> undefined -- return $ toAtom $ RefType h (fromJust $ toMaybeType a)
S.TypeKind -> \case ~[] -> return $ toAtom $ Kind $ TypeKind
UCon con -> case con of
P.ProdCon -> \xs -> return $ toAtom $ ProdCon xs
P.SumCon _ -> error "not supported"
S.ProdCon -> \xs -> return $ toAtom $ ProdCon xs
S.SumCon _ -> error "not supported"
-- UMiscOp op -> \x -> emit =<< MiscOp <$> matchGenericOp op x
-- UMemOp op -> \x -> emit =<< MemOp <$> matchGenericOp op x
UBinOp op -> \case ~[x, y] -> emitBinOp op x y
Expand All @@ -1059,19 +1059,19 @@ matchPrimApp = \case
ExplicitCoreLam (UnaryNest b) body <- return x
return $ UnaryLamExpr b body

matchGenericOp :: GenericOp op => OpConst op CoreIR -> [CAtom n] -> InfererM i n (op CoreIR n)
matchGenericOp op xs = do
(tyArgs, dataArgs) <- partitionEithers <$> forM xs \x -> do
case getType x of
TyCon (Kind TypeKind) -> do
Just x' <- return $ toMaybeType x
return $ Left x'
_ -> return $ Right x
let tyArgs' = case tyArgs of
[] -> Nothing
[t] -> Just t
_ -> error "Expected at most one type arg"
return $ fromJust $ toOp $ GenericOpRep op tyArgs' dataArgs
-- matchGenericOp :: GenericOp op => OpConst op CoreIR -> [CAtom n] -> InfererM i n (op CoreIR n)
-- matchGenericOp op xs = do
-- (tyArgs, dataArgs) <- partitionEithers <$> forM xs \x -> do
-- case getType x of
-- TyCon (Kind TypeKind) -> do
-- Just x' <- return $ toMaybeType x
-- return $ Left x'
-- _ -> return $ Right x
-- let tyArgs' = case tyArgs of
-- [] -> Nothing
-- [t] -> Just t
-- _ -> error "Expected at most one type arg"
-- return $ fromJust $ toOp $ GenericOpRep op tyArgs' dataArgs

pattern ExplicitCoreLam :: Nest CBinder n l -> CExpr l -> CAtom n
pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body)))
Expand Down
2 changes: 1 addition & 1 deletion src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import Data.Functor ((<&>))

import Types.Primitives
import Types.Core
import Types.Source
import Types.Source hiding (TCName (..))
import Types.Top
import Types.Imp
import IRVariants
Expand Down
78 changes: 39 additions & 39 deletions src/lib/QueryTypePure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ getTypeBaseType e = case getType e of
TyCon (BaseType b) -> b
ty -> error $ "Expected a base type. Got: " ++ show ty

instance IRRep r => HasType r (MemOp r) where
getType = \case
IOAlloc _ -> PtrTy (CPU, Scalar Word8Type)
IOFree _ -> UnitTy
PtrOffset arr _ -> getType arr
PtrLoad ptr -> do
let PtrTy (_, t) = getType ptr
toType $ BaseType t
PtrStore _ _ -> UnitTy
-- instance IRRep r => HasType r (MemOp r) where
-- getType = \case
-- IOAlloc _ -> PtrTy (CPU, Scalar Word8Type)
-- IOFree _ -> UnitTy
-- PtrOffset arr _ -> getType arr
-- PtrLoad ptr -> do
-- let PtrTy (_, t) = getType ptr
-- toType $ BaseType t
-- PtrStore _ _ -> UnitTy

rawStrType :: IRRep r => Type r n
rawStrType = case newName "n" of
Expand Down Expand Up @@ -216,7 +216,7 @@ instance IRRep r => HasEffects (Expr r) r where
Case _ _ (EffTy effs _) -> effs
TabCon _ _ -> Pure
ApplyMethod (EffTy eff _) _ _ _ -> eff
PrimOp _ primOp -> getEffects primOp
-- PrimOp _ primOp -> getEffects primOp
Project _ _ _ -> Pure
Unwrap _ _ -> Pure
Hof (TypedHof (EffTy eff _) _) -> eff
Expand All @@ -225,32 +225,32 @@ instance IRRep r => HasEffects (DeclBinding r) r where
getEffects (DeclBinding _ expr) = getEffects expr
{-# INLINE getEffects #-}

instance IRRep r => HasEffects (PrimOp r) r where
getEffects = \case
UnOp _ _ -> Pure
BinOp _ _ _ -> Pure
VectorOp _ -> Pure
MemOp op -> case op of
IOAlloc _ -> Effectful
IOFree _ -> Effectful
PtrLoad _ -> Effectful
PtrStore _ _ -> Effectful
PtrOffset _ _ -> Pure
MiscOp op -> case op of
Select _ _ _ -> Pure
ThrowError -> Pure
CastOp _ -> Pure
UnsafeCoerce _ -> Pure
GarbageVal -> Pure
BitcastOp _ -> Pure
SumTag _ -> Pure
ToEnum _ -> Pure
OutputStream -> Pure
ShowAny _ -> Pure
ShowScalar _ -> Pure
RefOp _ m -> case m of
MGet -> Effectful
MPut _ -> Effectful
IndexRef _ -> Pure
ProjRef _ -> Pure
{-# INLINE getEffects #-}
-- instance IRRep r => HasEffects (PrimOp r) r where
-- getEffects = \case
-- UnOp _ _ -> Pure
-- BinOp _ _ _ -> Pure
-- VectorOp _ -> Pure
-- MemOp op -> case op of
-- IOAlloc _ -> Effectful
-- IOFree _ -> Effectful
-- PtrLoad _ -> Effectful
-- PtrStore _ _ -> Effectful
-- PtrOffset _ _ -> Pure
-- MiscOp op -> case op of
-- Select _ _ _ -> Pure
-- ThrowError -> Pure
-- CastOp _ -> Pure
-- UnsafeCoerce _ -> Pure
-- GarbageVal -> Pure
-- BitcastOp _ -> Pure
-- SumTag _ -> Pure
-- ToEnum _ -> Pure
-- OutputStream -> Pure
-- ShowAny _ -> Pure
-- ShowScalar _ -> Pure
-- RefOp _ m -> case m of
-- MGet -> Effectful
-- MPut _ -> Effectful
-- IndexRef _ -> Pure
-- ProjRef _ -> Pure
-- {-# INLINE getEffects #-}
Loading

0 comments on commit d0334fc

Please sign in to comment.