Type level programming in julia

For a machine learning project, I needed a way to easily turn a big row from a DataFrame into a Float64 vector. The columns of the dataset contained everything from categorical data to ordinal data (with lots of missing’s). Writing an encoder manually would have been long and tedious. Here is what I wanted to do instead:

  1. Given a type, generate an encoding function from that type to a vector of floats
  2. Given a type, generate a decoding function from a vector of floats to that type.

1. Interface

Doing programming at the type level is easy, since types are first class values in julia. Here is what the api of what we want to build looks like:

abstract type Encodable end
function encdecsize(t::Type{A}) where {A <: Encodable}
    # ...
    # ...
    encodefunc, decodefunc, sizeofencoding
end

The encodefunc turns a value of our type into a constant size vector of floats. The decodefunc reverses that transformation. The sizeofencoding output is useful for methods that work on parametric types as we will see later. Now, Let’s write implementations for concrete types.

2. Concrete types

The encoding functions for primitive types are easy to write, as they do not require much more knowledge about Julia’s type system than simple ::Type{Thing} dispatching.

function encdecsize(::Type{Float64})
    encode(f) = [f]
    decode(v) = v[1]
    encode, decode, 1
end

Here, Type{Float64} is the type of Float64 itself.

@show Float64 isa Type{Float64}
Float64 isa Type{Float64} = true

Let’s see if it works

@show en, de, s = encdecsize(Float64)
@show encoded = en(1.0)
@show decoded = de(encoded)
(en, de, s) = encdecsize(Float64) = (var"#encode#30"(), var"#decode#31"(), 1)
encoded = en(1.0) = [1.0]
decoded = de(encoded) = 1.0

Here, the size of the encoding of a Missing is zero, since a value of such a type contains zero information. It is always missing!

function encdecsize(::Type{Missing})
    encode(v) = Float64[]
    decode(v) = missing
    encode, decode, 0
end

Another (slightly) interesting case is the one with Bool. In the decoding function, we might receive a number that is not exactly zero or one. We want to give the closest result, so we compare with 0.5.

function encdecsize(::Type{Bool})
    encode(v) = [1.0*v]
    decode(v) = v[1] > 0.5
    encode, decode, 1
end

3. Composite types

Most types we use daily are not simple primitive types: they are composite. Think tuples and unions. We need a way to encode values of these types too.

3.1. Tuples

Since a tuple contains multiple types of values, we need to know how to encode (and decode) each of those types. Here, I chose to encode each component of the tuple one after the other. This is why we need our API to produce encoding sizes. Since each encoded component will live in a different offset of the vector, we need these sizes to calculate (when decoding) the vector slice of each encoded component.

function encdecsize(::Type{T}) where {T <: Tuple}
    encs = []
    decs = []
    sizes = Int[]
    
    # We collect encoding functions, decoding functions and encoding
    # sizes for each component type.
    for t in T.types
        enc, dec, s = encdecsize(t)
        push!(encs, enc)
        push!(decs, dec)
        push!(sizes, s)
    end

    function encode(tuple::T)
        out = Float64[]
        # We concatenate the encoded components.
        for (i, elem) in enumerate(tuple)
            out = vcat(out, encs[i](elem))
        end
        out
    end
    function decode(v)::T
        start = 1
        out = []
        for (i, s) in enumerate(sizes)
            # each time we decode a component, we must "consume" the
            # associated part of the vector
            push!(out, decs[i](v[start:start+s-1]))
            start += s
        end
        # we turn back the Vector of Any into a tuple
        (out...,)
    end
    encode, decode, sum(sizes)
end

3.2. Union

This part was the most time-consuming, as I had to wrestle with Julia’s method dispatch system to get it to recognise “union of exactly two types, none of which is the empty union”. I discovered that the type of a “Union type” is Union itself. Nice!

@show Union{Float32, Bool} isa Union
true

The most complicated part of this function is the correct treatment of the tag of the value. It is necessary to tell the union decoder what child decoding function to use.

We need first need a padding function to ensure every output of our encoding function is of equal length.

function pad(v, len)
    vcat(v, zeros(max(0, len - length(v))))
end

Then, we need a way to encode(and decode) the type of the value itself.

function binenc(n, len)
    [1.0 * (n>>i & 1)

     for i in 0:len-1]
end
function bindec(v)
    sum(v[i+1] * 2^i
        for i in 0:length(v)-1)
end

Finally, we need a way to collect every leaf in a tree of Unions

gettypes(u::Type) = [u]
gettypes(u::Union) = [u.a;gettypes(u.b)]

Here is the code:

function encdecsize(un::Union)
    encs = []
    decs = []
    sizes = Int[]
    types = gettypes(un)

    # As before, we collect enc, dec and size of each subtype.
    for t in types
        enc, dec, s = encdecsize(t)
        push!(encs, enc)
        push!(decs, dec)
        push!(sizes, s)
    end

    n = length(types)

    # The tag must contain enough "bits" to encode every variant.
    tagsize = Int(ceil(log(2, n)))

    # Apart from the tag, we must be able to encode the biggest
    # variant.
    innersize = maximum(sizes)

    function encode(v)
        for (i, t) in enumerate(types)
            if v isa t
                # We write the tag, then the padded encoded value.
                return vcat(
                    binenc(i-1, tagsize),
                    pad(encs[i](v), innersize),
                )
            end
        end
    end

    function decode(v)
        tag = v[1:tagsize]
        # We calculate the tag, then retrieve the correct decoding
        # function.
        dectag = bindec(tag) |> round |> Int
        d = decs[dectag+1](v[tagsize+1:tagsize+sizes[dectag+1]])
    end
    encode, decode, tagsize+innersize
end

4. Wrapping up

Here, let’s test our code on semi-complex tuples:

let
    vec = Tuple{Float64, Bool, Union{Missing, Float64}}[
        (1.0, false, missing),
        (2.0, true, 15.0),
        (2.0, true, 13.0),
    ]

    en, de, _ = encdecsize(eltype(typeof(vec)))
    for v in vec
        @show encoded = en(v)
        @show decoded = de(encoded)
    end
end
encoded = en(v) = [1.0, 0.0, 0.0, 0.0]
decoded = de(encoded) = (1.0, false, missing)
encoded = en(v) = [2.0, 1.0, 1.0, 15.0]
decoded = de(encoded) = (2.0, true, 15.0)
encoded = en(v) = [2.0, 1.0, 1.0, 13.0]
decoded = de(encoded) = (2.0, true, 13.0)

Author: Justin Veilleux

Created: 2025-09-25 Thu 18:12

Validate