diff --git a/.github/workflows/normalize-cabal.yml b/.github/workflows/normalize-cabal.yml index 0f5eaef85..be8c19459 100644 --- a/.github/workflows/normalize-cabal.yml +++ b/.github/workflows/normalize-cabal.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - - uses: tfausak/cabal-gild-setup-action@v1 + - uses: tfausak/cabal-gild-setup-action@v2 with: - version: 0.3.0.1 + version: 1.3.0.1 - run: cabal-gild --input swarm.cabal --mode check diff --git a/src/swarm-lang/Swarm/Effect/Unify.hs b/src/swarm-lang/Swarm/Effect/Unify.hs new file mode 100644 index 000000000..e0b0a1bde --- /dev/null +++ b/src/swarm-lang/Swarm/Effect/Unify.hs @@ -0,0 +1,53 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} + +-- | +-- SPDX-License-Identifier: BSD-3-Clause +-- Description: This module defines an effect signature for +-- computations that support doing unification. The intention is for +-- code needing unification to use the operations defined in this +-- module, and then import 'Swarm.Effect.Unify.Fast' to dispatch the +-- 'Unification' effects. +module Swarm.Effect.Unify where + +import Control.Algebra +import Data.Kind (Type) +import Data.Set (Set) +import Swarm.Language.Types hiding (Type) + +-- | Data type representing available unification operations. +data Unification (m :: Type -> Type) k where + Unify :: UType -> UType -> Unification m (Either UnificationError UType) + ApplyBindings :: UType -> Unification m UType + FreshIntVar :: Unification m IntVar + FreeUVars :: UType -> Unification m (Set IntVar) + +-- | Unify two types, returning a type equal to both, or a 'UnificationError' if +-- the types definitely do not unify. +(=:=) :: Has Unification sig m => UType -> UType -> m (Either UnificationError UType) +t1 =:= t2 = send (Unify t1 t2) + +-- | Substitute for all the unification variables that are currently +-- bound. It is guaranteed that any unification variables remaining +-- in the result are not currently bound, /i.e./ we have learned no +-- information about them. +applyBindings :: Has Unification sig m => UType -> m UType +applyBindings = send . ApplyBindings + +-- | Compute the set of free unification variables of a type (after +-- substituting away any which are already bound). +freeUVars :: Has Unification sig m => UType -> m (Set IntVar) +freeUVars = send . FreeUVars + +-- | Generate a fresh unification variable. +freshIntVar :: Has Unification sig m => m IntVar +freshIntVar = send FreshIntVar + +-- | An error that occurred while running the unifier. +data UnificationError where + -- | Occurs check failure, i.e. the solution to some unification + -- equations was an infinite term. + Infinite :: IntVar -> UType -> UnificationError + -- | Mismatch error between the given terms. + UnifyErr :: TypeF UType -> TypeF UType -> UnificationError + deriving (Show) diff --git a/src/swarm-lang/Swarm/Effect/Unify/Common.hs b/src/swarm-lang/Swarm/Effect/Unify/Common.hs new file mode 100644 index 000000000..b7e618ef8 --- /dev/null +++ b/src/swarm-lang/Swarm/Effect/Unify/Common.hs @@ -0,0 +1,60 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralisedNewtypeDeriving #-} + +-- | +-- SPDX-License-Identifier: BSD-3-Clause +-- +-- Description: Common definitions used in both the naive and fast +-- implementations of unification. +module Swarm.Effect.Unify.Common where + +import Control.Algebra +import Control.Effect.State (State, get) +import Data.Map (Map) +import Data.Map qualified as M +import Data.Set (Set) +import Prelude hiding (lookup) + +------------------------------------------------------------ +-- Substitutions + +-- | A value of type @Subst n a@ is a substitution which maps +-- names of type @n@ (the /domain/, see 'dom') to values of type +-- @a@. Substitutions can be /applied/ to certain terms (see +-- 'subst'), replacing any free occurrences of names in the +-- domain with their corresponding values. Thus, substitutions can +-- be thought of as functions of type @Term -> Term@ (for suitable +-- @Term@s that contain names and values of the right type). +-- +-- Concretely, substitutions are stored using a @Map@. +newtype Subst n a = Subst {getSubst :: Map n a} + deriving (Eq, Ord, Show, Functor) + +-- | The domain of a substitution is the set of names for which the +-- substitution is defined. +dom :: Subst n a -> Set n +dom = M.keysSet . getSubst + +-- | The identity substitution, /i.e./ the unique substitution with an +-- empty domain, which acts as the identity function on terms. +idS :: Subst n a +idS = Subst M.empty + +-- | Construct a singleton substitution, which maps the given name to +-- the given value. +(|->) :: n -> a -> Subst n a +x |-> t = Subst (M.singleton x t) + +-- | Insert a new name/value binding into the substitution. +insert :: Ord n => n -> a -> Subst n a -> Subst n a +insert n a (Subst m) = Subst (M.insert n a m) + +-- | Look up the value a particular name maps to under the given +-- substitution; or return @Nothing@ if the name being looked up is +-- not in the domain. +lookup :: Ord n => n -> Subst n a -> Maybe a +lookup x (Subst m) = M.lookup x m + +-- | Look up a name in a substitution stored in a state effect. +lookupS :: (Ord n, Has (State (Subst n a)) sig m) => n -> m (Maybe a) +lookupS x = lookup x <$> get diff --git a/src/swarm-lang/Swarm/Effect/Unify/Fast.hs b/src/swarm-lang/Swarm/Effect/Unify/Fast.hs new file mode 100644 index 000000000..d34faccf3 --- /dev/null +++ b/src/swarm-lang/Swarm/Effect/Unify/Fast.hs @@ -0,0 +1,238 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | +-- SPDX-License-Identifier: BSD-3-Clause +-- +-- Description: Fast yet purely functional implementation of +-- unification, using a map as a lazy substitution, i.e. a +-- manually-maintained "functional shared memory". +-- +-- See Dijkstra, Middelkoop, & Swierstra, "Efficient Functional +-- Unification and Substitution", Utrecht University tech report +-- UU-CS-2008-027 (section 5) for the basic idea, and Peyton Jones et +-- al, "Practical type inference for arbitrary-rank types" +-- (pp. 74--75) for a correct implementation of unification via +-- references. +module Swarm.Effect.Unify.Fast where + +import Control.Algebra +import Control.Applicative (Alternative) +import Control.Carrier.State.Strict (StateC, evalState) +import Control.Carrier.Throw.Either (ThrowC, runThrow) +import Control.Category ((>>>)) +import Control.Effect.State (State, get, gets, modify) +import Control.Effect.Throw (Throw, throwError) +import Control.Monad.Free +import Control.Monad.Trans (MonadIO) +import Data.Function (on) +import Data.Map qualified as M +import Data.Map.Merge.Lazy qualified as M +import Data.Set qualified as S +import Swarm.Effect.Unify +import Swarm.Effect.Unify.Common +import Swarm.Language.Types hiding (Type) +import Prelude hiding (lookup) + +------------------------------------------------------------ +-- Substitutions + +-- | Compose two substitutions. Applying @s1 \@\@ s2@ is the same as +-- applying first @s2@, then @s1@; that is, semantically, +-- composition of substitutions corresponds exactly to function +-- composition when they are considered as functions on terms. +-- +-- As one would expect, composition is associative and has 'idS' as +-- its identity. +-- +-- Note that we do /not/ apply @s1@ to all the values in @s2@, since +-- the substitution is maintained lazily; we do not need to maintain +-- the invariant that values in the mapping do not contain any of +-- the keys. This makes composition much faster, at the cost of +-- making application more complex. +(@@) :: (Ord n, Substitutes n a a) => Subst n a -> Subst n a -> Subst n a +(Subst s1) @@ (Subst s2) = Subst (s2 `M.union` s1) + +-- | Class of things supporting substitution. @Substitutes n b a@ means +-- that we can apply a substitution of type @Subst n b@ to a +-- value of type @a@, replacing all the free names of type @n@ +-- inside the @a@ with values of type @b@, resulting in a new value +-- of type @a@. +-- +-- We also do a lazy occurs-check during substitution application, +-- so we need the ability to throw a unification error. +class Substitutes n b a where + subst :: Has (Throw UnificationError) sig m => Subst n b -> a -> m a + +-- | We can perform substitution on terms built up as the free monad +-- over a structure functor @f@. +instance Substitutes IntVar UType UType where + subst s = go S.empty + where + go seen (Pure x) = case lookup x s of + Nothing -> pure $ Pure x + Just t + | S.member x seen -> throwError $ Infinite x t + | otherwise -> go (S.insert x seen) t + go seen (Free t) = Free <$> goF seen t + + goF _ t@(TyBaseF {}) = pure t + goF _ t@(TyVarF {}) = pure t + goF seen (TySumF t1 t2) = TySumF <$> go seen t1 <*> go seen t2 + goF seen (TyProdF t1 t2) = TyProdF <$> go seen t1 <*> go seen t2 + goF seen (TyRcdF m) = TyRcdF <$> mapM (go seen) m + goF seen (TyCmdF c) = TyCmdF <$> go seen c + goF seen (TyDelayF c) = TyDelayF <$> go seen c + goF seen (TyFunF t1 t2) = TyFunF <$> go seen t1 <*> go seen t2 + +------------------------------------------------------------ +-- Carrier type + +-- Note: this carrier type and the runUnification function are +-- identical between this module and Swarm.Effect.Unify.Naive, but it +-- seemed best to duplicate it, so we can modify the carriers +-- independently in the future if we want. + +-- | Carrier type for unification: we maintain a current substitution, +-- a counter for generating fresh unification variables, and can +-- throw unification errors. +newtype UnificationC m a = UnificationC + { unUnificationC :: + StateC (Subst IntVar UType) (StateC FreshVarCounter (ThrowC UnificationError m)) a + } + deriving newtype (Functor, Applicative, Alternative, Monad, MonadIO) + +-- | Counter for generating fresh unification variables. +newtype FreshVarCounter = FreshVarCounter {getFreshVarCounter :: Int} + deriving (Eq, Ord, Enum) + +-- | Run a 'Unification' effect via the 'UnificationC' carrier. +runUnification :: Algebra sig m => UnificationC m a -> m (Either UnificationError a) +runUnification = + unUnificationC >>> evalState idS >>> evalState (FreshVarCounter 0) >>> runThrow + +------------------------------------------------------------ +-- Unification + +-- The idea here (using an explicit substitution as a sort of +-- "functional shared memory", instead of directly using IORefs), is +-- based on Dijkstra et al. Unfortunately, their implementation of +-- unification is subtly wrong; fortunately, a single integration test +-- in the Swarm test suite failed, leading to discovering the bug. +-- The basic issue is that when unifying an equation between two +-- variables @x = y@, we must look up *both* to see whether they are +-- already mapped by the substitution (and if so, replace them by +-- their referent and keep recursing). Dijkstra et al. only look up +-- @x@ and simply map @x |-> y@ if x is not in the substitution, but +-- this can lead to cycles where e.g. x is mapped to y, and later we +-- unify @y = x@ resulting in both @x |-> y@ and @y |-> x@ in the +-- substitution, which at best leads to a spurious infinite type +-- error, and at worst leads to infinite recursion in the unify function. +-- +-- Peyton Jones et al. show how to do it correctly: when unifying x = y and +-- x is not mapped in the substitution, we must also look up y. + +-- | Implementation of the 'Unification' effect in terms of the +-- 'UnificationC' carrier. +instance Algebra sig m => Algebra (Unification :+: sig) (UnificationC m) where + alg hdl sig ctx = UnificationC $ case sig of + L (Unify t1 t2) -> (<$ ctx) <$> runThrow (unify t1 t2) + L (ApplyBindings t) -> do + s <- get @(Subst IntVar UType) + (<$ ctx) <$> subst s t + L FreshIntVar -> do + v <- IntVar <$> gets getFreshVarCounter + modify @FreshVarCounter succ + return $ v <$ ctx + L (FreeUVars t) -> do + s <- get @(Subst IntVar UType) + (<$ ctx) . fuvs <$> subst s t + R other -> alg (unUnificationC . hdl) (R (R (R other))) ctx + +-- | Unify two types, returning a unified type equal to both. Note +-- that for efficiency we /don't/ do an occurs check here, but +-- instead lazily during substitution. +unify :: + ( Has (Throw UnificationError) sig m + , Has (State (Subst IntVar UType)) sig m + ) => + UType -> + UType -> + m UType +unify ty1 ty2 = case (ty1, ty2) of + (Pure x, Pure y) | x == y -> pure (Pure x) + (Pure x, y) -> do + mxv <- lookupS x + case mxv of + Nothing -> unifyVar x y + Just xv -> unify xv y + (x, Pure y) -> unify (Pure y) x + (Free t1, Free t2) -> Free <$> unifyF t1 t2 + +-- | Unify a unification variable which /is not/ bound by the current +-- substitution with another term. If the other term is also a +-- variable, we must look it up as well to see if it is bound. +unifyVar :: + ( Has (Throw UnificationError) sig m + , Has (State (Subst IntVar UType)) sig m + ) => + IntVar -> + UType -> + m UType +unifyVar x (Pure y) = do + myv <- lookupS y + case myv of + -- x = y but the variable y is not bound: just add (x |-> y) to + -- the current Subst + -- + -- Note, as an optimization we just call e.g. insert x (Pure y) + -- instead of building a singleton Subst with @(|->)@ and then + -- composing, since composition doesn't need to apply the newly + -- created binding to all the other values bound in the Subst. + Nothing -> modify @(Subst IntVar UType) (insert x (Pure y)) >> pure (Pure y) + -- x = y and y is bound to v: recurse on x = v. + Just yv -> unify (Pure x) yv + +-- x = t for a non-variable t: just add (x |-> t) to the Subst. +unifyVar x t = modify (insert x t) >> pure t + +-- | Perform unification on two non-variable terms: check that they +-- have the same top-level constructor and recurse on their +-- contents. +unifyF :: + ( Has (Throw UnificationError) sig m + , Has (State (Subst IntVar UType)) sig m + ) => + TypeF UType -> + TypeF UType -> + m (TypeF UType) +unifyF t1 t2 = case (t1, t2) of + (TyBaseF b1, TyBaseF b2) -> case b1 == b2 of + True -> pure t1 + False -> unifyErr + (TyBaseF {}, _) -> unifyErr + -- Note that *type variables* are not the same as *unification variables*. + -- Type variables must match exactly. + (TyVarF v1, TyVarF v2) -> case v1 == v2 of + True -> pure t1 + False -> unifyErr + (TyVarF {}, _) -> unifyErr + (TySumF t11 t12, TySumF t21 t22) -> TySumF <$> unify t11 t21 <*> unify t12 t22 + (TySumF {}, _) -> unifyErr + (TyProdF t11 t12, TyProdF t21 t22) -> TyProdF <$> unify t11 t21 <*> unify t12 t22 + (TyProdF {}, _) -> unifyErr + (TyRcdF m1, TyRcdF m2) -> + case ((==) `on` M.keysSet) m1 m2 of + False -> unifyErr + _ -> fmap TyRcdF . sequence $ M.merge M.dropMissing M.dropMissing (M.zipWithMatched (const unify)) m1 m2 + (TyRcdF {}, _) -> unifyErr + (TyCmdF c1, TyCmdF c2) -> TyCmdF <$> unify c1 c2 + (TyCmdF {}, _) -> unifyErr + (TyDelayF c1, TyDelayF c2) -> TyDelayF <$> unify c1 c2 + (TyDelayF {}, _) -> unifyErr + (TyFunF t11 t12, TyFunF t21 t22) -> TyFunF <$> unify t11 t21 <*> unify t12 t22 + (TyFunF {}, _) -> unifyErr + where + unifyErr = throwError $ UnifyErr t1 t2 diff --git a/src/swarm-lang/Swarm/Effect/Unify/Naive.hs b/src/swarm-lang/Swarm/Effect/Unify/Naive.hs new file mode 100644 index 000000000..1bb88af82 --- /dev/null +++ b/src/swarm-lang/Swarm/Effect/Unify/Naive.hs @@ -0,0 +1,183 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | +-- SPDX-License-Identifier: BSD-3-Clause +-- +-- Description: Naive (slow) substitution-based implementation of +-- unification. Uses a simple but expensive-to-maintain invariant on +-- substitutions, and returns a substitution from unification which +-- must then be composed with the substitution being tracked. +-- +-- Not used in Swarm, but useful for testing/comparison. +module Swarm.Effect.Unify.Naive where + +import Control.Algebra +import Control.Applicative (Alternative) +import Control.Carrier.State.Strict (StateC, evalState) +import Control.Carrier.Throw.Either (ThrowC, runThrow) +import Control.Category ((>>>)) +import Control.Effect.State (get, gets, modify) +import Control.Effect.Throw (Throw, throwError) +import Control.Monad.Free +import Control.Monad.Trans (MonadIO) +import Data.Function (on) +import Data.Map ((!?)) +import Data.Map qualified as M +import Data.Map.Merge.Lazy qualified as M +import Data.Maybe (fromMaybe) +import Data.Set qualified as S +import Swarm.Effect.Unify +import Swarm.Effect.Unify.Common +import Swarm.Language.Types hiding (Type) + +------------------------------------------------------------ +-- Substitutions + +-- | Class of things supporting substitution. @Substitutes n b a@ means +-- that we can apply a substitution of type @Subst n b@ to a +-- value of type @a@, replacing all the free names of type @n@ +-- inside the @a@ with values of type @b@, resulting in a new value +-- of type @a@. +class Substitutes n b a where + subst :: Subst n b -> a -> a + +-- | We can perform substitution on terms built up as the free monad +-- over a structure functor @f@. +instance (Show n, Ord n, Functor f) => Substitutes n (Free f n) (Free f n) where + subst s f = f >>= \n -> fromMaybe (Pure n) (getSubst s !? n) + +-- | Compose two substitutions. Applying @s1 \@\@ s2@ is the same as +-- applying first @s2@, then @s1@; that is, semantically, +-- composition of substitutions corresponds exactly to function +-- composition when they are considered as functions on terms. +-- +-- As one would expect, composition is associative and has 'idS' as +-- its identity. +(@@) :: (Ord n, Substitutes n a a) => Subst n a -> Subst n a -> Subst n a +(Subst s1) @@ (Subst s2) = Subst (M.map (subst (Subst s1)) s2 `M.union` s1) + +-- | Compose a whole container of substitutions. For example, +-- @compose [s1, s2, s3] = s1 \@\@ s2 \@\@ s3@. +compose :: (Ord n, Substitutes n a a, Foldable t) => t (Subst n a) -> Subst n a +compose = foldr (@@) idS + +------------------------------------------------------------ +-- Carrier type + +-- Note: this carrier type and the runUnification function are +-- identical between this module and Swarm.Effect.Unify.Fast, but it +-- seemed best to duplicate it, so we can modify the carriers +-- independently in the future if we want. + +-- | Carrier type for unification: we maintain a current substitution, +-- a counter for generating fresh unification variables, and can +-- throw unification errors. +newtype UnificationC m a = UnificationC + { unUnificationC :: + StateC (Subst IntVar UType) (StateC FreshVarCounter (ThrowC UnificationError m)) a + } + deriving newtype (Functor, Applicative, Alternative, Monad, MonadIO) + +-- | Counter for generating fresh unification variables. +newtype FreshVarCounter = FreshVarCounter {getFreshVarCounter :: Int} + deriving (Eq, Ord, Enum) + +-- | Run a 'Unification' effect via the 'UnificationC' carrier. +runUnification :: Algebra sig m => UnificationC m a -> m (Either UnificationError a) +runUnification = + unUnificationC >>> evalState idS >>> evalState (FreshVarCounter 0) >>> runThrow + +------------------------------------------------------------ +-- Unification + +-- | Naive implementation of the 'Unification' effect in terms of the +-- 'UnificationC' carrier. +-- +-- We maintain an invariant on the current @Subst@ that map keys +-- never show up in any of the values. For example, we could have +-- @{x -> a+5, y -> 5}@ but not @{x -> a+y, y -> 5}@. +instance Algebra sig m => Algebra (Unification :+: sig) (UnificationC m) where + alg hdl sig ctx = UnificationC $ case sig of + L (Unify t1 t2) -> do + s1 <- get @(Subst IntVar UType) + let t1' = subst s1 t1 + t2' = subst s1 t2 + s2 <- unify t1' t2' + modify (s2 @@) + return $ Right (subst s2 t1') <$ ctx + L (ApplyBindings t) -> do + s <- get @(Subst IntVar UType) + return $ subst s t <$ ctx + L FreshIntVar -> do + v <- IntVar <$> gets getFreshVarCounter + modify @FreshVarCounter succ + return $ v <$ ctx + L (FreeUVars t) -> do + s <- get @(Subst IntVar UType) + return $ fuvs (subst s t) <$ ctx + R other -> alg (unUnificationC . hdl) (R (R (R other))) ctx + +-- | Unify two types and return the mgu, i.e. the smallest +-- substitution which makes them equal. +unify :: + Has (Throw UnificationError) sig m => + UType -> + UType -> + m (Subst IntVar UType) +unify ty1 ty2 = case (ty1, ty2) of + (Pure x, Pure y) + | x == y -> return idS + | otherwise -> return $ x |-> Pure y + (Pure x, y) + | x `S.member` fuvs y -> throwError $ Infinite x y + | otherwise -> return $ x |-> y + (y, Pure x) + | x `S.member` fuvs y -> throwError $ Infinite x y + | otherwise -> return $ x |-> y + (Free t1, Free t2) -> unifyF t1 t2 + +-- | Unify two non-variable terms and return an mgu, i.e. the smallest +-- substitution which makes them equal. +unifyF :: + Has (Throw UnificationError) sig m => + TypeF UType -> + TypeF UType -> + m (Subst IntVar UType) +unifyF t1 t2 = case (t1, t2) of + (TyBaseF b1, TyBaseF b2) -> case b1 == b2 of + True -> return idS + False -> unifyErr + (TyBaseF {}, _) -> unifyErr + (TyVarF v1, TyVarF v2) -> case v1 == v2 of + True -> return idS + False -> unifyErr + (TyVarF {}, _) -> unifyErr + (TySumF t11 t12, TySumF t21 t22) -> do + s1 <- unify t11 t21 + s2 <- unify t12 t22 + return $ s1 @@ s2 + (TySumF {}, _) -> unifyErr + (TyProdF t11 t12, TyProdF t21 t22) -> do + s1 <- unify t11 t21 + s2 <- unify t12 t22 + return $ s1 @@ s2 + (TyProdF {}, _) -> unifyErr + (TyRcdF m1, TyRcdF m2) -> + case ((==) `on` M.keysSet) m1 m2 of + False -> unifyErr + _ -> (fmap compose . sequence) (M.merge M.dropMissing M.dropMissing (M.zipWithMatched (const unify)) m1 m2) + (TyRcdF {}, _) -> unifyErr + (TyCmdF c1, TyCmdF c2) -> unify c1 c2 + (TyCmdF {}, _) -> unifyErr + (TyDelayF c1, TyDelayF c2) -> unify c1 c2 + (TyDelayF {}, _) -> unifyErr + (TyFunF t11 t12, TyFunF t21 t22) -> do + s1 <- unify t11 t21 + s2 <- unify t12 t22 + return $ s1 @@ s2 + (TyFunF {}, _) -> unifyErr + where + unifyErr = throwError $ UnifyErr t1 t2 diff --git a/src/swarm-lang/Swarm/Language/Context.hs b/src/swarm-lang/Swarm/Language/Context.hs index 3a6a5f07d..008cad504 100644 --- a/src/swarm-lang/Swarm/Language/Context.hs +++ b/src/swarm-lang/Swarm/Language/Context.hs @@ -7,9 +7,10 @@ -- types, values, or capability sets) used throughout the codebase. module Swarm.Language.Context where +import Control.Algebra (Has) +import Control.Effect.Reader (Reader, local) import Control.Lens.Empty (AsEmpty (..)) import Control.Lens.Prism (prism) -import Control.Monad.Reader (MonadReader, local) import Data.Aeson (FromJSON, ToJSON) import Data.Data (Data) import Data.Map (Map) @@ -70,10 +71,10 @@ union :: Ctx t -> Ctx t -> Ctx t union (Ctx c1) (Ctx c2) = Ctx (c2 `M.union` c1) -- | Locally extend the context with an additional binding. -withBinding :: MonadReader (Ctx t) m => Var -> t -> m a -> m a +withBinding :: Has (Reader (Ctx t)) sig m => Var -> t -> m a -> m a withBinding x ty = local (addBinding x ty) -- | Locally extend the context with an additional context of -- bindings. -withBindings :: MonadReader (Ctx t) m => Ctx t -> m a -> m a +withBindings :: Has (Reader (Ctx t)) sig m => Ctx t -> m a -> m a withBindings ctx = local (`union` ctx) diff --git a/src/swarm-lang/Swarm/Language/Pretty.hs b/src/swarm-lang/Swarm/Language/Pretty.hs index 642fc5fb9..f6da84170 100644 --- a/src/swarm-lang/Swarm/Language/Pretty.hs +++ b/src/swarm-lang/Swarm/Language/Pretty.hs @@ -10,10 +10,9 @@ module Swarm.Language.Pretty where import Control.Lens.Combinators (pattern Empty) -import Control.Unification -import Control.Unification.IntVar +import Control.Monad.Free (Free (..)) import Data.Bool (bool) -import Data.Functor.Fixedpoint (Fix, unFix) +import Data.Fix import Data.List.NonEmpty ((<|)) import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M @@ -25,6 +24,7 @@ import Data.Text qualified as T import Prettyprinter import Prettyprinter.Render.String qualified as RS import Prettyprinter.Render.Text qualified as RT +import Swarm.Effect.Unify (UnificationError (..)) import Swarm.Language.Capability import Swarm.Language.Context import Swarm.Language.Parse (getLocRange) @@ -136,16 +136,16 @@ instance UnchainableFun Type where unchainFun (a :->: ty) = a <| unchainFun ty unchainFun ty = pure ty -instance UnchainableFun (UTerm TypeF ty) where - unchainFun (UTerm (TyFunF ty1 ty2)) = ty1 <| unchainFun ty2 +instance UnchainableFun (Free TypeF ty) where + unchainFun (Free (TyFunF ty1 ty2)) = ty1 <| unchainFun ty2 unchainFun ty = pure ty instance (PrettyPrec (t (Fix t))) => PrettyPrec (Fix t) where prettyPrec p = prettyPrec p . unFix -instance (PrettyPrec (t (UTerm t v)), PrettyPrec v) => PrettyPrec (UTerm t v) where - prettyPrec p (UTerm t) = prettyPrec p t - prettyPrec p (UVar v) = prettyPrec p v +instance (PrettyPrec (t (Free t v)), PrettyPrec v) => PrettyPrec (Free t v) where + prettyPrec p (Free t) = prettyPrec p t + prettyPrec p (Pure v) = prettyPrec p v instance ((UnchainableFun t), (PrettyPrec t)) => PrettyPrec (TypeF t) where prettyPrec _ (TyBaseF b) = ppr b @@ -331,8 +331,7 @@ prettyTypeErr code (CTE l tcStack te) = instance PrettyPrec TypeErr where prettyPrec _ = \case - UnifyErr ty1 ty2 -> - "Can't unify" <+> ppr ty1 <+> "and" <+> ppr ty2 + UnificationErr ue -> ppr ue Mismatch Nothing (getJoin -> (ty1, ty2)) -> "Type mismatch: expected" <+> ppr ty1 <> ", but got" <+> ppr ty2 Mismatch (Just t) (getJoin -> (ty1, ty2)) -> @@ -349,8 +348,6 @@ instance PrettyPrec TypeErr where "Skolem variable" <+> pretty x <+> "would escape its scope" UnboundVar x -> "Unbound variable" <+> pretty x - Infinite x uty -> - "Infinite type:" <+> ppr x <+> "=" <+> ppr uty DefNotTopLevel t -> "Definitions may only be at the top level:" <+> pprCode t CantInfer t -> @@ -367,6 +364,13 @@ instance PrettyPrec TypeErr where pprCode :: PrettyPrec a => a -> Doc ann pprCode = bquote . ppr +instance PrettyPrec UnificationError where + prettyPrec _ = \case + Infinite x uty -> + "Infinite type:" <+> ppr x <+> "=" <+> ppr uty + UnifyErr ty1 ty2 -> + "Can't unify" <+> ppr ty1 <+> "and" <+> ppr ty2 + -- | Given a type and its source, construct an appropriate description -- of it to go in a type mismatch error message. typeDescription :: Source -> UType -> Doc a @@ -385,11 +389,11 @@ hasAnyUVars = ucata (const True) or -- | Check whether a type consists of a top-level type constructor -- immediately applied to unification variables. isTopLevelConstructor :: UType -> Maybe (TypeF ()) -isTopLevelConstructor (UTyCmd (UVar {})) = Just $ TyCmdF () -isTopLevelConstructor (UTyDelay (UVar {})) = Just $ TyDelayF () -isTopLevelConstructor (UTySum (UVar {}) (UVar {})) = Just $ TySumF () () -isTopLevelConstructor (UTyProd (UVar {}) (UVar {})) = Just $ TyProdF () () -isTopLevelConstructor (UTyFun (UVar {}) (UVar {})) = Just $ TyFunF () () +isTopLevelConstructor (UTyCmd (Pure {})) = Just $ TyCmdF () +isTopLevelConstructor (UTyDelay (Pure {})) = Just $ TyDelayF () +isTopLevelConstructor (UTySum (Pure {}) (Pure {})) = Just $ TySumF () () +isTopLevelConstructor (UTyProd (Pure {}) (Pure {})) = Just $ TyProdF () () +isTopLevelConstructor (UTyFun (Pure {}) (Pure {})) = Just $ TyFunF () () isTopLevelConstructor _ = Nothing -- | Return an English noun phrase describing things with the given diff --git a/src/swarm-lang/Swarm/Language/Typecheck.hs b/src/swarm-lang/Swarm/Language/Typecheck.hs index 9b5c12671..1da3f7264 100644 --- a/src/swarm-lang/Swarm/Language/Typecheck.hs +++ b/src/swarm-lang/Swarm/Language/Typecheck.hs @@ -2,9 +2,6 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} - --- For 'Ord IntVar' instance -- | -- SPDX-License-Identifier: BSD-3-Clause @@ -29,17 +26,11 @@ module Swarm.Language.Typecheck ( LocatedTCFrame (..), TCStack, withFrame, - getTCStack, -- * Typechecking monad - TC, - runTC, fresh, -- * Unification - substU, - unify, - HasBindings (..), instantiate, skolemize, generalize, @@ -54,24 +45,17 @@ module Swarm.Language.Typecheck ( ) where import Control.Arrow ((***)) +import Control.Carrier.Error.Either (ErrorC, runError) +import Control.Carrier.Reader (ReaderC, runReader) import Control.Category ((>>>)) +import Control.Effect.Catch (Catch, catchError) +import Control.Effect.Error (Error) +import Control.Effect.Reader +import Control.Effect.Throw import Control.Lens ((^.)) import Control.Lens.Indexed (itraverse) -import Control.Monad (forM_, void, when, (<=<)) -import Control.Monad.Except ( - ExceptT, - MonadError (catchError, throwError), - runExceptT, - ) -import Control.Monad.Reader ( - MonadReader (ask, local), - ReaderT (runReaderT), - mapReaderT, - ) -import Control.Monad.Trans.Class (MonadTrans (lift)) -import Control.Unification hiding (applyBindings, unify, (=:=)) -import Control.Unification qualified as U -import Control.Unification.IntVar +import Control.Monad (forM_, when, (<=<), (>=>)) +import Control.Monad.Free (Free (..)) import Data.Data (Data, gmapM) import Data.Foldable (fold) import Data.Functor.Identity @@ -82,12 +66,14 @@ import Data.Maybe import Data.Set (Set, (\\)) import Data.Set qualified as S import Data.Text qualified as T +import Swarm.Effect.Unify (Unification, UnificationError, (=:=)) +import Swarm.Effect.Unify qualified as U +import Swarm.Effect.Unify.Fast qualified as U import Swarm.Language.Context hiding (lookup) import Swarm.Language.Context qualified as Ctx import Swarm.Language.Module import Swarm.Language.Parse.QQ (tyQ) import Swarm.Language.Syntax -import Swarm.Language.Typecheck.Unify import Swarm.Language.Types import Prelude hiding (lookup) @@ -113,6 +99,10 @@ data LocatedTCFrame = LocatedTCFrame SrcLoc TCFrame -- middle of doing during typechecking. type TCStack = [LocatedTCFrame] +-- | Push a frame on the typechecking stack. +withFrame :: Has (Reader TCStack) sig m => SrcLoc -> TCFrame -> m a -> m a +withFrame l f = local (LocatedTCFrame l f :) + ------------------------------------------------------------ -- Type source @@ -155,67 +145,84 @@ getJoin :: Join a -> (a, a) getJoin (Join j) = (j Expected, j Actual) ------------------------------------------------------------ --- Type checking monad - --- | The concrete monad used for type checking. 'IntBindingT' is a --- monad transformer provided by the @unification-fd@ library which --- supports various operations such as generating fresh variables --- and unifying things. --- --- Note that we are sort of constrained to use a concrete monad stack by --- @unification-fd@, which has some strange types on some of its exported --- functions that actually require various monad transformers to be stacked --- in certain ways. For example, see . I don't really see a way --- to use "capability style" like we do elsewhere in the codebase. -type TC = ReaderT UCtx (ReaderT TCStack (ExceptT ContextualTypeErr (IntBindingT TypeF Identity))) - --- | Push a frame on the typechecking stack within a local 'TC' --- computation. -withFrame :: SrcLoc -> TCFrame -> TC a -> TC a -withFrame l f = mapReaderT (local (LocatedTCFrame l f :)) - --- | Get the current typechecking stack. -getTCStack :: TC TCStack -getTCStack = lift ask - ------------------------------------------------------------- +-- Type checking + +fromUModule :: + ( Has Unification sig m + , Has (Reader UCtx) sig m + , Has (Throw ContextualTypeErr) sig m + ) => + UModule -> + m TModule +fromUModule (Module u uctx) = + Module + <$> mapM (checkPredicative <=< (fmap fromU . generalize)) u + <*> checkPredicative (fromU uctx) + +finalizeUModule :: + ( Has Unification sig m + , Has (Reader UCtx) sig m + , Has (Throw ContextualTypeErr) sig m + ) => + UModule -> + m TModule +finalizeUModule = applyBindings >=> fromUModule + +-- | Version of 'runTC' which is generic in the base monad. +runTC' :: + Algebra sig m => + TCtx -> + ReaderC UCtx (ReaderC TCStack (ErrorC ContextualTypeErr (U.UnificationC m))) UModule -> + m (Either ContextualTypeErr TModule) +runTC' ctx = + (>>= finalizeUModule) + >>> runReader (toU ctx) + >>> runReader [] + >>> runError + >>> U.runUnification + >>> fmap reportUnificationError -- | Run a top-level inference computation, returning either a --- 'TypeErr' or a fully resolved 'TModule'. -runTC :: TCtx -> TC UModule -> Either ContextualTypeErr TModule -runTC ctx = - (>>= applyBindings) - >>> ( >>= - \(Module u uctx) -> - Module - <$> mapM (checkPredicative <=< (fmap fromU . generalize)) u - <*> checkPredicative (fromU uctx) - ) - >>> flip runReaderT (toU ctx) - >>> flip runReaderT [] - >>> runExceptT - >>> evalIntBindingT - >>> runIdentity - -checkPredicative :: Maybe a -> TC a +-- 'ContextualTypeErr' or a fully resolved 'TModule'. +runTC :: + TCtx -> + ReaderC UCtx (ReaderC TCStack (ErrorC ContextualTypeErr (U.UnificationC Identity))) UModule -> + Either ContextualTypeErr TModule +runTC tctx = runTC' tctx >>> runIdentity + +checkPredicative :: Has (Throw ContextualTypeErr) sig m => Maybe a -> m a checkPredicative = maybe (throwError (mkRawTypeErr Impredicative)) pure +reportUnificationError :: Either UnificationError (Either ContextualTypeErr a) -> Either ContextualTypeErr a +reportUnificationError = either (Left . mkRawTypeErr . UnificationErr) id + -- | Look up a variable in the ambient type context, either throwing -- an 'UnboundVar' error if it is not found, or opening its -- associated 'UPolytype' with fresh unification variables via -- 'instantiate'. -lookup :: SrcLoc -> Var -> TC UType +lookup :: + ( Has (Throw ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + , Has (Reader UCtx) sig m + , Has Unification sig m + ) => + SrcLoc -> + Var -> + m UType lookup loc x = do - ctx <- getCtx + ctx <- ask @UCtx maybe (throwTypeErr loc $ UnboundVar x) instantiate (Ctx.lookup x ctx) --- | Get the current type context. -getCtx :: TC UCtx -getCtx = ask - -- | Catch any thrown type errors and re-throw them with an added source -- location. -addLocToTypeErr :: SrcLoc -> TC a -> TC a +addLocToTypeErr :: + ( Has (Throw ContextualTypeErr) sig m + , Has (Catch ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + ) => + SrcLoc -> + m a -> + m a addLocToTypeErr l m = m `catchError` \case CTE NoLoc _ te -> throwTypeErr l te @@ -225,30 +232,25 @@ addLocToTypeErr l m = -- Dealing with variables: free variables, fresh variables, -- substitution --- | @unification-fd@ does not provide an 'Ord' instance for 'IntVar', --- so we must provide our own, in order to be able to store --- 'IntVar's in a 'Set'. -deriving instance Ord IntVar - -- | A class for getting the free unification variables of a thing. -class FreeVars a where - freeVars :: a -> TC (Set IntVar) +class FreeUVars a where + freeUVars :: Has Unification sig m => a -> m (Set IntVar) -- | We can get the free unification variables of a 'UType'. -instance FreeVars UType where - freeVars ut = fmap S.fromList . lift . lift . lift $ getFreeVars ut +instance FreeUVars UType where + freeUVars = U.freeUVars -- | We can also get the free variables of a polytype. -instance (FreeVars t) => FreeVars (Poly t) where - freeVars (Forall _ t) = freeVars t +instance (FreeUVars t) => FreeUVars (Poly t) where + freeUVars (Forall _ t) = freeUVars t -- | We can get the free variables in any polytype in a context. -instance FreeVars UCtx where - freeVars = fmap S.unions . mapM freeVars . M.elems . unCtx +instance FreeUVars UCtx where + freeUVars = fmap S.unions . mapM freeUVars . M.elems . unCtx -- | Generate a fresh unification variable. -fresh :: TC UType -fresh = UVar <$> (lift . lift . lift $ freeVar) +fresh :: Has Unification sig m => m UType +fresh = Pure <$> U.freshIntVar -- | Perform a substitution over a 'UType', substituting for both type -- and unification variables. Note that since 'UType's do not have @@ -257,14 +259,21 @@ fresh = UVar <$> (lift . lift . lift $ freeVar) substU :: Map (Either Var IntVar) UType -> UType -> UType substU m = ucata - (\v -> fromMaybe (UVar v) (M.lookup (Right v) m)) + (\v -> fromMaybe (Pure v) (M.lookup (Right v) m)) ( \case TyVarF v -> fromMaybe (UTyVar v) (M.lookup (Left v) m) - f -> UTerm f + f -> Free f ) -- | Make sure no skolem variables escape. -noSkolems :: SrcLoc -> UPolytype -> TC () +noSkolems :: + ( Has Unification sig m + , Has (Reader TCStack) sig m + , Has (Throw ContextualTypeErr) sig m + ) => + SrcLoc -> + Poly UType -> + m () noSkolems l (Forall xs upty) = do upty' <- applyBindings upty let tyvs = @@ -285,41 +294,33 @@ noSkolems l (Forall xs upty) = do -- doing the throwTypeErr either zero or one time, depending on -- whether lookupMin returns Nothing or Just. ------------------------------------------------------------- --- Lifted stuff from unification-fd - -infix 4 =:= - -- | @unify t expTy actTy@ ensures that the given two types are equal. -- If we know the actual term @t@ which is supposed to have these -- types, we can use it to generate better error messages. --- --- We first do a quick-and-dirty check to see whether we know for --- sure the types either are or cannot be equal, generating an --- equality constraint for the unifier as a last resort. -unify :: Maybe Syntax -> TypeJoin -> TC UType -unify ms j = case unifyCheck expected actual of - Apart -> throwTypeErr NoLoc $ Mismatch ms j - Equal -> return expected - MightUnify -> lift . lift $ expected U.=:= actual +unify :: + ( Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + ) => + Maybe Syntax -> + TypeJoin -> + m UType +unify ms j = do + res <- expected =:= actual + case res of + Left _ -> throwTypeErr NoLoc $ Mismatch ms j + Right ty -> return ty where (expected, actual) = getJoin j --- | Ensure two types are the same. -(=:=) :: UType -> UType -> TC UType -ty1 =:= ty2 = unify Nothing (joined ty1 ty2) - --- | @unification-fd@ provides a function 'U.applyBindings' which --- fully substitutes for any bound unification variables (for --- efficiency, it does not perform such substitution as it goes --- along). The 'HasBindings' class is for anything which has +-- | The 'HasBindings' class is for anything which has -- unification variables in it and to which we can usefully apply --- 'U.applyBindings'. +-- 'applyBindings'. class HasBindings u where - applyBindings :: u -> TC u + applyBindings :: Has Unification sig m => u -> m u instance HasBindings UType where - applyBindings = lift . lift . U.applyBindings + applyBindings = U.applyBindings instance HasBindings UPolytype where applyBindings (Forall xs u) = Forall xs <$> applyBindings u @@ -336,13 +337,13 @@ instance (HasBindings u, Data u) => HasBindings (Syntax' u) where instance HasBindings UModule where applyBindings (Module u uctx) = Module <$> applyBindings u <*> applyBindings uctx ------------------------------------------------------------- --- Converting between mono- and polytypes +-- ------------------------------------------------------------ +-- -- Converting between mono- and polytypes -- | To 'instantiate' a 'UPolytype', we generate a fresh unification -- variable for each variable bound by the `Forall`, and then -- substitute them throughout the type. -instantiate :: UPolytype -> TC UType +instantiate :: Has Unification sig m => UPolytype -> m UType instantiate (Forall xs uty) = do xs' <- mapM (const fresh) xs return $ substU (M.fromList (zip (map Left xs) xs')) uty @@ -352,12 +353,12 @@ instantiate (Forall xs uty) = do -- variables cannot unify with anything other than themselves. This -- is used when checking something with a polytype explicitly -- specified by the user. -skolemize :: UPolytype -> TC UType +skolemize :: Has Unification sig m => UPolytype -> m UType skolemize (Forall xs uty) = do xs' <- mapM (const fresh) xs return $ substU (M.fromList (zip (map Left xs) (map toSkolem xs'))) uty where - toSkolem (UVar v) = UTyVar (mkVarName "s" v) + toSkolem (Pure v) = UTyVar (mkVarName "s" v) toSkolem x = error $ "Impossible! Non-UVar in skolemize.toSkolem: " ++ show x -- | 'generalize' is the opposite of 'instantiate': add a 'Forall' @@ -365,12 +366,12 @@ skolemize (Forall xs uty) = do -- -- Pick nice type variable names instead of reusing whatever fresh -- names happened to be used for the free variables. -generalize :: UType -> TC UPolytype +generalize :: (Has Unification sig m, Has (Reader UCtx) sig m) => UType -> m UPolytype generalize uty = do uty' <- applyBindings uty - ctx <- getCtx - tmfvs <- freeVars uty' - ctxfvs <- freeVars ctx + ctx <- ask @UCtx + tmfvs <- freeUVars uty' + ctxfvs <- freeUVars ctx let fvs = S.toList $ tmfvs \\ ctxfvs alphabet = ['a' .. 'z'] -- Infinite supply of pretty names a, b, ..., z, a0, ... z0, a1, ... z1, ... @@ -399,9 +400,15 @@ mkTypeErr :: SrcLoc -> TCStack -> TypeErr -> ContextualTypeErr mkTypeErr = CTE -- | Throw a 'ContextualTypeErr'. -throwTypeErr :: SrcLoc -> TypeErr -> TC a +throwTypeErr :: + ( Has (Throw ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + ) => + SrcLoc -> + TypeErr -> + m a throwTypeErr l te = do - stk <- getTCStack + stk <- ask @TCStack throwError $ mkTypeErr l stk te -- | Errors that can occur during type checking. The idea is that @@ -415,10 +422,8 @@ data TypeErr | -- | A Skolem variable escaped its local context. EscapedSkolem Var | -- | Occurs check failure, i.e. infinite type. - Infinite IntVar UType - | -- | Error generated by the unifier. - UnifyErr (TypeF UType) (TypeF UType) - | -- | Type mismatch caught by 'unifyCheck'. The given term was + UnificationErr UnificationError + | -- | Type mismatch caught by 'unify'. The given term was -- expected to have a certain type, but has a different type -- instead. Mismatch (Maybe Syntax) TypeJoin @@ -459,17 +464,20 @@ data InvalidAtomicReason LongConst deriving (Show) -instance Fallible TypeF IntVar ContextualTypeErr where - occursFailure v t = mkRawTypeErr (Infinite v t) - mismatchFailure t1 t2 = mkRawTypeErr (UnifyErr t1 t2) - ------------------------------------------------------------ -- Type decomposition -- | Decompose a type that is supposed to be a delay type. Also take -- the term which is supposed to have that type, for use in error -- messages. -decomposeDelayTy :: Syntax -> Sourced UType -> TC UType +decomposeDelayTy :: + ( Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + ) => + Syntax -> + Sourced UType -> + m UType decomposeDelayTy _ (_, UTyDelay a) = return a decomposeDelayTy t ty = do a <- fresh @@ -479,7 +487,14 @@ decomposeDelayTy t ty = do -- | Decompose a type that is supposed to be a command type. Also take -- the term which is supposed to have that type, for use in error -- messages. -decomposeCmdTy :: Syntax -> Sourced UType -> TC UType +decomposeCmdTy :: + ( Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + ) => + Syntax -> + Sourced UType -> + m UType decomposeCmdTy _ (_, UTyCmd a) = return a decomposeCmdTy t ty = do a <- fresh @@ -489,7 +504,14 @@ decomposeCmdTy t ty = do -- | Decompose a type that is supposed to be a function type. Also take -- the term which is supposed to have that type, for use in error -- messages. -decomposeFunTy :: Syntax -> Sourced UType -> TC (UType, UType) +decomposeFunTy :: + ( Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + ) => + Syntax -> + Sourced UType -> + m (UType, UType) decomposeFunTy _ (_, UTyFun ty1 ty2) = return (ty1, ty2) decomposeFunTy t ty = do ty1 <- fresh @@ -500,7 +522,14 @@ decomposeFunTy t ty = do -- | Decompose a type that is supposed to be a product type. Also take -- the term which is supposed to have that type, for use in error -- messages. -decomposeProdTy :: Syntax -> Sourced UType -> TC (UType, UType) +decomposeProdTy :: + ( Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + , Has (Reader TCStack) sig m + ) => + Syntax -> + Sourced UType -> + m (UType, UType) decomposeProdTy _ (_, UTyProd ty1 ty2) = return (ty1, ty2) decomposeProdTy t ty = do ty1 <- fresh @@ -508,8 +537,8 @@ decomposeProdTy t ty = do _ <- unify (Just t) (mkJoin ty (UTyProd ty1 ty2)) return (ty1, ty2) ------------------------------------------------------------- --- Type inference / checking +-- ------------------------------------------------------------ +-- -- Type inference / checking -- | Top-level type inference function: given a context of definition -- types and a top-level term, either return a type error or its @@ -519,7 +548,14 @@ inferTop ctx = runTC ctx . inferModule -- | Infer the signature of a top-level expression which might -- contain definitions. -inferModule :: Syntax -> TC UModule +inferModule :: + ( Has Unification sig m + , Has (Reader UCtx) sig m + , Has (Reader TCStack) sig m + , Has (Error ContextualTypeErr) sig m + ) => + Syntax -> + m UModule inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of -- For definitions with no type signature, make up a fresh type -- variable for the body, infer the body under an extended context, @@ -604,7 +640,14 @@ inferModule s@(Syntax l t) = addLocToTypeErr l $ case t of -- -- For most everything else we prefer 'check' because it can often -- result in better and more localized type error messages. -infer :: Syntax -> TC (Syntax' UType) +infer :: + ( Has (Reader UCtx) sig m + , Has (Reader TCStack) sig m + , Has Unification sig m + , Has (Error ContextualTypeErr) sig m + ) => + Syntax -> + m (Syntax' UType) infer s@(Syntax l t) = addLocToTypeErr l $ case t of -- Primitives, i.e. things for which we immediately know the only -- possible correct type, and knowing an expected type would provide @@ -719,7 +762,7 @@ infer s@(Syntax l t) = addLocToTypeErr l $ case t of uty <- skolemize upty _ <- check c uty -- Make sure no skolem variables have escaped. - getCtx >>= mapM_ (noSkolems l) + ask @UCtx >>= mapM_ (noSkolems l) -- If check against skolemized polytype is successful, -- instantiate polytype with unification variables. -- Free variables should be able to unify with anything in @@ -853,7 +896,15 @@ inferConst c = case c of -- -- We try to stay in checking mode as far as possible, decomposing -- the expected type as we go and pushing it through the recursion. -check :: Syntax -> UType -> TC (Syntax' UType) +check :: + ( Has (Reader UCtx) sig m + , Has (Reader TCStack) sig m + , Has Unification sig m + , Has (Error ContextualTypeErr) sig m + ) => + Syntax -> + UType -> + m (Syntax' UType) check s@(Syntax l t) expected = addLocToTypeErr l $ case t of -- if t : ty, then {t} : {ty}. -- Note that in theory, if the @Maybe Var@ component of the @SDelay@ @@ -881,15 +932,15 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of SLam x mxTy body -> do (argTy, resTy) <- decomposeFunTy s (Expected, expected) case toU mxTy of - Just xTy -> case unifyCheck argTy xTy of - -- Generate a special error when the explicit type annotation - -- on a lambda doesn't match the expected type, - -- e.g. (\x:int. x + 2) : text -> int, since the usual - -- "expected/but got" language would probably be confusing. - Apart -> throwTypeErr l $ LambdaArgMismatch (joined argTy xTy) - -- Otherwise, make sure to unify the annotation with the - -- expected argument type. - _ -> void $ argTy =:= xTy + Just xTy -> do + res <- argTy =:= xTy + case res of + -- Generate a special error when the explicit type annotation + -- on a lambda doesn't match the expected type, + -- e.g. (\x:int. x + 2) : text -> int, since the usual + -- "expected/but got" language would probably be confusing. + Left _ -> throwTypeErr l $ LambdaArgMismatch (joined argTy xTy) + Right _ -> return () Nothing -> return () body' <- withBinding (lvVar x) (Forall [] argTy) $ check body resTy return $ Syntax' l (SLam x mxTy body') (UTyFun argTy resTy) @@ -939,7 +990,7 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of t2' <- withBinding (lvVar x) upty $ check t2 expected -- Make sure no skolem variables have escaped. - getCtx >>= mapM_ (noSkolems l) + ask @UCtx >>= mapM_ (noSkolems l) -- Return the annotated let. return $ Syntax' l (SLet r x mxTy t1' t2') expected @@ -1015,7 +1066,14 @@ check s@(Syntax l t) expected = addLocToTypeErr l $ case t of -- i.e. contains at most one tangible command. For example, @atomic -- (move; move)@ is invalid, since that would allow robots to move -- twice as fast as usual by doing both actions in one tick. -validAtomic :: Syntax -> TC () +validAtomic :: + ( Has (Reader UCtx) sig m + , Has (Reader TCStack) sig m + , Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + ) => + Syntax -> + m () validAtomic s@(Syntax l t) = do n <- analyzeAtomic S.empty s when (n > 1) $ throwTypeErr l $ InvalidAtomic (TooManyTicks n) t @@ -1023,7 +1081,15 @@ validAtomic s@(Syntax l t) = do -- | Analyze an argument to @atomic@: ensure it contains no nested -- atomic blocks and no references to external variables, and count -- how many tangible commands it will execute. -analyzeAtomic :: Set Var -> Syntax -> TC Int +analyzeAtomic :: + ( Has (Reader UCtx) sig m + , Has (Reader TCStack) sig m + , Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + ) => + Set Var -> + Syntax -> + m Int analyzeAtomic locals (Syntax l t) = case t of -- Literals, primitives, etc. that are fine and don't require a tick -- to evaluate @@ -1062,7 +1128,14 @@ analyzeAtomic locals (Syntax l t) = case t of SBind mx s1 s2 -> (+) <$> analyzeAtomic locals s1 <*> analyzeAtomic (maybe id (S.insert . lvVar) mx locals) s2 SRcd m -> sum <$> mapM analyzeField (M.assocs m) where - analyzeField :: (Var, Maybe Syntax) -> TC Int + analyzeField :: + ( Has (Reader UCtx) sig m + , Has (Reader TCStack) sig m + , Has Unification sig m + , Has (Throw ContextualTypeErr) sig m + ) => + (Var, Maybe Syntax) -> + m Int analyzeField (x, Nothing) = analyzeAtomic locals (STerm (TVar x)) analyzeField (_, Just s) = analyzeAtomic locals s SProj {} -> return 0 @@ -1070,7 +1143,7 @@ analyzeAtomic locals (Syntax l t) = case t of TVar x | x `S.member` locals -> return 0 | otherwise -> do - mxTy <- Ctx.lookup x <$> getCtx + mxTy <- Ctx.lookup x <$> ask @UCtx case mxTy of -- If the variable is undefined, return 0 to indicate the -- atomic block is valid, because we'd rather have the error @@ -1123,5 +1196,5 @@ isSimpleUType = \case UTyCmd {} -> False UTyDelay {} -> False -- Make the pattern-match coverage checker happy - UVar {} -> False - UTerm {} -> False + Pure {} -> False + Free {} -> False diff --git a/src/swarm-lang/Swarm/Language/Typecheck/Unify.hs b/src/swarm-lang/Swarm/Language/Typecheck/Unify.hs deleted file mode 100644 index c68308ebf..000000000 --- a/src/swarm-lang/Swarm/Language/Typecheck/Unify.hs +++ /dev/null @@ -1,90 +0,0 @@ --- | --- SPDX-License-Identifier: BSD-3-Clause --- --- Utilities related to type unification. -module Swarm.Language.Typecheck.Unify ( - UnifyStatus (..), - unifyCheck, -) where - -import Control.Unification -import Data.Foldable qualified as F -import Data.Function (on) -import Data.Map qualified as M -import Data.Map.Merge.Lazy qualified as M -import Swarm.Language.Types - --- | The result of doing a unification check on two types. -data UnifyStatus - = -- | The two types are definitely not equal; they will never unify - -- no matter how any unification variables get filled in. For - -- example, (int * u0) and (u1 -> u2) are apart: the first is a - -- product type and the second is a function type. - Apart - | -- | The two types might unify, depending on how unification - -- variables get filled in, but we're not sure. For example, - -- (int * u0) and (u1 * bool). - MightUnify - | -- | The two types are most definitely equal, and we don't need to - -- bother generating a constraint to make them so. For example, - -- (int * text) and (int * text). - Equal - deriving (Eq, Ord, Read, Show) - --- | The @Semigroup@ instance for @UnifyStatus@ is used to combine --- results for compound types. -instance Semigroup UnifyStatus where - -- If either part of a compound type is apart, then the whole thing is. - Apart <> _ = Apart - _ <> Apart = Apart - -- Otherwise, if we're unsure about either part of a compound type, - -- then we're unsure about the whole thing. - MightUnify <> _ = MightUnify - _ <> MightUnify = MightUnify - -- Finally, if both parts are definitely equal then the whole thing is. - Equal <> Equal = Equal - -instance Monoid UnifyStatus where - mempty = Equal - --- | Given two types, try hard to prove either that (1) they are --- 'Apart', i.e. cannot possibly unify, or (2) they are definitely --- 'Equal'. In case (1), we can generate a much better error --- message at the instant the two types come together than we could --- if we threw a constraint into the unifier. In case (2), we don't --- have to bother with generating a trivial constraint. If we don't --- know for sure whether they will unify, return 'MightUnify'. -unifyCheck :: UType -> UType -> UnifyStatus -unifyCheck ty1 ty2 = case (ty1, ty2) of - (UVar x, UVar y) - | x == y -> Equal - | otherwise -> MightUnify - (UVar _, _) -> MightUnify - (_, UVar _) -> MightUnify - (UTerm t1, UTerm t2) -> unifyCheckF t1 t2 - -unifyCheckF :: TypeF UType -> TypeF UType -> UnifyStatus -unifyCheckF t1 t2 = case (t1, t2) of - (TyBaseF b1, TyBaseF b2) -> case b1 == b2 of - True -> Equal - False -> Apart - (TyBaseF {}, _) -> Apart - (TyVarF v1, TyVarF v2) -> case v1 == v2 of - True -> Equal - False -> Apart - (TyVarF {}, _) -> Apart - (TySumF t11 t12, TySumF t21 t22) -> unifyCheck t11 t21 <> unifyCheck t12 t22 - (TySumF {}, _) -> Apart - (TyProdF t11 t12, TyProdF t21 t22) -> unifyCheck t11 t21 <> unifyCheck t12 t22 - (TyProdF {}, _) -> Apart - (TyRcdF m1, TyRcdF m2) -> - case ((==) `on` M.keysSet) m1 m2 of - False -> Apart - _ -> F.fold (M.merge M.dropMissing M.dropMissing (M.zipWithMatched (const unifyCheck)) m1 m2) - (TyRcdF {}, _) -> Apart - (TyCmdF c1, TyCmdF c2) -> unifyCheck c1 c2 - (TyCmdF {}, _) -> Apart - (TyDelayF c1, TyDelayF c2) -> unifyCheck c1 c2 - (TyDelayF {}, _) -> Apart - (TyFunF t11 t12, TyFunF t21 t22) -> unifyCheck t11 t21 <> unifyCheck t12 t22 - (TyFunF {}, _) -> Apart diff --git a/src/swarm-lang/Swarm/Language/Types.hs b/src/swarm-lang/Swarm/Language/Types.hs index 0d7f1fbc4..5ce808301 100644 --- a/src/swarm-lang/Swarm/Language/Types.hs +++ b/src/swarm-lang/Swarm/Language/Types.hs @@ -1,8 +1,6 @@ {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} - --- for the Data IntVar instance -- | -- SPDX-License-Identifier: BSD-3-Clause @@ -35,6 +33,7 @@ module Swarm.Language.Types ( pattern TyDelay, -- * @UType@ + IntVar (..), UType, pattern UTyBase, pattern UTyVar, @@ -56,6 +55,7 @@ module Swarm.Language.Types ( -- ** Utilities ucata, mkVarName, + fuvs, -- * Polytypes Poly (..), @@ -71,18 +71,15 @@ module Swarm.Language.Types ( WithU (..), ) where -import Control.Monad (guard) -import Control.Unification -import Control.Unification.IntVar +import Control.Monad.Free import Data.Aeson (FromJSON, ToJSON) +import Data.Aeson.TH (defaultOptions, deriveFromJSON1, deriveToJSON1) import Data.Data (Data) +import Data.Eq.Deriving (deriveEq1) +import Data.Fix import Data.Foldable (fold) -import Data.Function (on) -import Data.Functor.Fixedpoint import Data.Kind qualified -import Data.Map.Merge.Strict qualified as M import Data.Map.Strict (Map) -import Data.Map.Strict qualified as M import Data.Set (Set) import Data.Set qualified as S import Data.String (IsString (..)) @@ -90,6 +87,7 @@ import Data.Text (Text) import Data.Text qualified as T import GHC.Generics (Generic, Generic1) import Swarm.Language.Context +import Text.Show.Deriving (deriveShow1) import Witch ------------------------------------------------------------ @@ -122,8 +120,8 @@ data BaseTy -- | A "structure functor" encoding the shape of type expressions. -- Actual types are then represented by taking a fixed point of this -- functor. We represent types in this way, via a "two-level type", --- so that we can work with the @unification-fd@ package (see --- https://byorgey.wordpress.com/2021/09/08/implementing-hindley-milner-with-the-unification-fd-library/). +-- so that we can easily use generic recursion schemes to implement +-- things like substitution. data TypeF t = -- | A base type. TyBaseF BaseTy @@ -142,18 +140,12 @@ data TypeF t TyFunF t t | -- | Record type. TyRcdF (Map Var t) - deriving (Show, Eq, Functor, Foldable, Traversable, Generic, Generic1, Unifiable, Data, FromJSON, ToJSON) - --- | Unify two Maps by insisting they must have exactly the same keys, --- and if so, simply matching up corresponding values to be --- recursively unified. There could be other reasonable --- implementations, but in our case we will use this for unifying --- record types, and we do not have any subtyping, so record types --- will only unify if they have exactly the same keys. -instance Ord k => Unifiable (Map k) where - zipMatch m1 m2 = do - guard $ ((==) `on` M.keysSet) m1 m2 - pure $ M.merge M.dropMissing M.dropMissing (M.zipWithMatched (\_ a1 a2 -> Right (a1, a2))) m1 m2 + deriving (Show, Eq, Functor, Foldable, Traversable, Generic, Generic1, Data, FromJSON, ToJSON) + +deriveEq1 ''TypeF +deriveShow1 ''TypeF +deriveFromJSON1 defaultOptions ''TypeF +deriveToJSON1 defaultOptions ''TypeF -- | @Type@ is now defined as the fixed point of 'TypeF'. It would be -- annoying to manually apply and match against 'Fix' constructors @@ -163,37 +155,35 @@ type Type = Fix TypeF -- | Get all the type variables contained in a 'Type'. tyVars :: Type -> Set Var -tyVars = cata (\case TyVarF x -> S.singleton x; f -> fold f) +tyVars = foldFix (\case TyVarF x -> S.singleton x; f -> fold f) --- The derived Data instance is so we can make a quasiquoter for types. -deriving instance Data Type +newtype IntVar = IntVar Int + deriving (Show, Data, Eq, Ord) -- | 'UType's are like 'Type's, but also contain unification --- variables. 'UType' is defined via 'UTerm', which is also a kind --- of fixed point (in fact, 'UType' is the /free monad/ over 'TypeF'). +-- variables. 'UType' is defined via 'Free', which is also a kind +-- of fixed point (in fact, @Free TypeF@ is the /free monad/ over +-- 'TypeF'). -- -- Just as with 'Type', we provide a bunch of pattern synonyms for -- working with 'UType' as if it were defined directly. -type UType = UTerm TypeF IntVar - --- The derived Data instances are so we can make a quasiquoter for --- types. -deriving instance Data UType -deriving instance Data IntVar +type UType = Free TypeF IntVar --- | A generic /fold/ for things defined via 'UTerm' (including, in --- particular, 'UType'). This probably belongs in the --- @unification-fd@ package, but since it doesn't provide one, we --- define it here. -ucata :: Functor t => (v -> a) -> (t a -> a) -> UTerm t v -> a -ucata f _ (UVar v) = f v -ucata f g (UTerm t) = g (fmap (ucata f g) t) +-- | A generic /fold/ for things defined via 'Free' (including, in +-- particular, 'UType'). +ucata :: Functor t => (v -> a) -> (t a -> a) -> Free t v -> a +ucata f _ (Pure v) = f v +ucata f g (Free t) = g (fmap (ucata f g) t) -- | A quick-and-dirty method for turning an 'IntVar' (used internally -- as a unification variable) into a unique variable name, by -- appending a number to the given name. mkVarName :: Text -> IntVar -> Var -mkVarName nm (IntVar v) = T.append nm (from @String (show (v + (maxBound :: Int) + 1))) +mkVarName nm (IntVar v) = T.append nm (from @String (show v)) + +-- | Get all the free unification variables in a 'UType'. +fuvs :: UType -> Set IntVar +fuvs = ucata S.singleton fold -- | For convenience, so we can write /e.g./ @"a"@ instead of @TyVar "a"@. instance IsString Type where @@ -264,8 +254,8 @@ class WithU t where -- | 'Type' is an instance of 'WithU', with associated type 'UType'. instance WithU Type where type U Type = UType - toU = unfreeze - fromU = freeze + toU = foldFix Free + fromU = ucata (const Nothing) (fmap wrapFix . sequence) -- | A 'WithU' instance can be lifted through any functor (including, -- in particular, 'Ctx' and 'Poly'). @@ -333,57 +323,52 @@ pattern TyDelay :: Type -> Type pattern TyDelay ty1 = Fix (TyDelayF ty1) pattern UTyBase :: BaseTy -> UType -pattern UTyBase b = UTerm (TyBaseF b) +pattern UTyBase b = Free (TyBaseF b) pattern UTyVar :: Var -> UType -pattern UTyVar v = UTerm (TyVarF v) +pattern UTyVar v = Free (TyVarF v) pattern UTyVoid :: UType -pattern UTyVoid = UTerm (TyBaseF BVoid) +pattern UTyVoid = Free (TyBaseF BVoid) pattern UTyUnit :: UType -pattern UTyUnit = UTerm (TyBaseF BUnit) +pattern UTyUnit = Free (TyBaseF BUnit) pattern UTyInt :: UType -pattern UTyInt = UTerm (TyBaseF BInt) +pattern UTyInt = Free (TyBaseF BInt) pattern UTyText :: UType -pattern UTyText = UTerm (TyBaseF BText) +pattern UTyText = Free (TyBaseF BText) pattern UTyDir :: UType -pattern UTyDir = UTerm (TyBaseF BDir) +pattern UTyDir = Free (TyBaseF BDir) pattern UTyBool :: UType -pattern UTyBool = UTerm (TyBaseF BBool) +pattern UTyBool = Free (TyBaseF BBool) pattern UTyActor :: UType -pattern UTyActor = UTerm (TyBaseF BActor) +pattern UTyActor = Free (TyBaseF BActor) pattern UTyKey :: UType -pattern UTyKey = UTerm (TyBaseF BKey) +pattern UTyKey = Free (TyBaseF BKey) pattern UTySum :: UType -> UType -> UType -pattern UTySum ty1 ty2 = UTerm (TySumF ty1 ty2) +pattern UTySum ty1 ty2 = Free (TySumF ty1 ty2) pattern UTyProd :: UType -> UType -> UType -pattern UTyProd ty1 ty2 = UTerm (TyProdF ty1 ty2) +pattern UTyProd ty1 ty2 = Free (TyProdF ty1 ty2) pattern UTyFun :: UType -> UType -> UType -pattern UTyFun ty1 ty2 = UTerm (TyFunF ty1 ty2) +pattern UTyFun ty1 ty2 = Free (TyFunF ty1 ty2) pattern UTyRcd :: Map Var UType -> UType -pattern UTyRcd m = UTerm (TyRcdF m) +pattern UTyRcd m = Free (TyRcdF m) pattern UTyCmd :: UType -> UType -pattern UTyCmd ty1 = UTerm (TyCmdF ty1) +pattern UTyCmd ty1 = Free (TyCmdF ty1) pattern UTyDelay :: UType -> UType -pattern UTyDelay ty1 = UTerm (TyDelayF ty1) +pattern UTyDelay ty1 = Free (TyDelayF ty1) pattern PolyUnit :: Polytype pattern PolyUnit = Forall [] (TyCmd TyUnit) - --- Derive aeson instances for type serialization -deriving instance Generic Type -deriving instance ToJSON Type -deriving instance FromJSON Type diff --git a/src/swarm-util/Swarm/Util.hs b/src/swarm-util/Swarm/Util.hs index ba8bdd30e..48a1671d6 100644 --- a/src/swarm-util/Swarm/Util.hs +++ b/src/swarm-util/Swarm/Util.hs @@ -594,4 +594,4 @@ smallHittingSet ss = go fixed (filter (S.null . S.intersection fixed) choices) -- Given a nonempty collection of sets, find an element which is shared among -- as many of them as possible. mostCommon :: Ord a => [Set a] -> a - mostCommon = fst . maximumBy (comparing snd) . M.assocs . M.fromListWith (+) . map (,1 :: Int) . concatMap S.toList + mostCommon = fst . maximumBy (comparing snd) . M.assocs . histogram . concatMap S.toList diff --git a/swarm.cabal b/swarm.cabal index 087cc362b..76255114a 100644 --- a/swarm.cabal +++ b/swarm.cabal @@ -64,7 +64,6 @@ flag ci common common if flag(ci) ghc-options: -Werror - ghc-options: -Wall -Wcompat @@ -90,32 +89,36 @@ common ghc2021-extensions default-extensions: -- Note we warn on prequalified -- Not GHC2021, but until we get \cases we use \case a lot + BangPatterns + DeriveAnyClass + DeriveDataTypeable + DeriveFunctor + DeriveGeneric + DeriveTraversable + ExplicitForAll + FlexibleContexts + FlexibleInstances + GADTSyntax + ImportQualifiedPost + LambdaCase MultiParamTypeClasses + NumericUnderscores RankNTypes ScopedTypeVariables - FlexibleContexts - FlexibleInstances - BangPatterns StandaloneDeriving - TypeOperators - GADTSyntax - DeriveDataTypeable - DeriveGeneric TupleSections - LambdaCase - ExplicitForAll - DeriveFunctor - DeriveTraversable - DeriveAnyClass TypeApplications - NumericUnderscores - ImportQualifiedPost + TypeOperators library swarm-lang import: stan-config, common, ghc2021-extensions visibility: public -- cabal-gild: discover src/swarm-lang exposed-modules: + Swarm.Effect.Unify + Swarm.Effect.Unify.Common + Swarm.Effect.Unify.Fast + Swarm.Effect.Unify.Naive Swarm.Language.Capability Swarm.Language.Context Swarm.Language.Direction @@ -135,7 +138,6 @@ library swarm-lang Swarm.Language.Syntax.CommandMetadata Swarm.Language.Text.Markdown Swarm.Language.Typecheck - Swarm.Language.Typecheck.Unify Swarm.Language.Typed Swarm.Language.Types Swarm.Language.Value @@ -148,7 +150,11 @@ library swarm-lang commonmark >=0.2 && <0.3, commonmark-extensions >=0.2 && <0.3, containers, + data-fix >=0.3 && <0.4, + deriving-compat >=0.6 && <0.7, extra, + free >=5.2 && <5.3, + fused-effects, hashable, lens, lsp >=2.4 && <2.5, @@ -161,8 +167,6 @@ library swarm-lang template-haskell, text, text-rope >=0.2 && <0.3, - transformers, - unification-fd >=0.11 && <0.12, vector, vty, witch, @@ -689,7 +693,6 @@ library Swarm.Language.Syntax.CommandMetadata, Swarm.Language.Text.Markdown, Swarm.Language.Typecheck, - Swarm.Language.Typecheck.Unify, Swarm.Language.Typed, Swarm.Language.Types, Swarm.Language.Value,