In previous articles we talked about how to write an implementation of a type inference algorithm. One that can infer the type of complex expressions without type annotation and can provide a validation layer on top of our code for no effort on the user's part.

What we talked about was for validating expressions and expression definitions. But what about type definitions? How can we help the user catch errors when the types they define are inconsistent or don't make sense?

We want to be able to catch errors such as:

Tree a =
  | Node a Tree Tree
    -- ^ should be: Node a (Tree a) (Tree a)

And:

Rec f a =
  | Rec f (f a)
    -- ^ f is used both as a saturated type and a type that takes a parameter

And more, while still allowing the user to define complex types without annotation, such as:

Cofree f a =
  | Cofree a (f (Cofree f a))

As we'll soon see, we can use the exact same unification-based constraint solving approach to type inference we covered in this article to infer the type of a type, or as we call them in the Haskell world, the kind of a type.

Getting started

‍If you prefer to skip the explanations and jump straight to the code, click here.

In this article we will implement a kind inference engine in Haskell for a simple type system. We'll start by adding the relevant imports and language definitions for our Haskell module:

#!/usr/bin/env cabal
{- cabal:
build-depends: base, mtl, containers, uniplate
ghc-options: -Wall
-}

This first part lets us run this file as a script if we have ghc and cabal installed. Just chmod +x kinds.hs and run it.

We will use the GHC2021 set of extensions and LambdaCase, as well as a few additional modules that will come into play later.

-- | An example of a kind inference for data types using
-- unification-based constraint solving.
--
-- See the blog post:
-- <https://gilmi.me/blog/post/2023/09/30/kind-inference>

{-# Language GHC2021 #-}
{-# Language LambdaCase #-}

import Data.Data (Data)
import GHC.Generics (Generic)
import Data.Tuple (swap)
import Data.Maybe (listToMaybe)
import Data.Foldable (for_)
import Data.Traversable (for)
import Control.Monad (foldM)
import Control.Monad.State qualified as Mtl
import Control.Monad.Except qualified as Mtl
import Data.Generics.Uniplate.Data qualified as Uniplate (universe, transformBi)
import Data.Map qualified as Map

Models

Now we can start by defining our models. What are types? What do type definitions look like? What are kinds?

Let's start with a data type definition. We'll support ML style data definitions like the ones in Haskell. For example, the following data type:

Option a =
  | Some a
  | None

A data type definition starts with the type's name, its type parameters, and a list of variants where each has a constructor name and potentially several types.

We'll represent that using the following types:

-- | The representation of a data type definition.
data Datatype a
  = Datatype
    { -- | A place to put kind annotation in.
      dtAnn :: a
    , -- | The name of the data type.
      dtName :: TypeName
    , -- | Type parameters.
      dtParameters :: [TypeVar]
    , -- | Alternative variants.
      dtVariants :: [Variant a]
    }
  deriving (Show, Eq, Data, Generic, Functor, Foldable, Traversable)

-- | A Variant of a data type definition.
data Variant a
  = Variant
    { -- | A type constructor.
      vTypeConstructor :: String
    , -- | A list of types.
      vTypes :: [Type a]
    }
  deriving (Show, Eq, Data, Generic, Functor, Foldable, Traversable)

-- | A name of known types.
newtype TypeName = MkTypeName { getTypeName :: String }
  deriving (Show, Eq, Ord, Data, Generic)

-- | A type variable.
newtype TypeVar = MkTypeVar { getTypeVar :: String }
  deriving (Show, Eq, Ord, Data, Generic)

That polymorphic a is going to be used for our kind annotation.

The shape of the types that we are going to support in our type system are fairly simple. We support type names, such as Int and Option, type variables, such as a and t, and type application, which lets us apply higher kinded types such as Option with other types, such as Option Int, Either e a and f a.

-- | A representation of a type with a place for kind annotation.
data Type a
  = -- | A type variable.
    TypeVar a TypeVar
  | -- | A named type.
    TypeName a TypeName
  | -- | An application of two types, of the form `t1 t2`.
    TypeApp a (Type a) (Type a)
  deriving (Show, Eq, Data, Generic, Functor, Foldable, Traversable)

For example, the data type Option we defined earlier will be represented as a Datatype in the following way:

option =
  Datatype ()
    (MkTypeName "Option")
    [MkTypeVar "a"]
    [ Variant "Some" [TypeVar () $ MkTypeVar "a"]
    , Variant "None" []
    ]

And now lets talk about kinds. As we said before, kinds are the types of types. They represent their whether they can be applied with other types, what their arity should be, and kind of types can be placed in each slot.

For example Option has the kind Type -> Type, it can be applied with a type that has the kind Type, such as Int, but cannot be applied with Option, or with two Ints.

There are also scenarios where a type variable can have any kind, for example in the following data type:

Proxy t =
  | Proxy

Since t is not used anywhere, we can apply any type to Proxy. We can have Proxy Int, but also Proxy Option. Let's define this as a data type:

-- | A representation of a kind.
data Kind
  = -- | For types like `Int`.
    Type
  | -- | For types like `Option`.
    KindFun Kind Kind
  | -- | For polymorphic kinds.
    KindVar KindVar
  | -- | For closing over polymorphic kinds.
    KindScheme [KindVar] Kind
  deriving (Show, Eq, Data, Generic)

-- | A kind variable.
newtype KindVar = MkKindVar { getKindVar :: String }
  deriving (Show, Eq, Ord, Data, Generic)

These types represent the module of our language. During inference, we take a list of data types and a mapping from named types that might appear in these data types to their kinds, and we infer the kinds of these data types and return the data types annotated with their kinds, or we return an error if there was a problem.

We can capture this operation in this type signature:

infer :: Map.Map TypeName Kind -> [Datatype ()] -> Either Error [Datatype Kind]

Let's dive-in and see how we can implement infer.

Kind Inference

Our kind inference algorithm, like our previous type inference algorithm, has 6 important parts:

  1. Topologically order definitions and group those that depend on one another (which we will not cover here).
  2. Elaboration and constraint generation
  3. Constraint solving
  4. Instantiation
  5. Substitution
  6. Generalization

The general process is as follows: we sort and group definitions by their dependencies, we elaborate the data types by giving each type we meet a unique kind variable and collect constraints on those kind variables according to their usage and placement. We then solve these constraints using unification, instatiating the polymorphic kinds we run into, and create a substitution which is a mapping from kind variables to kinds. We then substitute the kind variables we gave to each type in the elaboration stage with their mapped kinds from the substitution in the data type definitions. Then we generalize the kinds of data type definitions and close over their free variables.

This looks somewhat like this:

-- | Infer the kind of a group of data types that should be solved together
--   (because they are mutually recursive).
infer :: Map.Map TypeName Kind -> [Datatype ()] -> Either Error [Datatype Kind]
infer kindEnv datatypes =
  -- initialize our `InferenceM` which is State + Except
  flip Mtl.evalState (initialState kindEnv) $ Mtl.runExceptT $ do
    -- Invent a kind variable for each data type
    for_ datatypes $ \(Datatype _ name _ _) -> do
      kindvar <- freshKindVar
      declareNamedType name kindvar
    -- Elaborate all of the data types
    datatypes' <- traverse elaborate datatypes
    -- Solve the constraints
    solveConstraints
    for datatypes' $ \(Datatype kindvar name vars variants) -> do
      -- Substitute the kind variable for a kind
      -- for the data type
      kind <- lookupKindVarInSubstitution kindvar
      -- ... and for all types
      variants' <- for variants $ traverse lookupKindVarInSubstitution
      -- generalize the data type's kind, and return.
      pure (Datatype (generalize kind) name vars variants')

Lets unpack all of that, step by step.

InferenceM

A couple of capabilities that are going to help us write less verbose code are managing State and throwing Exceptions. We will define a type that merges and provides these capabilities:

-- | We combine the capabilities of Except and State
--   For our kind inference code.
type InferenceM a = Mtl.ExceptT Error (Mtl.State State) a

And the types representing the errors we can throw, and the state we keep throughout the inference process:

-- | The errors that can be thrown in the process.
data Error
  = UnboundVar TypeVar
  | UnboundName TypeName
  | UnificationFailed Kind Kind
  | OccursCheckFailed (Maybe (Type ())) KindVar Kind
  deriving (Show)

-- | The state we keep during an inference cycle
data State = State
  { -- | Mapping from named types or type variables to kind variables.
    -- When we declare a new data type or a type variable, we'll add it here.
    -- When run into a type variable or a type name during elaboration,
    -- we search its matching kind here.
    env :: Map.Map (Either TypeName TypeVar) KindVar
  , -- | Mapping from existing named types to their kinds.
    -- Kinds for types that are supplied before the inference process can be found here.
    kindEnv :: Map.Map TypeName Kind
  , -- | Used for generating fresh kind variables.
    counter :: Int
  , -- | When we learn information about kinds during elaboration, we'll add it here.
    constraints :: [Constraint]
  , -- | The constraint solving process will generate this mapping from
    -- the kind variables we collected to the kind they should represent.
    -- If we don't find the kind variable in the substitution, that means
    -- it is a free variable we should close over.
    substitution :: Map.Map KindVar Kind
  }
  deriving (Show, Eq, Data, Generic)

-- | The state at the start of the process.
initialState :: Map.Map TypeName Kind -> State
initialState kindEnv =
  State mempty kindEnv 0 mempty mempty

-- | A constraint on kinds.
data Constraint
  = Equality Kind Kind
    -- ^ The two kinds should unify.
    -- If one of the kinds is a kind scheme, we will instantiate it, and
    -- add an equality constraint of the other kind with the instantiated kind.
  deriving (Show, Eq, Data, Generic)

We'll later write special utilities functions for interacting with this state when we run into them.

Elaboration and constraint generation

In this section we want to traverse a data type, annotate the types with fresh kind variables, and generate constraints according to the types' location and usage.

-- | Invent kind variables for types we don't know and add constraints
--   on them according to their usage.
elaborate :: Datatype () -> InferenceM (Datatype KindVar)
elaborate (Datatype _ datatypeName vars variants) = do
  -- We go over each of the data type parameters and
  -- generate a fresh kind variable for them.
  varKinds <- for vars $ \var -> do
    kindvar <- freshKindVar
    declareTypeVar var kindvar
    pure kindvar

  -- We go over the variants, elaborate each field,
  -- and return the elaborated variants.
  variants' <- for variants $ \(Variant name fields) -> do
    Variant name <$>
      for fields
        ( \field -> do
          field' <- elaborateType field
          -- a constraint on fields: their kind must be `Type`.
          newEqualityConstraint (KindVar $ getAnn field') Type
          pure field'
        )

  -- We grab the kind variable of the data type
  -- so we can add a constraint on it.
  datatypeKindvar <- lookupNameKindVar datatypeName
  -- A type of the form `T a b c ... =` has the kind:
  -- `aKind -> bKind -> cKind -> ... -> Type`.
  -- We add that as a constraint.
  let kind = foldr KindFun Type $ map KindVar varKinds
  newEqualityConstraint (KindVar datatypeKindvar) kind

  -- We return the elaborated data type after annotating
  -- all types with kind variables and generating constraints.
  pure (Datatype datatypeKindvar datatypeName vars variants')

There are a couple of utility functions we've used in the last snippet:

We generate fresh kind variables for type variables declare. After constraint solving we'll find the kind variable again and learn what the actual kind should be in its place.

We'll also save the kind variable we generated for the type variable in the environment, so we can find it later when it is used.

-- | Generate a fresh kind variables.
freshKindVar :: InferenceM KindVar
freshKindVar = do
  s <- Mtl.get
  let kindvar = MkKindVar ("k" <> show (counter s))
  Mtl.put s { counter = 1 + counter s }
  pure kindvar

-- | Insert declared type variables into the environment.
declareTypeVar :: TypeVar -> KindVar -> InferenceM ()
declareTypeVar var kindvar =
  Mtl.modify $ \s ->
    s { env = Map.insert (Right var) kindvar (env s) }

We've used freshKindVar and the following declareNamedType before in infer when we ran into the data type declaration.

-- | Insert declared type names into the environment.
declareNamedType :: TypeName -> KindVar -> InferenceM ()
declareNamedType name kindvar =
  Mtl.modify $ \s ->
    s { env = Map.insert (Left name) kindvar (env s) }

We also fetch the kind we annotated a type with using getAnn:

-- | Get the annotation of a type.
getAnn :: Type a -> a
getAnn = \case
  TypeVar a _ -> a
  TypeName a _ -> a
  TypeApp a _ _ -> a

Another important utility function is for adding constraints:

-- | Add a new equality constraint to the state.
newEqualityConstraint :: Kind -> Kind -> InferenceM ()
newEqualityConstraint k1 k2 =
  Mtl.modify $ \s ->
    s { constraints = Equality k1 k2 : constraints s }

Elaborating types

The next part is elaborating types. As a reminder, we support named types, type variables, and applications of a type to a type.

For type variables, we added them to the environment previously when we saw them declared. We look them up. If they are not there, that's an error.

-- | Find the kind variable of a type variable in the environment.
lookupVarKindVar :: TypeVar -> InferenceM KindVar
lookupVarKindVar var =
  maybe
    (Mtl.throwError $ UnboundVar var)
    pure
    . Map.lookup (Right var)
    . env =<< Mtl.get

For named types, either they were supplied to the inference stage, and in that case we invent a new kind variable for them for this particular use, or they were declared as part of this data types group, in which case we look them up in the environment.

-- | Find the kind variable of a named type in the environment.
lookupNameKindVar :: TypeName -> InferenceM KindVar
lookupNameKindVar name = do
  state <- Mtl.get
  -- We first look the named type in the supplied kind env.
  case Map.lookup name (kindEnv state) of
    -- If we find it, we generate a new kind variable for it
    -- and constraint it to be this type, so that each use has it own
    -- type variable (later used for instantiation).
    Just kind -> do
      kindvar <- freshKindVar
      newEqualityConstraint (KindVar kindvar) kind
      pure kindvar
    -- If it's not a supplied type, it means we are actively inferring it,
    -- and we need to use the same kind variable for all uses.
    -- We'll look it up in our environment of declared types.
    Nothing ->
      maybe
        -- If we still can't find it, we error.
        (Mtl.throwError $ UnboundName name)
        pure
        . Map.lookup (Left name)
        $ env state

And for a type application of t1 and t2, we elaborate both types, invent a kind variable for the type application, then constrain the applied type t1 to be equal to a kind that takes the kind of t2 and returns the kind of the type application.

This is the rest of the code for elaborating types:

-- | Elaborate a type with a kind variable and add constraints
--   according to usage.
elaborateType :: Type () -> InferenceM (Type KindVar)
elaborateType = \case
  -- for type variables and type names,
  -- we lookup the kind variables we generated when we ran into
  -- the declaration of them.
  TypeVar () var ->
    fmap (\kindvar -> TypeVar kindvar var) (lookupVarKindVar var)

  TypeName () name ->
    fmap (\kindvar -> TypeName kindvar name) (lookupNameKindVar name)

  -- for type application
  TypeApp () t1 t2 -> do
    -- we elaborate both types
    t1Kindvar <- elaborateType t1
    t2Kindvar <- elaborateType t2
    -- then we generate a kind variable for the type application
    typeAppKindvar <- freshKindVar
    -- then we constrain the type application kind variable
    -- it should unify with `t2Kind -> typeAppKind`.
    newEqualityConstraint
      (KindVar $ getAnn t1Kindvar)
      (KindFun (KindVar $ getAnn t2Kindvar) (KindVar typeAppKindvar))

    pure (TypeApp typeAppKindvar t1Kindvar t2Kindvar)

And that's it for the elaboration phase. After giving each type a kind variable and collecting some constraints about them, we are ready to the next stage where we can ignore the data type definition and focus on the constraints we generated.

Constraint solving and generating a substitution

In this phase we go one constraint at a time and decide whether it is trivial (equality between Type and Type), or if it needs to be reduced to simpler constraints that will be checked (like matching the two first parts and the two second parts of two KindFuns).

When we run into kind variables, we will substitute them with the other kind in the rest of the equality constraints and in a mapping we'll keep on the side which we'll call a "substitution" and keep going.

When we run into a kind scheme, we instantiate it (give it a new unique instance) and constrain it with the other kind in the constraint.

When we run into two kinds that cannot be unified (Type and KindFun), we throw an error.

When there are no more constraints left to solve, we are done and succeeded on our task!

-- | Solve constraints according to logic.
--   this process is iterative. We continue fetching
--   the next constraint and try to solve it.
--
--   Each step can either reduce or increase the number of constraints,
--   and we are done when there are no more constraints to solve,
--   or if we ran into a constraint that cannot be solved.
solveConstraints :: InferenceM ()
solveConstraints = do
  -- Pop the next constraint we should solve.
  constraint <- do
    c <- listToMaybe . constraints <$> Mtl.get
    Mtl.modify $ \s -> s { constraints = drop 1 $ constraints s }
    pure c

  case constraint of
    -- If we have two 'Type's, the unify. We can skip to the next constraint.
    Just (Equality Type Type) -> solveConstraints
    -- We have an equality between two kind functions.
    -- We add two new equality constraints matching the two firsts
    -- with the two seconds.
    Just (Equality (KindFun k1 k2) (KindFun k3 k4)) -> do
      Mtl.modify $ \s ->
        s { constraints = Equality k1 k3 : Equality k2 k4 : constraints s }
      solveConstraints
    -- When we run into a kind scheme, we instantiate it
    -- (we look at the kind and replace all closed kind variables
    -- with fresh kind variables), and add an equality constraint
    -- between the other kind and the instantiated kind.
    Just (Equality (KindScheme vars kind) k) -> do
      kind' <- instantiate kind vars
      Mtl.modify $ \s ->
        s { constraints = Equality k kind' : constraints s }
      solveConstraints
    -- Same as the previous scenario.
    Just (Equality k (KindScheme vars kind)) -> do
      kind' <- instantiate kind vars
      Mtl.modify $ \s ->
        s { constraints = Equality kind' k : constraints s }
      solveConstraints
    -- If we run into a kind variable on one of the sides,
    -- we replace all instances of it with the other kind and continue.
    Just (Equality (KindVar var) k) -> do
      replaceInState var k
      solveConstraints
    -- The same as the previous scenario, but the kind var is on the other side.
    Just (Equality k (KindVar var)) -> do
      replaceInState var k
      solveConstraints
    -- If we have an equality constraint between a 'Type' and
    -- a 'KindFun', we cannot unify the two, and unification fails.
    Just (Equality k1@Type k2@KindFun{}) -> Mtl.throwError (UnificationFailed k1 k2)
    Just (Equality k1@KindFun{} k2@Type) -> Mtl.throwError (UnificationFailed k1 k2)
    -- If there are no more constraints, we are done. Good job!
    Nothing -> pure ()

Let's talk about a few of these operations.

Instantiating kind schemes

When we run into a kind scheme (where a kind contains polymorphic kind variables) in a constraint, we actually want to work with an instance of that kind scheme. So we take the kind scheme and produce a kind where all of the type variables in it are fresh kind variables.

-- | Instantiate a kind.
--   We look at the kind and replace all closed kind variables
--   with fresh kind variables.
instantiate :: Kind -> [KindVar] -> InferenceM Kind
instantiate = foldM replaceKindVarWithFreshKindVar

-- | Replace a kind variable with a fresh variable in the kind.
replaceKindVarWithFreshKindVar :: Kind -> KindVar -> InferenceM Kind
replaceKindVarWithFreshKindVar kind var = do
  kindvar <- freshKindVar
  -- Uniplate.transformBi lets us perform reflection and
  -- apply a function to all instances of a certain type
  -- in a value. Think of it like `fmap`, but for any type.
  --
  -- It is a bit slow though, so it's worth replacing it with
  -- hand rolled recursion or a functor, but its convenient.
  pure $ flip Uniplate.transformBi kind $ \case
    kv | kv == var -> kindvar
    x -> x

Note: We are using the uniplate with the interface that works for every type that has an instance of Data from Data.Data. It lets us use generic traversals and transformations with very little effort, but it is fairly slower than hand-writting things so for real kind inference implementation you probably want to hand-write the traversals.

Replacing kind variables

When we run into a kind variable, we replace it with the kind on the other side of the equality constraint in the substitution and in the rest of the constraints, and we then add it to the substitution.

We change it in the rest of the constraints so that if we have the following two constraints:

1. Equality (KindVar "k1") Type
2. Equality (KindVar "k1") (Type -> Type)

When we replace KindVar "k1" with Type in the rest of the constraints, instead of the next constraint being (2), it will be:

Equality Type (Type -> Type)

Which does not unify, and we catch the bug.

We also change the kind variable in the substitution and later add it to it so we can later look up the kind variable we placed on each type in the elaboration phase and find their kinds.

-- | Replace every instance of 'KindVar var' in our state with 'kind'.
--   And add it to the substitution.
replaceInState :: KindVar -> Kind -> InferenceM ()
replaceInState var kind = do
  occursCheck var kind
  s <- Mtl.get
  let
    -- Uniplate.transformBi lets us perform reflection and
    -- apply a function to all instances of a certain type
    -- in a value. Think of it like `fmap`, but for any type.
    --
    -- Note that we are changing all instances of `Kind` of the form
    -- `KindVar v | v == var` in all of `State`! This includes both the
    -- `substitution` and the remaining `constraints`,
    --
    -- It is a bit slow though, so it's worth replacing it with
    -- hand rolled recursion or a functor, but its convenient.
    s' =
      flip Uniplate.transformBi s $ \case
        KindVar v | v == var -> kind
        x -> x
  Mtl.put $ s' { substitution = Map.insert var kind (substitution s') }

But one important thing we need to check about the kind variable and the kind is that the kind does not contain the kind variable, which means we have an "infinite" kind. This is called an occurs check.

-- | We check that the kind variable does not appear in the kind
--   and throw an error if it does.
occursCheck :: KindVar -> Kind -> InferenceM ()
occursCheck var kind =
  if KindVar var == kind || null [ () | KindVar v <- Uniplate.universe kind, var == v ]
    then pure ()
    else do
      -- We try to find the type of the kind variable by doing reverse lookup,
      -- but this might not succeed before the kind variable might be generated
      -- during constraint solving.
      -- We might be able to find the type if we look at the substitution as well,
      -- but for now lets leave it at this "best effort" attempt.
      reverseEnv <- map swap . Map.toList . env <$> Mtl.get
      let typ = either (TypeName ()) (TypeVar ()) <$> lookup var reverseEnv
      Mtl.throwError (OccursCheckFailed typ var kind)

Once again we use the universe function from the uniplate library, which returns all values of the same type that appear in a value, so we find all of the kinds inside our kind and select the kind variables specifically.

Other errors

If we run into an equality constraint between a Type and a KindFun, we throw an error, since we can't unify them.

Substitution

Once we finish with constraint solving, we'll have a substitution ready for us in State.

All we need to do now is look up the kind produced by the substitution for each kind variable.

-- | Look up what the kind of a kind variable is in the substitution
--   produced by constraint solving.
--   If there was no constraint on the kind variable, it won't appear
--   in the substitution, which means it can stay a kind variable which
--   we will close over later.
lookupKindVarInSubstitution :: KindVar -> InferenceM Kind
lookupKindVarInSubstitution kindvar =
  maybe (KindVar kindvar) id . Map.lookup kindvar . substitution <$> Mtl.get

Generalization

When we are done with elaborating, solving constraints, and substituting over data types, we need to look at the kind produced for each data type, and close over the free kind variables.

Again, using Uniplate.universe to find all of the type variables and include them in the kind scheme.

-- | Close over kind variables we did not solve.
generalize :: Kind -> Kind
generalize kind = KindScheme [var | KindVar var <- Uniplate.universe kind] kind

Examples

That's pretty much it! We can now define data types and observe the kinds we produce.

option :: Datatype ()
option =
  Datatype ()
    (MkTypeName "Option")
    [MkTypeVar "a"]
    [ Variant "Some" [TypeVar () $ MkTypeVar "a"]
    , Variant "None" []
    ]

main :: IO ()
main = do
  print $ map dtAnn <$> infer mempty [option]

Will output:

Right [KindScheme [] (KindFun Type Type)]

As promised, our kind inference engine is able to infer the kind of this type:

Cofree f a =
  | Cofree a (f (Cofree f a))

As expected:

Cofree : (Type -> Type) -> Type -> Type
Cofree f a =
  | Cofree a (f (Cofree f a))

And can catch errors such as:

Tree a =
  | Node a Tree Tree

And produce the error:

Unification failed between the following kinds:
  * k1 -> Type
  * Type

You can find other examples in the gist.

Summary

Kind inference using unification-based constraint solving works on data types the in same way type inference with the same methods works on expressions. While it can be a bit tricky, implementing a somewhat powerful kind inference engine is relatively straightforward.

You can find the source code in this gist. It includes pretty printing code, additional examples, and a lot of comments.