Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize through user-defined index sets #1316

Merged
merged 4 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/lib/ImpToLLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -902,16 +902,33 @@ withWidthOfFP x template = case typeOf template of
L.FloatingPointType L.FloatFP -> litVal $ Float32Lit $ realToFrac x
_ -> error $ "Unsupported floating point type: " ++ show (typeOf template)

-- If we are accessing a `L.Type` from a Dex array, what memory alignment (in
-- bytes) can we guarantee? This is probably better expressed in Dex types, but
-- we would need to plumb them to do it that way. 1-byte alignment should
-- always be safe, but we can promise higher-performance alignments for some
-- types.
dexAlignment :: L.Type -> Word32
dexAlignment = \case
L.IntegerType bits | bits `mod` 8 == 0 -> bits `div` 8
L.IntegerType _ -> 1
L.PointerType _ _ -> 4
L.FloatingPointType L.FloatFP -> 4
L.FloatingPointType L.DoubleFP -> 8
L.VectorType _ eltTy -> dexAlignment eltTy
_ -> 1

store :: LLVMBuilder m => Operand -> Operand -> m ()
store ptr x = addInstr $ L.Do $ L.Store False ptr x Nothing 0 []
store ptr x = addInstr $ L.Do $ L.Store False ptr x Nothing alignment [] where
alignment = dexAlignment $ typeOf x

load :: LLVMBuilder m => L.Type -> Operand -> m Operand
load pointeeTy ptr =
#if MIN_VERSION_llvm_hs(15,0,0)
emitInstr pointeeTy $ L.Load False pointeeTy ptr Nothing 0 []
emitInstr pointeeTy $ L.Load False pointeeTy ptr Nothing alignment []
#else
emitInstr pointeeTy $ L.Load False ptr Nothing 0 []
emitInstr pointeeTy $ L.Load False ptr Nothing alignment []
#endif
where alignment = dexAlignment pointeeTy

ilt :: LLVMBuilder m => Operand -> Operand -> m Operand
ilt x y = emitInstr i1 $ L.ICmp IP.SLT x y []
Expand Down
94 changes: 59 additions & 35 deletions src/lib/Vectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ import Util (allM, zipWithZ)
-- TODO: Local vector values? We might want to pack short and pure for loops into vectors,
-- to support things like float3 etc.
data Stability
= Uniform -- constant across vectorized dimension
| Varying -- varying across vectorized dimension
| Contiguous -- varying, but contiguous across vectorized dimension
-- Constant across vectorized dimension, represented as a scalar
= Uniform
-- Varying across vectorized dimension, represented as a vector
| Varying
-- Varying, but contiguous across vectorized dimension; represented as a
-- scalar carrying the first value
| Contiguous
| ProdStability [Stability]
deriving (Eq, Show)

Expand Down Expand Up @@ -168,25 +172,27 @@ vectorizeLoopsExpr expr = do
narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr
let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth
case expr of
PrimOp (DAMOp (Seq effs dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal n))) dest body))
| n `mod` loopWidth == 0 -> (do
safe <- vectorSafeEffect effs
if safe
then (do
Distinct <- getDistinct
let vn = n `div` loopWidth
body' <- vectorizeSeq loopWidth body
dest' <- renameM dest
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
return $ PrimOp $ DAMOp seqOp)
else renameM expr)
`catchErr` \errs -> do
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
ctx = mempty { messageCtx = [msg] }
errs' = prependCtxToErrs ctx errs
modify (<> LiftE errs')
recurSeq expr
PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do
sz <- simplifyIxSize =<< renameM ixty
case sz of
Just n | n `mod` loopWidth == 0 -> (do
safe <- vectorSafeEffect effs
if safe
then (do
Distinct <- getDistinct
let vn = n `div` loopWidth
body' <- vectorizeSeq loopWidth ixty body
dest' <- renameM dest
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
return $ PrimOp $ DAMOp seqOp)
else renameM expr)
`catchErr` \errs -> do
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
ctx = mempty { messageCtx = [msg] }
errs' = prependCtxToErrs ctx errs
modify (<> LiftE errs')
recurSeq expr
_ -> recurSeq expr
PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
item' <- renameM item
itemTy <- return $ getType item'
Expand Down Expand Up @@ -218,6 +224,15 @@ vectorizeLoopsExpr expr = do
return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
recurSeq _ = error "Impossible"

simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m)
=> IxType SimpIR n -> m n (Maybe Word32)
simplifyIxSize ixty = do
sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size []
cheapReduce sizeMethod >>= \case
Just (IdxRepVal n) -> return $ Just n
_ -> return Nothing
{-# INLINE simplifyIxSize #-}

-- Really we should check this by seeing whether there is an instance for a
-- `Commutative` class, or something like that, but for now just pattern-match
-- to detect scalar addition as the only monoid we recognize as commutative.
Expand Down Expand Up @@ -300,22 +315,27 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where
Nothing -> error $ "Handle " ++ pprint h ++ " not present in commute map?"
safe _ = return False

vectorizeSeq :: Word32 -> LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o)
vectorizeSeq loopWidth (UnaryLamExpr (b:>ty) body) = do
(_, ty') <- case ty of
ProdTy [ixTy, ref] -> do
ixTy' <- renameM ixTy
vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i
-> TopVectorizeM i o (LamExpr SimpIR o)
vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do
newLoopTy <- case ty of
ProdTy [_ixType, ref] -> do
ref' <- renameM ref
return (ixTy', ProdTy [IdxRepTy, ref'])
return $ ProdTy [IdxRepTy, ref']
_ -> error "Unexpected seq binder type"
ixty' <- renameM ixty
liftVectorizeM loopWidth $
buildUnaryLamExpr (getNameHint b) ty' \ci -> do
-- XXX: we're assuming `Fin n` here
buildUnaryLamExpr (getNameHint b) newLoopTy \ci -> do
-- The per-tile loop iterates on `Fin`
(viOrd, dest) <- fromPair $ Var ci
iOrd <- imul viOrd $ IdxRepVal loopWidth
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal iOrd dest)) $
-- TODO: It would be nice to cancel this UnsafeFromOrdinal with the
-- Ordinal that will be taken later when indexing, but that should
-- probably be a separate pass.
i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd]
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal i dest)) $
vectorizeBlock body $> UnitVal
vectorizeSeq _ _ = error "expected a unary lambda expression"
vectorizeSeq _ _ _ = error "expected a unary lambda expression"

newtype VectorizeM i o a =
VectorizeM { runVectorizeM ::
Expand Down Expand Up @@ -467,9 +487,13 @@ vectorizePrimOp op = case op of
BinOp opk arg1 arg2 -> do
sx@(VVal vx x) <- vectorizeAtom arg1
sy@(VVal vy y) <- vectorizeAtom arg2
let v = case (vx, vy) of (Uniform, Uniform) -> Uniform; _ -> Varying
x' <- if vx /= v then ensureVarying sx else return x
y' <- if vy /= v then ensureVarying sy else return y
let v = case (opk, vx, vy) of
(_, Uniform, Uniform) -> Uniform
(IAdd, Uniform, Contiguous) -> Contiguous
(IAdd, Contiguous, Uniform) -> Contiguous
_ -> Varying
x' <- if v == Varying then ensureVarying sx else return x
y' <- if v == Varying then ensureVarying sy else return y
VVal v <$> emitOp (BinOp opk x' y')
MiscOp (CastOp tyArg arg) -> do
ty <- vectorizeType tyArg
Expand Down
41 changes: 38 additions & 3 deletions tests/opt-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ _ = for i:(Fin 20) j:(Fin 4). ordinal j
"vectorizing int binary op"
-- CHECK-LABEL: vectorizing int binary op
%passes vect
_ = for i:(Fin 256). (n_to_i32 (ordinal i)) + 1
_ = for i:(Fin 256). (n_to_i32 (ordinal i)) * 2
-- CHECK: seq (RawFin 0x10)
-- CHECK: [[i0:v#[0-9]+]]:<16xInt32> = vbroadcast
-- CHECK: [[i1:v#[0-9]+]]:<16xInt32> = viota
-- CHECK: [[i2:v#[0-9]+]]:<16xInt32> = %iadd [[i0]] [[i1]]
-- CHECK: [[ones:v#[0-9]+]]:<16xInt32> = vbroadcast 1
-- CHECK: %iadd [[i2]] [[ones]]
-- CHECK: [[twos:v#[0-9]+]]:<16xInt32> = vbroadcast 2
-- CHECK: %imul [[i2]] [[twos]]

"vectorizing float binary op"
-- CHECK-LABEL: vectorizing float binary op
Expand Down Expand Up @@ -211,3 +211,38 @@ _ = yield_accum (AddMonoid Int32) \result.
-- CHECK: [[mat1:v#[0-9]+]]:<16xInt32> = vbroadcast
-- CHECK: [[prodj:v#[0-9]+]]:<16xInt32> = %imul [[mat1]] [[mat2j]]
-- CHECK: extend [[refj]] [[prodj]]

"vectorizing through the `tile` combinator and its funny index set"
-- CHECK-LABEL: vectorizing through the `tile` combinator and its funny index set

%passes vect
_ = yield_accum (AddMonoid Int32) \result.
tile((Fin 256), 32) \set.
for_ i:set.
ix = inject(i, to=(Fin 256))
result!ix += xs[ix]
-- CHECK: seq (RawFin 0x8)
-- CHECK: seq (RawFin 0x2)
-- CHECK: [[refix:v#[0-9]+]]:(Ref {{v#[0-9]+}} <16xInt32>) = vrefslice
-- CHECK: [[xsix:v#[0-9]+]]:<16xInt32> =
-- CHECK-NEXT: vslice
-- CHECK: extend [[refix]] [[xsix]]

"Non-aligned"
-- CHECK-LABEL: Non-aligned

-- This is a regression test. We are checking that Dex-side
-- vectorization does not end up assuming that arrays are aligned on
-- the size of the vectors, only on the size of the underlying
-- scalars.

non_aligned = for i:(Fin 7). for j:(Fin 257). +0

%passes llvm
_ = yield_accum (AddMonoid Int32) \result.
tile((Fin 257), 32) \set.
for_ i:set.
ix = inject(i, to=(Fin 257))
result!(6@(Fin 7))!ix += non_aligned[6@_][ix]
-- CHECK: load <16 x i32>, <16 x i32>* %"v#{{[0-9]+}}", align 4
-- CHECK: store <16 x i32> %"v#{{[0-9]+}}", <16 x i32>* %"v#{{[0-9]+}}", align 4