Type level programming in julia
For a machine learning project, I needed a way to easily turn a big row from a
DataFrame
into a Float64
vector. The columns of the
dataset contained everything from categorical data to ordinal data (with lots of
missing
’s). Writing an encoder manually would have been long and
tedious. Here is what I wanted to do instead:
- Given a type, generate an encoding function from that type to a vector of floats
- Given a type, generate a decoding function from a vector of floats to that type.
1. Interface
Doing programming at the type level is easy, since types are first class values in julia. Here is what the api of what we want to build looks like:
abstract type Encodable end
function encdecsize(t::Type{A}) where {A <: Encodable}
# ...
# ...
encodefunc, decodefunc, sizeofencoding
end
The encodefunc
turns a value of our type into a constant size vector
of floats. The decodefunc
reverses that transformation. The
sizeofencoding
output is useful for methods that work on parametric
types as we will see later. Now, Let’s write implementations for concrete types.
2. Concrete types
The encoding functions for primitive types are easy to write, as they do not
require much more knowledge about Julia’s type system than simple
::Type{Thing}
dispatching.
function encdecsize(::Type{Float64})
encode(f) = [f]
decode(v) = v[1]
encode, decode, 1
end
Here, Type{Float64}
is the type of Float64
itself.
@show Float64 isa Type{Float64}
Float64 isa Type{Float64} = true
Let’s see if it works
@show en, de, s = encdecsize(Float64)
@show encoded = en(1.0)
@show decoded = de(encoded)
(en, de, s) = encdecsize(Float64) = (var"#encode#30"(), var"#decode#31"(), 1) encoded = en(1.0) = [1.0] decoded = de(encoded) = 1.0
Here, the size of the encoding of a Missing
is zero, since a value of
such a type contains zero information. It is always missing
!
function encdecsize(::Type{Missing})
encode(v) = Float64[]
decode(v) = missing
encode, decode, 0
end
Another (slightly) interesting case is the one with Bool
. In the
decoding function, we might receive a number that is not exactly zero or one. We
want to give the closest result, so we compare with 0.5.
function encdecsize(::Type{Bool})
encode(v) = [1.0*v]
decode(v) = v[1] > 0.5
encode, decode, 1
end
3. Composite types
Most types we use daily are not simple primitive types: they are composite. Think tuples and unions. We need a way to encode values of these types too.
3.1. Tuples
Since a tuple contains multiple types of values, we need to know how to encode (and decode) each of those types. Here, I chose to encode each component of the tuple one after the other. This is why we need our API to produce encoding sizes. Since each encoded component will live in a different offset of the vector, we need these sizes to calculate (when decoding) the vector slice of each encoded component.
function encdecsize(::Type{T}) where {T <: Tuple}
encs = []
decs = []
sizes = Int[]
# We collect encoding functions, decoding functions and encoding
# sizes for each component type.
for t in T.types
enc, dec, s = encdecsize(t)
push!(encs, enc)
push!(decs, dec)
push!(sizes, s)
end
function encode(tuple::T)
out = Float64[]
# We concatenate the encoded components.
for (i, elem) in enumerate(tuple)
out = vcat(out, encs[i](elem))
end
out
end
function decode(v)::T
start = 1
out = []
for (i, s) in enumerate(sizes)
# each time we decode a component, we must "consume" the
# associated part of the vector
push!(out, decs[i](v[start:start+s-1]))
start += s
end
# we turn back the Vector of Any into a tuple
(out...,)
end
encode, decode, sum(sizes)
end
3.2. Union
This part was the most time-consuming, as I had to wrestle with Julia’s method
dispatch system to get it to recognise “union of exactly two types, none of
which is the empty union”. I discovered that the type of a “Union type”
is Union
itself. Nice!
@show Union{Float32, Bool} isa Union
true
The most complicated part of this function is the correct treatment of the tag of the value. It is necessary to tell the union decoder what child decoding function to use.
We need first need a padding function to ensure every output of our encoding function is of equal length.
function pad(v, len)
vcat(v, zeros(max(0, len - length(v))))
end
Then, we need a way to encode(and decode) the type of the value itself.
function binenc(n, len)
[1.0 * (n>>i & 1)
for i in 0:len-1]
end
function bindec(v)
sum(v[i+1] * 2^i
for i in 0:length(v)-1)
end
Finally, we need a way to collect every leaf in a tree of Union
s
gettypes(u::Type) = [u]
gettypes(u::Union) = [u.a;gettypes(u.b)]
Here is the code:
function encdecsize(un::Union)
encs = []
decs = []
sizes = Int[]
types = gettypes(un)
# As before, we collect enc, dec and size of each subtype.
for t in types
enc, dec, s = encdecsize(t)
push!(encs, enc)
push!(decs, dec)
push!(sizes, s)
end
n = length(types)
# The tag must contain enough "bits" to encode every variant.
tagsize = Int(ceil(log(2, n)))
# Apart from the tag, we must be able to encode the biggest
# variant.
innersize = maximum(sizes)
function encode(v)
for (i, t) in enumerate(types)
if v isa t
# We write the tag, then the padded encoded value.
return vcat(
binenc(i-1, tagsize),
pad(encs[i](v), innersize),
)
end
end
end
function decode(v)
tag = v[1:tagsize]
# We calculate the tag, then retrieve the correct decoding
# function.
dectag = bindec(tag) |> round |> Int
d = decs[dectag+1](v[tagsize+1:tagsize+sizes[dectag+1]])
end
encode, decode, tagsize+innersize
end
4. Wrapping up
Here, let’s test our code on semi-complex tuples:
let
vec = Tuple{Float64, Bool, Union{Missing, Float64}}[
(1.0, false, missing),
(2.0, true, 15.0),
(2.0, true, 13.0),
]
en, de, _ = encdecsize(eltype(typeof(vec)))
for v in vec
@show encoded = en(v)
@show decoded = de(encoded)
end
end
encoded = en(v) = [1.0, 0.0, 0.0, 0.0] decoded = de(encoded) = (1.0, false, missing) encoded = en(v) = [2.0, 1.0, 1.0, 15.0] decoded = de(encoded) = (2.0, true, 15.0) encoded = en(v) = [2.0, 1.0, 1.0, 13.0] decoded = de(encoded) = (2.0, true, 13.0)