Skip to content

Commit

Permalink
Remove one-off type arguments from primops and make a result type uni…
Browse files Browse the repository at this point in the history
…formly available instead
  • Loading branch information
dougalm committed May 15, 2024
1 parent 3853784 commit 7f895d4
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 245 deletions.
10 changes: 5 additions & 5 deletions dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ library
, Builder
, CUDA
, CheapReduction
, CheckType
-- , CheckType
, ConcreteSyntax
, Core
, DPS
Expand All @@ -56,9 +56,9 @@ library
, Inference
-- , Inline
, IRVariants
, JAX.Concrete
, JAX.Rename
, JAX.ToSimp
-- , JAX.Concrete
-- , JAX.Rename
-- , JAX.ToSimp
, LLVM.Link
, LLVM.Compile
, LLVM.CUDA
Expand All @@ -75,7 +75,7 @@ library
, PPrint
, RawName
, Runtime
, RuntimePrint
-- , RuntimePrint
, Serialize
, Simplify
, Subst
Expand Down
54 changes: 39 additions & 15 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ import Util (enumerate, transitiveClosureM, bindM2, toSnocList, popList)
peepholeExpr :: a -> a
peepholeExpr = id

-- === ToExpr ===

class ToExpr (e::E) (r::IR) | e -> r where
toExpr :: e n -> Expr r n

instance ToExpr (Expr r) r where toExpr = id
instance ToExpr (Atom r) r where toExpr = Atom
instance ToExpr (Con r) r where toExpr = Atom . Con
instance ToExpr (AtomVar r) r where toExpr = toExpr . toAtom
instance IRRep r => ToExpr (MemOp r) r where toExpr op = PrimOp (getType op) (MemOp op)
instance ToExpr (TypedHof r) r where toExpr = Hof

-- === Ordinary (local) builder class ===

class (EnvReader m, Fallible1 m, IRRep r)
Expand Down Expand Up @@ -81,6 +93,18 @@ emit e = case toExpr e of
return $ toAtom v
{-# INLINE emit #-}

emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n)
emitUnOp op x = emit $ PrimOp resultTy $ UnOp op x
where resultTy = TyCon $ BaseType $ typeUnOp op $ getTypeBaseType x


emitBinOp :: (Builder r m, Emits n) => BinOp -> Atom r n -> Atom r n -> m n (Atom r n)
emitBinOp op x y = emit $ PrimOp resultTy $ BinOp op x y
where resultTy = TyCon $ BaseType $ typeBinOp op $ getTypeBaseType x

emitRefOp :: (Builder r m, Emits n) => Atom r n -> RefOp r n -> m n (Atom r n)
emitRefOp ref op = undefined

emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n)
emitToVar expr = emit expr >>= \case
Stuck _ (Var v) -> return v
Expand Down Expand Up @@ -823,7 +847,7 @@ maybeTangentType' ty = case ty of
addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n)
addTangent x y = do
case getTyCon x of
BaseType (Scalar _) -> emit $ BinOp FAdd x y
BaseType (Scalar _) -> emitBinOp FAdd x y
ProdType _ -> do
xs <- getUnpacked x
ys <- getUnpacked y
Expand Down Expand Up @@ -855,22 +879,22 @@ symbolicTangentNonZero val = do
-- === builder versions of common local ops ===

fadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
fadd x y = emit $ BinOp FAdd x y
fadd x y = emitBinOp FAdd x y

fsub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
fsub x y = emit $ BinOp FSub x y
fsub x y = emitBinOp FSub x y

fmul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
fmul x y = emit $ BinOp FMul x y
fmul x y = emitBinOp FMul x y

fdiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
fdiv x y = emit $ BinOp FDiv x y
fdiv x y = emitBinOp FDiv x y

iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
iadd x y = emit $ BinOp IAdd x y
iadd x y = emitBinOp IAdd x y

imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
imul x y = emit $ BinOp IMul x y
imul x y = emitBinOp IMul x y

fLitLike :: Double -> SAtom n -> SAtom n
fLitLike x t = case getTyCon t of
Expand All @@ -893,7 +917,7 @@ getProjRef :: (Builder r m, Emits n) => Projection -> Atom r n -> m n (Atom r n)
getProjRef i r = emit =<< mkProjRef r i

newUninitializedRef :: (SBuilder m, Emits o) => SType o -> m o (SAtom o)
newUninitializedRef ty = emit $ NewRef ty
newUninitializedRef ty = emit $ PrimOp ty $ MiscOp NewRef

-- XXX: getUnpacked must reduce its argument to enforce the invariant that
-- ProjectElt atoms are always fully reduced (to avoid type errors between two
Expand Down Expand Up @@ -1068,21 +1092,21 @@ naryIndexRef ref is = foldM indexRef ref is

ptrOffset :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n)
ptrOffset x (IdxRepVal 0) = return x
ptrOffset x i = emit $ MemOp $ PtrOffset x i
ptrOffset x i = emit $ PtrOffset x i
{-# INLINE ptrOffset #-}

unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n)
unsafePtrLoad x = emit . MemOp . PtrLoad =<< sinkM x
unsafePtrLoad x = emit . PtrLoad =<< sinkM x

mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (PrimOp r n)
mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (Expr r n)
mkIndexRef ref i = do
resultTy <- typeOfIndexRef (getType ref) i
return $ RefOp ref $ IndexRef resultTy i
return $ PrimOp resultTy $ RefOp ref $ IndexRef i

mkProjRef :: (EnvReader m, IRRep r) => Atom r n -> Projection -> m n (PrimOp r n)
mkProjRef :: (EnvReader m, IRRep r) => Atom r n -> Projection -> m n (Expr r n)
mkProjRef ref i = do
resultTy <- typeOfProjRef (getType ref) i
return $ RefOp ref $ ProjRef resultTy i
return $ PrimOp resultTy $ RefOp ref $ ProjRef i

-- === index set type class ===

Expand Down Expand Up @@ -1127,7 +1151,7 @@ emitIf :: (Emits n, ScopableBuilder r m)
-> (forall l. (Emits l, DExt n l) => m l (Atom r l))
-> m n (Atom r n)
emitIf predicate resultTy trueCase falseCase = do
predicate' <- emit $ ToEnum (TyCon (SumType [UnitTy, UnitTy])) predicate
predicate' <- emit $ PrimOp (TyCon (SumType [UnitTy, UnitTy])) $ MiscOp (ToEnum predicate)
buildCase predicate' resultTy \i _ ->
case i of
0 -> falseCase
Expand Down
6 changes: 3 additions & 3 deletions src/lib/CheapReduction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ reduceExprM = \case
withInstantiated def args \(PairE _ (InstanceBody _ methods)) -> do
reduceApp (methods !! i) explicitArgs'
_ -> empty
PrimOp (MiscOp (CastOp ty' val')) -> do
PrimOp ty' (MiscOp (CastOp val')) -> do
ty <- substM ty'
val <- substM val'
case (ty, val) of
Expand All @@ -124,7 +124,7 @@ reduceExprM = \case
TopApp _ _ _ -> empty
Case _ _ _ -> empty
TabCon _ _ -> empty
PrimOp _ -> empty
PrimOp _ _ -> empty

reduceApp :: CAtom i -> [CAtom o] -> ReducerM i o (CAtom o)
reduceApp f xs = do
Expand Down Expand Up @@ -392,7 +392,7 @@ instance IRRep r => VisitGeneric (Expr r) r where
return $ Case x' alts' effTy'
Atom x -> Atom <$> visitGeneric x
TabCon t xs -> TabCon <$> visitGeneric t <*> mapM visitGeneric xs
PrimOp op -> PrimOp <$> visitGeneric op
PrimOp t op -> PrimOp <$> visitGeneric t <*> visitGeneric op
App et fAtom xs -> App <$> visitGeneric et <*> visitGeneric fAtom <*> mapM visitGeneric xs
ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs
Project t i x -> Project <$> visitGeneric t <*> pure i <*> visitGeneric x
Expand Down
4 changes: 2 additions & 2 deletions src/lib/DPS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ storeDest :: Emits o => Dest o -> SAtom o -> DestM i o ()
storeDest dest val = do
RefTy (TyCon tycon) <- return $ getType dest
case tycon of
BaseType _ -> void $ emit $ RefOp dest $ MPut val
BaseType _ -> undefined -- void $ emit $ RefOp dest $ MPut val

-- The dps pass carries a non-type-preserving substitution in which arrays are
-- replaced with refs to arrays. So it's incorrect to directly apply the
Expand Down Expand Up @@ -121,7 +121,7 @@ dpsExpr maybeDest expr = case expr of
return UnitVal
Atom x -> lowerAtom x >>= returnResult
TabCon _ _ -> undefined
PrimOp _ -> undefined
PrimOp _ _ -> undefined
Project _ _ _ -> undefined

where
Expand Down
63 changes: 33 additions & 30 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import qualified Control.Monad.State.Strict as MTL

import Builder
import CheapReduction
import CheckType (CheckableE (..))
-- import CheckType (CheckableE (..))
import Core
import Err
import IRVariants
Expand Down Expand Up @@ -289,7 +289,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of
results <- impCall f scalarArgs
restructureScalarOrPairType resultTy results
Atom x -> substM x
PrimOp op -> toImpOp op
PrimOp ty op -> toImpOp ty op
Case e alts (EffTy _ unitResultTy) -> do
e' <- substM e
case unitResultTy of
Expand Down Expand Up @@ -332,34 +332,38 @@ toImpRefOp refDest' m = do
-- than to go through a general purpose atom.
storeAtom dest =<< loadAtom refDest
loadAtom dest
IndexRef _ i -> destToAtom <$> indexDest refDest i
ProjRef _ ~(ProjectProduct i) -> return $ destToAtom $ projectDest i refDest
IndexRef i -> destToAtom <$> indexDest refDest i
ProjRef ~(ProjectProduct i) -> return $ destToAtom $ projectDest i refDest

toImpOp :: forall i o . Emits o => PrimOp SimpIR i -> SubstImpM i o (SAtom o)
toImpOp op = case op of
toImpOp :: forall i o . Emits o => SType i -> PrimOp SimpIR i -> SubstImpM i o (SAtom o)
toImpOp resultTy op = case op of
RefOp refDest eff -> toImpRefOp refDest eff
BinOp binOp x y -> returnIExprVal =<< emitInstr =<< (IBinOp binOp <$> fsa x <*> fsa y)
UnOp unOp x -> returnIExprVal =<< emitInstr =<< (IUnOp unOp <$> fsa x)
MemOp op' -> toImpMemOp =<< substM op'
MiscOp op' -> toImpMiscOp =<< substM op'
VectorOp op' -> toImpVectorOp =<< substM op'
MiscOp op' -> do
resultTy' <- substM resultTy
toImpMiscOp resultTy' =<< substM op'
VectorOp op' -> do
resultTy' <- substM resultTy
toImpVectorOp resultTy' =<< substM op'
where
fsa x = substM x >>= fromScalarAtom
returnIExprVal x = return $ toScalarAtom x

toImpVectorOp :: Emits o => VectorOp SimpIR o -> SubstImpM i o (SAtom o)
toImpVectorOp = \case
VectorBroadcast val vty -> do
toImpVectorOp :: Emits o => SType o -> VectorOp SimpIR o -> SubstImpM i o (SAtom o)
toImpVectorOp vty = \case
VectorBroadcast val -> do
val' <- fromScalarAtom val
emitInstr (IVectorBroadcast val' $ toIVectorType vty) >>= returnIExprVal
VectorIota vty -> emitInstr (IVectorIota $ toIVectorType vty) >>= returnIExprVal
VectorSubref ref i vty -> do
VectorIota -> emitInstr (IVectorIota $ toIVectorType vty) >>= returnIExprVal
VectorSubref ref i -> do
refDest <- atomToDest ref
refi <- destToAtom <$> indexDest refDest i
refi' <- fromScalarAtom refi
resultVal <- castPtrToVectorType refi' (toIVectorType vty)
repValAtom $ RepVal (RefTy vty) (Leaf resultVal)
VectorIdx _ _ _ -> error "Unexpected VectorIdx in Imp pass"
VectorIdx _ _ -> error "Unexpected VectorIdx in Imp pass"
where
returnIExprVal x = return $ toScalarAtom x

Expand All @@ -369,29 +373,29 @@ castPtrToVectorType ptr vty = do
let PtrType (addrSpace, _) = getIType ptr
cast ptr (PtrType (addrSpace, vty))

toImpMiscOp :: forall i o . Emits o => MiscOp SimpIR o -> SubstImpM i o (SAtom o)
toImpMiscOp op = case op of
ThrowError resultTy -> do
toImpMiscOp :: forall i o . Emits o => SType o -> MiscOp SimpIR o -> SubstImpM i o (SAtom o)
toImpMiscOp resultTy op = case op of
ThrowError -> do
emitStatement IThrowError
buildGarbageVal resultTy
CastOp destTy x -> do
CastOp x -> do
BaseTy _ <- return $ getType x
BaseTy bt <- return destTy
BaseTy bt <- return resultTy
x' <- fsa x
returnIExprVal =<< cast x' bt
BitcastOp destTy x -> do
BaseTy bt <- return destTy
BitcastOp x -> do
BaseTy bt <- return resultTy
returnIExprVal =<< emitInstr =<< (IBitcastOp bt <$> fsa x)
UnsafeCoerce resultTy x -> do
UnsafeCoerce x -> do
srcTy <- return $ getType x
srcRep <- getRepBaseTypes srcTy
destRep <- getRepBaseTypes resultTy
assertEq srcRep destRep $
"representation types don't match: " ++ pprint srcRep ++ " != " ++ pprint destRep
RepVal _ tree <- atomToRepVal x
repValAtom (RepVal resultTy tree)
GarbageVal resultTy -> buildGarbageVal resultTy
NewRef _ -> error "not implemented"
GarbageVal -> buildGarbageVal resultTy
NewRef -> error "not implemented"
Select p x y -> do
BaseTy _ <- return $ getType x
returnIExprVal =<< emitInstr =<< (ISelect <$> fsa p <*> fsa x <*> fsa y)
Expand All @@ -401,15 +405,14 @@ toImpMiscOp op = case op of
RepVal _ (Branch (tag:_)) <- return dRepVal
return $ toAtom $ RepVal (TagRepTy :: SType o) tag
_ -> error $ "Not a data constructor: " ++ pprint con
ToEnum ty i -> case ty of
ToEnum i -> case resultTy of
TyCon (SumType cases) -> do
i' <- fromScalarAtom i
return $ toAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases
_ -> error $ "Not an enum: " ++ pprint ty
return $ toAtom $ RepVal resultTy $ Branch $ Leaf i' : map (const (Branch [])) cases
_ -> error $ "Not an enum: " ++ pprint resultTy
OutputStream -> returnIExprVal =<< emitInstr IOutputStream
ShowAny _ -> error "Shouldn't have ShowAny in simplified IR"
ShowScalar x -> do
resultTy <- return $ getType $ PrimOp $ MiscOp op
Dest (PairTy sizeTy tabTy) (Branch [sizeTree, tabTree@(Leaf tabPtr)]) <- allocDest resultTy
xScalar <- fromScalarAtom x
size <- emitInstr $ IShowScalar tabPtr xScalar
Expand Down Expand Up @@ -1234,8 +1237,8 @@ impInstrTypes instr = case instr of
IShowScalar _ _ -> return [Scalar Word32Type]
where hostPtrTy ty = PtrType (CPU, ty)

instance CheckableE SimpIR ImpFunction where
checkE = renameM -- TODO
-- instance CheckableE SimpIR ImpFunction where
-- checkE = renameM -- TODO

-- TODO: Don't use Core Envs for Imp!
instance BindsEnv ImpDecl where
Expand Down
24 changes: 8 additions & 16 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import GHC.Generics (Generic (..))

import Builder
import CheapReduction
import CheckType
-- import CheckType
import Core
import Err
import IRVariants
Expand Down Expand Up @@ -1037,12 +1037,12 @@ matchPrimApp = \case
UCon con -> case con of
P.ProdCon -> \xs -> return $ toAtom $ ProdCon xs
P.SumCon _ -> error "not supported"
UMiscOp op -> \x -> emit =<< MiscOp <$> matchGenericOp op x
UMemOp op -> \x -> emit =<< MemOp <$> matchGenericOp op x
UBinOp op -> \case ~[x, y] -> emit $ BinOp op x y
UUnOp op -> \case ~[x] -> emit $ UnOp op x
UMGet -> \case ~[r] -> emit $ RefOp r MGet
UMPut -> \case ~[r, x] -> emit $ RefOp r $ MPut x
-- UMiscOp op -> \x -> emit =<< MiscOp <$> matchGenericOp op x
-- UMemOp op -> \x -> emit =<< MemOp <$> matchGenericOp op x
UBinOp op -> \case ~[x, y] -> emitBinOp op x y
UUnOp op -> \case ~[x] -> emitUnOp op x
UMGet -> \case ~[r] -> emitRefOp r MGet
UMPut -> \case ~[r, x] -> emitRefOp r $ MPut x
UIndexRef -> \case ~[r, i] -> indexRef r i
UApplyMethod i -> \case ~(d:args) -> emit =<< mkApplyMethod (fromJust $ toMaybeDict d) i args
ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x
Expand Down Expand Up @@ -1116,7 +1116,7 @@ buildNthOrderedAlt alts _ resultTy i v = do
case lookup i [(idx, alt) | IndexedAlt idx alt <- alts] of
Nothing -> do
resultTy' <- sinkM resultTy
emit $ ThrowError resultTy'
emit $ PrimOp resultTy' $ MiscOp ThrowError
Just alt -> applyAbs alt (SubstVal v) >>= emit

buildMonomorphicCase
Expand Down Expand Up @@ -2213,14 +2213,6 @@ instance PrettyE e => Pretty (UDeclInferenceResult e l) where
instance SinkableE e => SinkableE (UDeclInferenceResult e) where
sinkingProofE = todoSinkableProof

instance (RenameE e, CheckableE CoreIR e) => CheckableE CoreIR (UDeclInferenceResult e) where
checkE = \case
UDeclResultDone e -> UDeclResultDone <$> checkE e
UDeclResultBindName ann block ab ->
UDeclResultBindName ann <$> checkE block <*> renameM ab -- TODO: check result
UDeclResultBindPattern hint block recon ->
UDeclResultBindPattern hint <$> checkE block <*> renameM recon -- TODO: check recon

instance GenericE SynthType where
type RepE SynthType = EitherE2 DictType (PairE (LiftE [Explicitness]) (Abs (Nest CBinder) DictType))
fromE (SynthDictType d) = Case0 d
Expand Down
Loading

0 comments on commit 7f895d4

Please sign in to comment.