-
Notifications
You must be signed in to change notification settings - Fork 117
/
TH.hs
455 lines (409 loc) · 17.6 KB
/
TH.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
-- |
-- Module : Data.Array.Accelerate.Pattern.TH
-- Copyright : [2018..2020] The Accelerate Team
-- License : BSD3
--
-- Maintainer : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability : experimental
-- Portability : non-portable (GHC extensions)
--
module Data.Array.Accelerate.Pattern.TH (
mkPattern,
mkPatterns,
) where
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Type
import Control.Monad
import Data.Bits
import Data.Char
import Data.List ( (\\), foldl' )
import Language.Haskell.TH.Extra hiding ( Exp, Match, match )
import Numeric
import Text.Printf
import qualified Language.Haskell.TH.Extra as TH
import GHC.Stack
-- | As 'mkPattern', but for a list of types
--
mkPatterns :: [Name] -> DecsQ
mkPatterns nms = concat <$> mapM mkPattern nms
-- | Generate pattern synonyms for the given simple (Haskell'98) sum or
-- product data type.
--
-- Constructor and record selectors are renamed to add a trailing
-- underscore if it does not exist, or to remove it if it does. For infix
-- constructors, the name is prepended with a colon ':'. For example:
--
-- > data Point = Point { xcoord_ :: Float, ycoord_ :: Float }
-- > deriving (Generic, Elt)
--
-- Will create the pattern synonym:
--
-- > Point_ :: Exp Float -> Exp Float -> Exp Point
--
-- together with the selector functions
--
-- > xcoord :: Exp Point -> Exp Float
-- > ycoord :: Exp Point -> Exp Float
--
mkPattern :: Name -> DecsQ
mkPattern nm = do
info <- reify nm
case info of
TyConI dec -> mkDec dec
_ -> fail "mkPatterns: expected the name of a newtype or datatype"
mkDec :: Dec -> DecsQ
mkDec dec =
case dec of
DataD _ nm tv _ cs _ -> mkDataD nm tv cs
NewtypeD _ nm tv _ c _ -> mkNewtypeD nm tv c
_ -> fail "mkPatterns: expected the name of a newtype or datatype"
mkNewtypeD :: Name -> [TyVarBndr a] -> Con -> DecsQ
mkNewtypeD tn tvs c = mkDataD tn tvs [c]
mkDataD :: Name -> [TyVarBndr a] -> [Con] -> DecsQ
mkDataD tn tvs cs = do
(pats, decs) <- unzip <$> go cs
comp <- pragCompleteD pats Nothing
return $ comp : concat decs
where
-- For single-constructor types we create the pattern synonym for the
-- type directly in terms of Pattern
go [] = fail "mkPatterns: empty data declarations not supported"
go [c] = return <$> mkConP tn tvs c
go _ = go' [] (map fieldTys cs) ctags cs
-- For sum-types, when creating the pattern for an individual
-- constructor we need to know about the types of the fields all other
-- constructors as well
go' prev (this:next) (tag:tags) (con:cons) = do
r <- mkConS tn tvs prev next tag con
rs <- go' (this:prev) next tags cons
return (r : rs)
go' _ [] [] [] = return []
go' _ _ _ _ = fail "mkPatterns: unexpected error"
fieldTys (NormalC _ fs) = map snd fs
fieldTys (RecC _ fs) = map (\(_,_,t) -> t) fs
fieldTys (InfixC a _ b) = [snd a, snd b]
fieldTys _ = fail "mkPatterns: only constructors for \"vanilla\" syntax are supported"
-- TODO: The GTags class demonstrates a way to generate the tags for
-- a given constructor, rather than backwards-engineering the structure
-- as we've done here. We should use that instead!
--
ctags =
let n = length cs
m = n `quot` 2
l = take m (iterate (True:) [False])
r = take (n-m) (iterate (True:) [True])
--
bitsToTag = foldl' f 0
where
f i False = i `shiftL` 1
f i True = setBit (i `shiftL` 1) 0
in
map bitsToTag (l ++ r)
mkConP :: Name -> [TyVarBndr a] -> Con -> Q (Name, [Dec])
mkConP tn' tvs' con' = do
checkExts [ PatternSynonyms ]
case con' of
NormalC cn fs -> mkNormalC tn' cn (map tyVarBndrName tvs') (map snd fs)
RecC cn fs -> mkRecC tn' cn (map tyVarBndrName tvs') (map (rename . fst3) fs) (map thd3 fs)
InfixC a cn b -> mkInfixC tn' cn (map tyVarBndrName tvs') [snd a, snd b]
_ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported"
where
mkNormalC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec])
mkNormalC tn cn tvs fs = do
xs <- replicateM (length fs) (newName "_x")
r <- sequence [ patSynSigD pat sig
, patSynD pat
(prefixPatSyn xs)
implBidir
[p| Pattern $(tupP (map varP xs)) |]
]
return (pat, r)
where
pat = rename cn
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkRecC :: Name -> Name -> [Name] -> [Name] -> [Type] -> Q (Name, [Dec])
mkRecC tn cn tvs xs fs = do
r <- sequence [ patSynSigD pat sig
, patSynD pat
(recordPatSyn xs)
implBidir
[p| Pattern $(tupP (map varP xs)) |]
]
return (pat, r)
where
pat = rename cn
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkInfixC :: Name -> Name -> [Name] -> [Type] -> Q (Name, [Dec])
mkInfixC tn cn tvs fs = do
mf <- reifyFixity cn
_a <- newName "_a"
_b <- newName "_b"
r <- sequence [ patSynSigD pat sig
, patSynD pat
(infixPatSyn _a _b)
implBidir
[p| Pattern $(tupP [varP _a, varP _b]) |]
]
r' <- case mf of
Nothing -> return r
Just f -> return (InfixD f pat : r)
return (pat, r')
where
pat = mkName (':' : nameBase cn)
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkConS :: Name -> [TyVarBndr a] -> [[Type]] -> [[Type]] -> Word8 -> Con -> Q (Name, [Dec])
mkConS tn' tvs' prev' next' tag' con' = do
checkExts [GADTs, PatternSynonyms, ScopedTypeVariables, TypeApplications, ViewPatterns]
case con' of
NormalC cn fs -> mkNormalC tn' cn tag' (map tyVarBndrName tvs') prev' (map snd fs) next'
RecC cn fs -> mkRecC tn' cn tag' (map tyVarBndrName tvs') (map (rename . fst3) fs) prev' (map thd3 fs) next'
InfixC a cn b -> mkInfixC tn' cn tag' (map tyVarBndrName tvs') prev' [snd a, snd b] next'
_ -> fail "mkPatterns: only constructors for \"vanilla\" syntax are supported"
where
mkNormalC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkNormalC tn cn tag tvs ps fs ns = do
let pat = rename cn
(fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns
(fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns
dec_pat <- mkNormalC_pattern tn pat tvs fs fun_build fun_match
return $ (pat, concat [dec_pat, dec_build, dec_match])
mkRecC :: Name -> Name -> Word8 -> [Name] -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkRecC tn cn tag tvs xs ps fs ns = do
let pat = rename cn
(fun_build, dec_build) <- mkBuild tn (nameBase cn) tvs tag ps fs ns
(fun_match, dec_match) <- mkMatch tn (nameBase pat) (nameBase cn) tvs tag ps fs ns
dec_pat <- mkRecC_pattern tn pat tvs xs fs fun_build fun_match
return $ (pat, concat [dec_pat, dec_build, dec_match])
mkInfixC :: Name -> Name -> Word8 -> [Name] -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkInfixC tn cn tag tvs ps fs ns = do
let pat = mkName (':' : nameBase cn)
(fun_build, dec_build) <- mkBuild tn (zencode (nameBase cn)) tvs tag ps fs ns
(fun_match, dec_match) <- mkMatch tn ("(" ++ nameBase pat ++ ")") (zencode (nameBase cn)) tvs tag ps fs ns
dec_pat <- mkInfixC_pattern tn cn pat tvs fs fun_build fun_match
return $ (pat, concat [dec_pat, dec_build, dec_match])
mkNormalC_pattern :: Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec]
mkNormalC_pattern tn pat tvs fs build match = do
xs <- replicateM (length fs) (newName "_x")
r <- sequence [ patSynSigD pat sig
, patSynD pat
(prefixPatSyn xs)
(explBidir [clause [] (normalB (varE build)) []])
(parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |])
]
return r
where
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkRecC_pattern :: Name -> Name -> [Name] -> [Name] -> [Type] -> Name -> Name -> Q [Dec]
mkRecC_pattern tn pat tvs xs fs build match = do
r <- sequence [ patSynSigD pat sig
, patSynD pat
(recordPatSyn xs)
(explBidir [clause [] (normalB (varE build)) []])
(parensP $ viewP (varE match) [p| Just $(tupP (map varP xs)) |])
]
return r
where
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkInfixC_pattern :: Name -> Name -> Name -> [Name] -> [Type] -> Name -> Name -> Q [Dec]
mkInfixC_pattern tn cn pat tvs fs build match = do
mf <- reifyFixity cn
_a <- newName "_a"
_b <- newName "_b"
r <- sequence [ patSynSigD pat sig
, patSynD pat
(infixPatSyn _a _b)
(explBidir [clause [] (normalB (varE build)) []])
(parensP $ viewP (varE match) [p| Just $(tupP [varP _a, varP _b]) |])
]
r' <- case mf of
Nothing -> return r
Just f -> return (InfixD f pat : r)
return r'
where
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkBuild :: Name -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkBuild tn cn tvs tag fs0 fs fs1 = do
fun <- newName ("_build" ++ cn)
xs <- replicateM (length fs) (newName "_x")
let
vs = foldl' (\es e -> [| SmartExp ($es `Pair` $e) |]) [| SmartExp Nil |]
$ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat (reverse fs0))
++ map varE xs
++ map (\t -> [| unExp $(varE 'undef `appTypeE` return t) |] ) (concat fs1)
tagged = [| Exp $ SmartExp $ Pair (SmartExp (Const (SingleScalarType (NumSingleType (IntegralNumType TypeWord8))) $(litE (IntegerL (toInteger tag))))) $vs |]
body = clause (map (\x -> [p| (Exp $(varP x)) |]) xs) (normalB tagged) []
r <- sequence [ sigD fun sig
, funD fun [body]
]
return (fun, r)
where
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt (map (\t -> [t| Elt $(varT t) |]) tvs))
(foldr (\t ts -> [t| $t -> $ts |])
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) |]
(map (\t -> [t| Exp $(return t) |]) fs))
mkMatch :: Name -> String -> String -> [Name] -> Word8 -> [[Type]] -> [Type] -> [[Type]] -> Q (Name, [Dec])
mkMatch tn pn cn tvs tag fs0 fs fs1 = do
fun <- newName ("_match" ++ cn)
e <- newName "_e"
x <- newName "_x"
(ps,es) <- extract vs [| Prj PairIdxRight $(varE x) |] [] []
unbind <- isExtEnabled RebindableSyntax
let
eqE = if unbind then letE [funD (mkName "==") [clause [] (normalB (varE '(==))) []]] else id
lhs = [p| (Exp $(varP e)) |]
body = normalB $ eqE $ caseE (varE e)
[ TH.match (conP 'SmartExp [(conP 'Match [matchP ps, varP x])]) (normalB [| Just $(tupE es) |]) []
, TH.match (conP 'SmartExp [(recP 'Match [])]) (normalB [| Nothing |]) []
, TH.match wildP (normalB [| error $error_msg |]) []
]
r <- sequence [ sigD fun sig
, funD fun [clause [lhs] body []]
]
return (fun, r)
where
sig = forallT
(map (`plainInvisTV` specifiedSpec) tvs)
(cxt ([t| HasCallStack |] : map (\t -> [t| Elt $(varT t) |]) tvs))
[t| Exp $(foldl' appT (conT tn) (map varT tvs)) -> Maybe $(tupT (map (\t -> [t| Exp $(return t) |]) fs)) |]
matchP us = [p| TagRtag $(litP (IntegerL (toInteger tag))) $pat |]
where
pat = [p| $(foldl (\ps p -> [p| TagRpair $ps $p |]) [p| TagRunit |] us) |]
extract [] _ ps es = return (ps, es)
extract (u:us) x ps es = do
_u <- newName "_u"
let x' = [| Prj PairIdxLeft (SmartExp $x) |]
if not u
then extract us x' (wildP:ps) es
else extract us x' (varP _u:ps) ([| Exp (SmartExp (Match $(varE _u) (SmartExp (Prj PairIdxRight (SmartExp $x))))) |] : es)
vs = reverse
$ [ False | _ <- concat fs0 ] ++ [ True | _ <- fs ] ++ [ False | _ <- concat fs1 ]
error_msg =
let pv = unwords
$ take (length fs + 1)
$ concatMap (map reverse)
$ iterate (concatMap (\xs -> [ x:xs | x <- ['a'..'z'] ])) [""]
in stringE $ unlines
[ "Embedded pattern synonym used outside 'match' context."
, ""
, "To use case statements in the embedded language the case statement must"
, "be applied as an n-ary function to the 'match' operator. For single"
, "argument case statements this can be done inline using LambdaCase, for"
, "example:"
, ""
, "> x & match \\case"
, printf "> %s%s -> ..." pn pv
, printf "> _%s -> ..." (replicate (length pn + length pv - 1) ' ')
]
fst3 :: (a,b,c) -> a
fst3 (a,_,_) = a
thd3 :: (a,b,c) -> c
thd3 (_,_,c) = c
rename :: Name -> Name
rename nm =
let
split acc [] = (reverse acc, '\0') -- shouldn't happen
split acc [l] = (reverse acc, l)
split acc (l:ls) = split (l:acc) ls
--
nm' = nameBase nm
(base, suffix) = split [] nm'
in
case suffix of
'_' -> mkName base
_ -> mkName (nm' ++ "_")
checkExts :: [Extension] -> Q ()
checkExts req = do
enabled <- extsEnabled
let missing = req \\ enabled
unless (null missing) . fail . unlines
$ printf "You must enable the following language extensions to generate pattern synonyms:"
: map (printf " {-# LANGUAGE %s #-}" . show) missing
-- A simplified version of that stolen from GHC/Utils/Encoding.hs
--
type EncodedString = String
zencode :: String -> EncodedString
zencode [] = []
zencode (h:rest) = encode_digit h ++ go rest
where
go [] = []
go (c:cs) = encode_ch c ++ go cs
unencoded_char :: Char -> Bool
unencoded_char 'z' = False
unencoded_char 'Z' = False
unencoded_char c = isAlphaNum c
encode_digit :: Char -> EncodedString
encode_digit c | isDigit c = encode_as_unicode_char c
| otherwise = encode_ch c
encode_ch :: Char -> EncodedString
encode_ch c | unencoded_char c = [c] -- Common case first
encode_ch '(' = "ZL"
encode_ch ')' = "ZR"
encode_ch '[' = "ZM"
encode_ch ']' = "ZN"
encode_ch ':' = "ZC"
encode_ch 'Z' = "ZZ"
encode_ch 'z' = "zz"
encode_ch '&' = "za"
encode_ch '|' = "zb"
encode_ch '^' = "zc"
encode_ch '$' = "zd"
encode_ch '=' = "ze"
encode_ch '>' = "zg"
encode_ch '#' = "zh"
encode_ch '.' = "zi"
encode_ch '<' = "zl"
encode_ch '-' = "zm"
encode_ch '!' = "zn"
encode_ch '+' = "zp"
encode_ch '\'' = "zq"
encode_ch '\\' = "zr"
encode_ch '/' = "zs"
encode_ch '*' = "zt"
encode_ch '_' = "zu"
encode_ch '%' = "zv"
encode_ch c = encode_as_unicode_char c
encode_as_unicode_char :: Char -> EncodedString
encode_as_unicode_char c
= 'z'
: if isDigit (head hex_str) then hex_str
else '0':hex_str
where
hex_str = showHex (ord c) "U"