The Rust Calling Convention We Deserve

I will often say that the so-called “C ABI” is a very bad one, and a relatively unimaginative one when it comes to passing complicated types effectively. A lot of people ask me “ok, what would you use instead”, and I just point them to the Go register ABI, but it seems most people have trouble filling in the gaps of what I mean. This article explains what I mean in detail.

I have discussed calling conventions in the past, but as a reminder: the calling convention is the part of the ABI that concerns itself with how to pass arguments to and from a function, and how to actually call a function. This includes which registers arguments go in, which registers values are returned out of, what function prologues/epilogues look like, how unwinding works, etc.

This particular post is primarily about x86, but I intend to be reasonably generic (so that what I’ve written applies just as well to ARM, RISC-V, etc). I will assume a general familiarity with x86 assembly, LLVM IR, and Rust (but not rustc’s internals).

The Problem

Today, like many other natively compiled languages, Rust defines an unspecified0- calling convention that lets it call functions however it likes. In practice, Rust lowers to LLVM’s built-in C calling convention, which LLVM’s prologue/epilogue codegen generates calls for.

Rust is fairly conservative: it tries to generate LLVM function signatures that Clang could have plausibly generated. This has two significant benefits:

  1. Good probability debuggers won’t choke on it. This is not a concern on Linux, though, because DWARF is very general and does not bake-in the Linux C ABI. We will concern ourselves only with ELF-based systems and assume that debuggability is a nonissue.

  2. It is less likely to tickle LLVM bugs due to using ABI codegen that Clang does not exercise. I think that if Rust tickles LLVM bugs, we should actually fix them (a very small number of rustc contributors do in fact do this).

However, we are too conservative. We get terrible codegen for simple functions:

fn extract(arr: [i32; 3]) -> i32 {
  arr[1]
}
Rust
extract:
  mov   eax, dword ptr [rdi + 4]
  ret
x86 Assembly

arr is 12 bytes wide, so you’d think it would be passed in registers, but no! It is passed by pointer! Rust is actually more conservative than what the Linux C ABI mandates, because it actually passes the [i32; 3] in registers when extern "C" is requested.

extern "C" fn extract(arr: [i32; 3]) -> i32 {
  arr[1]
}
Rust
extract:
  mov   rax, rdi
  shr   rax, 32
  ret
x86 Assembly

The array is passed in rdi and rsi, with the i32s packed into registers. The function moves rdi into rax, the output register, and shifts the upper half down.

Not only does clang produce patently bad code for passing things by value, but it also knows how to do it better, if you request a standard calling convention! We could be generating way better code than Clang, but we don’t!

Hereforth, I will describe how to do it.

-Zcallconv

Let’s suppose that we keep the current calling convention for extern "Rust"1, but we add a flag -Zcallconv that sets the calling convention for extern "Rust" when compiling a crate. The supported values will be -Zcallconv=legacy for the current one, and -Zcallconv=fast for the one we’re going to design. We could even let -O set -Zcallconv=fast automatically.

Why keep the old calling convention? Although I did sweep debugability under the rug, one nice property -Zcallconv=fast will not have is that it does not place arguments in the C ABI order, which means that a reader replying on the “Diana’s silk dress cost $89” mnemonic on x86 will get fairly confused.

I am also assuming we may not even support -Zcallconv=fast for some targets, like WASM, where there is no concept of “registers” and “spilling”. It may not even make sense to enable it for for debug builds, because it will produce much worse code with optimizations turned off.

There is also a mild wrinkle with function pointers, and extern "Rust" {} blocks. Because this flag is per-crate, even though functions can advertise which version of extern "Rust" they use, function pointers have no such luxury. However, calling through a function pointer is slow and rare, so we can simply force them to use -Zcallconv=legacy. We can generate a shim to translate calling conventions as needed.

Similarly, we can, in principle, call any Rust function like this:

fn secret_call() -> i32 {
  extern "Rust" {
    fn my_func() -> i32;
  }
  unsafe { my_func() }
}
Rust

However, this mechanism can only be used to call unmangled symbols. Thus, we can simply force #[no_mangle] symbols to use the legacy calling convention.

Bending LLVM to Our Will

In an ideal world, LLVM would provide a way for us to specify the calling convention directly. E.g., this argument goes in that register, this return goes in that one, etc. Unfortunately, adding a calling convention to LLVM requires writing a bunch of C++.

However, we can get away with specifying our own calling convention by following the following procedure.

  1. First, determine, for a given target triple, the maximum number of values that can be passed “by register”. I will explain how to do this below.

  2. Decide how to pass the return value. It will either fit in the output registers, or it will need to be returned “by reference”, in which case we pass an extra ptr argument to the function (tagged with the sret attribute) and the actual return value of the function is that pointer.

  3. Decide which arguments that have been passed by value need to be demoted to being passed by reference. This will be a heuristic, but generally will be approximately “arguments larger than the by-register space”. For example, on x86, this comes out to 176 bytes.

  4. Decide which arguments get passed by register, so as to maximize register space usage. This problem is NP-hard (it’s the knapsack problem) so it will require a heuristic. All other arguments are passed on the stack.

  5. Generate the function signature in LLVM IR. This will be all of the arguments that are passed by register encoded as various non-aggregates, such as i64, ptr, double, and <2 x i64>. What valid choices are for said non-aggregates depends on the target, but the above are what you will generally get on a 64-bit architecture. Arguments passed on the stack will follow the “register inputs”.

  6. Generate a function prologue. This is code to decode each Rust-level argument from the register inputs, so that there are %ssa values corresponding to those that would be present when using -Zcallconv=legacy. This allows us to generate the same code for the body of the function regardless of calling convention. Redundant decoding code will be eliminated by DCE passes.

  7. Generate a function exit block. This is a block that contains a single phi instruction for the return type as it would be for -Zcallconv=legacy. This block will encode it into the requisite output format and then ret as appropriate. All exit paths through the function should br to this block instead of ret-ing.

  8. If a non-polymorphic, non-inline function may have its address taken (as a function pointer), either because it is exported out of the crate or the crate takes a function pointer to it, generate a shim that uses -Zcallconv=legacy and immediately tail-calls the real implementation. This is necessary to preserve function pointer equality.

The main upshot here is that we need to cook up heuristics for figuring out what goes in registers (since we allow reordering arguments to get better throughput). This is equivalent to the knapsack problem; knapsack heuristics are beyond the scope of this article. This should happen early enough that this information can be stuffed into rmeta to avoid needing to recompute it. We may want to use different, faster heuristics depending on -Copt-level. Note that correctness requires that we forbid linking code generated by multiple different Rust compilers, which is already the case, since Rust breaks ABI from release to release.

What Is LLVM Willing to Do?

Assuming we do that, how do we actually get LLVM to pass things in the way we want it to? We need to determine what the largest “by register” passing LLVM will permit is. The following LLVM program is useful for determining this on a particular version of LLVM:

%InputI = type [6 x i64]
%InputF = type [0 x double]
%InputV = type [8 x <2 x i64>]

%OutputI = type [3 x i64]
%OutputF = type [0 x double]
%OutputV = type [4 x <2 x i64>]

define void @inputs({ %InputI, %InputF, %InputV }) {
  %p = alloca [4096 x i8]
  store volatile { %InputI, %InputF, %InputV } %0, ptr %p
  ret void
}

%Output = { %OutputI, %OutputF, %OutputV }
@gOutput = constant %Output zeroinitializer
define %Output @outputs() {
  %1 = load %Output, ptr @gOutput
  ret %Output %1
}
LLVM IR

When you pass an aggregate by-value to an LLVM function, LLVM will attempt to “explode” that aggregate into as many registers as possible. There are distinct register classes on different systems. For example, on both x86 and ARM, floats and vectors share the same register class (kind of2).

The above values are for x863. LLVM will pass six integers and eight SSE vectors by register, and return half as many (3 and 4) by register. Increasing any of the values generates extra loads and stores that indicate LLVM gave up and passed arguments on the stack.

The values for aarch64-unknown-linux are 8 integers and 8 vectors for both inputs and outputs, respectively.

This is the maximum number of registers we get to play with for each class. Anything extra gets passed on the stack.

I recommend that every function have the same number of by-register arguments. So on x86, EVERY -Zcallconv=fast function’s signature should look like this:

declare {[3 x i64], [4 x <2 x i64>]} @my_func(
  i64 %rdi, i64 %rsi, i64 %rdx, i64 %rcx, i64 %r8, i64 %r9,
  <2 x i64> %xmm0, <2 x i64> %xmm1, <2 x i64> %xmm2, <2 x i64> %xmm3,
  <2 x i64> %xmm4, <2 x i64> %xmm5, <2 x i64> %xmm6, <2 x i64> %xmm7,
  ; other args...
)
LLVM IR

When passing pointers, the appropriate i64s should be replaced by ptr, and when passing doubles, they replace <2 x i64>s.

But you’re probably saying, “Miguel, that’s crazy! Most functions don’t pass 176 bytes!” And you’d be right, if not for the magic of LLVM’s very well-specified poison semantics.

We can get away with not doing extra work if every argument we do not use is passed poison. Because poison is equal to “the most convenient possible value at the present moment”, when LLVM sees poison passed into a function via register, it decides that the most convenient value is “whatever happens to be in the register already”, and so it doesn’t have to touch that register!

For example, if we wanted to pass a pointer via rcx, we would generate the following code.

; This is a -Zcallconv=fast-style function.
%Out = type {[3 x i64], [4 x <2 x i64>]}
define %Out @load_rcx(
  i64 %rdi, i64 %rsi, i64 %rdx,
  ptr %rcx, i64 %r8, i64 %r9,
  <2 x i64> %xmm0, <2 x i64> %xmm1,
  <2 x i64> %xmm2, <2 x i64> %xmm3,
  <2 x i64> %xmm4, <2 x i64> %xmm5,
  <2 x i64> %xmm6, <2 x i64> %xmm7
) {
  %load = load i64, ptr %rcx
  %out = insertvalue %Out poison,
                      i64 %load, 0, 0
  ret %Out %out
}

declare ptr @malloc(i64)
define i64 @make_the_call() {
  %1 = call ptr @malloc(i64 8)
  store i64 42, ptr %1
  %2 = call %Out @by_rcx(
    i64 poison, i64 poison, i64 poison,
    ptr %1,     i64 poison, i64 poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison)
  %3 = extractvalue %Out %2, 0, 0
  %4 = add i64 %3, 42
  ret i64 %4
}
LLVM IR
by_rcx:
  mov   rax, qword ptr [rcx]
  ret

make_the_call:
  push  rax
  mov   edi, 8
  call  malloc
  mov   qword ptr [rax], 42
  mov   rcx, rax
  call  load_rcx
  add   rax, 42
  pop   rcx
  ret
x86 Assembly

It is perfectly legal to pass poison to a function, if it does not interact with the poisoned argument in any proscribed way. And as we see, load_rcx() receives its pointer argument in rcx, whereas make_the_call() takes no penalty in setting up the call: loading poison into the other thirteen registers compiles down to nothing4, so it only needs to load the pointer returned by malloc into rcx.

This gives us almost total control over argument passing; unfortunately, it is not total. In an ideal world, the same registers are used for input and output, to allow easier pipelining of calls without introducing extra register traffic. This is true on ARM and RISC-V, but not x86. However, because register ordering is merely a suggestion for us, we can choose to allocate the return registers in whatever order we want. For example, we can pretend the order registers should be allocated in is rdx, rcx, rdi, rsi, r8, r9 for inputs, and rdx, rcx, rax for outputs.

%Out = type {[3 x i64], [4 x <2 x i64>]}
define %Out @square(
  i64 %rdi, i64 %rsi, i64 %rdx,
  ptr %rcx, i64 %r8, i64 %r9,
  <2 x i64> %xmm0, <2 x i64> %xmm1,
  <2 x i64> %xmm2, <2 x i64> %xmm3,
  <2 x i64> %xmm4, <2 x i64> %xmm5,
  <2 x i64> %xmm6, <2 x i64> %xmm7
) {
  %sq = mul i64 %rdx, %rdx
  %out = insertvalue %Out poison,
                      i64 %sq, 0, 1
  ret %Out %out
}

define i64 @make_the_call(i64) {
  %2 = call %Out @square(
    i64 poison, i64 poison, i64 %0,
    i64 poison, i64 poison, i64 poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison)
  %3 = extractvalue %Out %2, 0, 1

  %4 = call %Out @square(
    i64 poison, i64 poison, i64 %3,
    i64 poison, i64 poison, i64 poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison,
    <2 x i64> poison, <2 x i64> poison)
  %5 = extractvalue %Out %4, 0, 1

  ret i64 %5
}
LLVM IR
square:
  imul rdx, rdx
  ret

make_the_call:
  push rax
  mov rdx, rdi
  call square
  call square
  mov rax, rdx
  pop rcx
  ret
x86 Assembly

square generates extremely simple code: the input and output register is rdi, so no extra register traffic needs to be generated. Similarly, when we effectively do @square(@square(%0)), there is no setup between the functions. This is similar to code seen on aarch64, which uses the same register sequence for input and output. We can see that the “naive” version of this IR produces the exact same code on aarch64 for this reason.

define i64 @square(i64) {
  %2 = mul i64 %0, %0
  ret i64 %2
}

define i64 @make_the_call(i64) {
  %2 = call i64 @square(i64 %0)
  %3 = call i64 @square(i64 %2)
  ret i64 %3
}
LLVM IR
square:
  mul x0, x0, x0
  ret

make_the_call:
  str x30, [sp, #-16]!
  bl square
  ldr x30, [sp], #16
  b square  // Tail call.
ARM Assembly

Rust Structs and Unions

Now that we’ve established total control on how registers are assigned, we can turn towards maximizing use of these registers in Rust.

For simplicity, we can assume that rustc has already processed the users’s types into basic aggregates and unions; no enums here! We then have to make some decisions about which portions of the arguments to allocate to registers.

First, return values. This is relatively straightforward, since there is only one value to pass. The amount of data we need to return is not the size of the struct. For example, [(u64, u32); 2] measures 32 bytes wide. However, eight of those bytes are padding! We do not need to preserve padding when returning by value, so we can flatten the struct into (u64, u32, u64, u32) and sort by size into (u64, u64, u32, u32). This has no padding and is 24 bytes wide, which fits into the three return registers LLVM gives us on x86. We define the effective size of a type to be the number of non-undef bits it occupies. For [(u64, u32); 2], this is 192 bits, since it excludes the padding. For bool, this is one. For char this is technically 21, but it’s simpler to treat char as an alias for u32.

The reason for counting bits this way is that it permits significant compaction. For example, returning a struct full of bools can simply bit-pack the bools into a single register.

So, a return value is converted to a by-ref return if its effective size is smaller than the output register space (on x86, this is three integer registers and four SSE registers, so we get 88 bytes total, or 704 bits).

Argument registers are much harder, because we hit the knapsack problem, which is NP-hard. The following relatively naive heuristic is where I would start, but it can be made infinitely smarter over time.

First, demote to by-ref any argument whose effective size is larget than the total by-register input space (on x86, 176 bytes or 1408 bits). This means we get a pointer argument instead. This is beneficial to do first, since a single pointer might pack better than the huge struct.

Enums should be replaced by the appropriate discriminant-union pair. For example, Option<i32> is, internally, (union { i32, () }, i1), while Option<Option<i32>> is (union { i32, (), () }, i2). Using a small non-power-of-two integer improves our ability to pack things, since enum discriminants are often quite tiny.

Next, we need to handle unions. Because mucking about with unions’ uninitialized bits behind our backs is allowed, we need to either pass it as an array of u8, unless it only has a single non-empty variant, in which case it is replaced with that variant5.

Now, we can proceed to flatten everything. All of the converted arguments are flattened into their most primitive components: pointers, integers, floats, and bools. Every field should be no larger than the smallest argument register; this may require splitting large types such as u128 or f64.

This big list of primitives is next sorted by effective size, from smallest to largest. We take the largest prefix of this that will fit in the available register space; everything else goes on the stack.

If part of a Rust-level input is sent to the stack in this way, and that part is larger than a small multiple of the pointer size (e.g., 2x), it is demoted to being passed by pointer-on-the-stack, to minimize memory traffic. Everything else is passed directly on the stack in the order those inputs were before the sort. This helps keep regions that need to be copied relatively contiguous, to minimize calls to memcpy.

The things we choose to pass in registers are allocated to registers in reverse size order, so e.g. first 64-bit things, then 32-bit things, etc. This is the same layout algorithm that repr(Rust) structs use to move all the padding into the tail. Once we get to the bools, those are bit-packed, 64 to a register.

Here’s a relatively complicated example. My Rust function is as follows:

struct Options {
  colorize: bool,
  verbose_debug: bool,
  allow_spurious_failure: bool,
  retries: u32,
}

trait Context {
  fn check(&self, n: usize, colorize: bool);
}

fn do_thing<'a>(op_count: Option<usize>, context: &dyn Context,
                name: &'a str, code: [char; 6],
                options: Options,
) -> &'a str {
  if let Some(op_count) = op_count {
    context.check(op_count, options.colorize);
  }

  for c in code {
    if let Some((_, suf)) = name.split_once(c) {
      return suf;
    }
  }

  "idk"
}
Rust

The codegen for this function is quite complex, so I’ll only cover the prologue and epilogue. After sorting and flattening, our raw argument LLVM types are something like this:

gprs: i64, ptr, ptr, ptr, i64, i32, i32
xmm0: i32, i32, i32, i32
xmm1: i32, i1, i1, i1, i1
LLVM IR

Everything fits in registers! So, what does the LLVM function look like on x86?

%Out = type {[3 x i64], [4 x <2 x i64>]}
define %Out @do_thing(
  i64 %rdi, ptr %rsi, ptr %rdx,
  ptr %rcx, i64 %r8, i64 %r9,
  <4 x i32> %xmm0, <4 x i32> %xmm1,
  ; Unused.
  <2 x i64> %xmm2, <2 x i64> %xmm3,
  <2 x i64> %xmm4, <2 x i64> %xmm5,
  <2 x i64> %xmm6, <2 x i64> %xmm7
) {
  ; First, unpack all the primitives.
  %r9.0 = trunc i64 %r9 to i32
  %r9.1.i64 = lshr i64 %r9, 32
  %r9.1 = trunc i64 %r9.1.i64 to i32
  %xmm0.0 = extractelement <4 x i32> %xmm0, i32 0
  %xmm0.1 = extractelement <4 x i32> %xmm0, i32 1
  %xmm0.2 = extractelement <4 x i32> %xmm0, i32 2
  %xmm0.3 = extractelement <4 x i32> %xmm0, i32 3
  %xmm1.0 = extractelement <4 x i32> %xmm1, i32 0
  %xmm1.1 = extractelement <4 x i32> %xmm1, i32 1
  %xmm1.1.0 = trunc i32 %xmm1.1 to i1
  %xmm1.1.1.i32 = lshr i32 %xmm1.1, 1
  %xmm1.1.1 = trunc i32 %xmm1.1.1.i32 to i1
  %xmm1.1.2.i32 = lshr i32 %xmm1.1, 2
  %xmm1.1.2 = trunc i32 %xmm1.1.2.i32 to i1
  %xmm1.1.3.i32 = lshr i32 %xmm1.1, 3
  %xmm1.1.3 = trunc i32 %xmm1.1.3.i32 to i1

  ; Next, reassemble them into concrete values as needed.
  %op_count.0 = insertvalue { i64, i1 } poison, i64 %rdi, 0
  %op_count = insertvalue { i64, i1 } %op_count.0, i1 %xmm1.1.0, 1
  %context.0 = insertvalue { ptr, ptr } poison, ptr %rsi, 0
  %context = insertvalue { ptr, ptr } %context.0, ptr %rdx, 1
  %name.0 = insertvalue { ptr, i64 } poison, ptr %rcx, 0
  %name = insertvalue { ptr, i64 } %name.0, i64 %r8, 1
  %code.0 = insertvalue [6 x i32] poison, i32 %r9.0, 0
  %code.1 = insertvalue [6 x i32] %code.0, i32 %r9.1, 1
  %code.2 = insertvalue [6 x i32] %code.1, i32 %xmm0.0, 2
  %code.3 = insertvalue [6 x i32] %code.2, i32 %xmm0.1, 3
  %code.4 = insertvalue [6 x i32] %code.3, i32 %xmm0.2, 4
  %code = insertvalue [6 x i32] %code.4, i32 %xmm0.3, 5
  %options.0 = insertvalue { i32, i1, i1, i1 } poison, i32 %xmm1.0, 0
  %options.1 = insertvalue { i32, i1, i1, i1 } %options.0, i1 %xmm1.1.1, 1
  %options.2 = insertvalue { i32, i1, i1, i1 } %options.1, i1 %xmm1.1.2, 2
  %options = insertvalue { i32, i1, i1, i1 } %options.2, i1 %xmm1.1.3, 3

  ; Codegen as usual.
  ; ...
}
LLVM IR

Above, !dbg metadata for the argument values should be attached to the instruction that actually materializes it. This ensures that gdb does something halfway intelligent when you ask it to print argument values.

On the other hand, in current rustc, it gives LLVM eight pointer-sized parameters, so it winds up spending all six integer registers, plus two values passed on the stack. Not great!

This is not a complete description of what a completely over-engineered calling convention could entail: in some cases we might know that we have additional registers available (such as AVX registers on x86). There are cases where we might want to split a struct across registers and the stack.

This also isn’t even getting into what returns could look like. Results are often passed through several layers of functions via ?, which can result in a lot of redundant register moves. Often, a Result is large enough that it doesn’t fit in registers, so each call in the ? stack has to inspect an ok bit by loading it from memory. Instead, a Result return might be implemented as an out-parameter pointer for the error, with the ok variant’s payload, and the is ok bit, returned as an Option<T>. There are some fussy details with Into calls via ?, but the idea is implementable.

Optimization-Dependent ABI

Now, because we’re Rust, we’ve also got a trick up our sleeve that C doesn’t (but Go does)! When we’re generating the ABI that all callers will see (for -Zcallconv=fast), we can look at the function body. This means that a crate can advertise the precise ABI (in terms of register-passing) of its functions.

This opens the door to a more extreme optimization-based ABIs. We can start by simply throwing out unused arguments: if the function never does anything with a parameter, don’t bother spending registers on it.

Another example: suppose that we know that an &T argument is not retained (a question the borrow checker can answer at this point in the compiler) and is never converted to a raw pointer (or written to memory a raw pointer is taken of, etc). We also know that T is fairly small, and T: Freeze. Then, we can replace the reference with the pointee directly, passed by value.

The most obvious candidates for this is APIs like HashMap::get(). If the key is something like an i32, we need to spill that integer to the stack and pass a pointer to it! This results in unnecessary, avoidable memory traffic.

Profile-guided ABI is a step further. We might know that some arguments are hotter than others, which might cause them to be prioritized in the register allocation order.

You could even imagine a case where a function takes a very large struct by reference, but three i64 fields are very hot, so the caller can preload those fields, passing them both by register and via the pointer to the large struct. The callee does not see additional cost: it had to issue those loads anyway. However, the caller probably has those values in registers already, which avoids some memory traffic.

Instrumentation profiles may even indicate that it makes sense to duplicate whole functions, which are identical except for their ABIs. Maybe they take different arguments by register to avoid costly spills.

Conclusion

This is a bit more advanced (and ranty) than my usual writing, but this is an aspect of Rust that I find really frustrating. We could be doing so much better than C++ ever can (because of their ABI constraints). None of this is new ideas; this is literally how Go does it!

So why don’t we? Part of the reason is that ABI codegen is complex, and as I described above, LLVM gives us very few useful knobs. It’s not a friendly part of rustc, and doing things wrong can have nasty consequences for usability. The other part is a lack of expertise. As of writing, only a handful of people contributing to rustc have the necessary grasp of LLVM’s semantics (and mood swings) to emit the Right Code such that we get good codegen and don’t crash LLVM.

Another reason is compilation time. The more complicated the function signatures, the more prologue/epilogue code we have to generate that LLVM has to chew on. But -Zcallconv is intended to only be used with optimizations turned on, so I don’t think this is a meaningful complaint. Nor do I think the project’s Goodhartization of compilation time as a metric is healthy… but I do not think this is ultimately a relevant drawback.

I, unfortunately, do not have the spare time to dive into fixing rustc’s ABI code, but I do know LLVM really well, and I know that this is a place where Rust has a low bus factor. For that reason, I am happy to provide the Rust compiler team expert knowledge on getting LLVM to do the right thing in service of making optimized code faster.

  1. Or just switch it to the codepath for extern "C" or extern "fastcall" since those are clearly better. We will always need to know how to generate code for the non-extern "Rust" calling conventions. 

  2. It’s Complicated. Passing a double burns a whole <2 x i64> slot. This seems bad, but it can be beneficial since keeping a double in vector registers reduces register traffic, since usually, fp instructions use the vector registers (or the fp registers shadow the vector registers, like on ARM). 

  3. On the one hand, you might say this “extended calling convention” isn’t an explicitly supported part of LLVM’s ccc calling convention. On the other hand, Hyrum’s Law cuts both ways: Rust is big enough of an LLVM user that LLVM cannot simply miscompile all Rust programs at this point, and the IR I propose Rust emits is extremely reasonable.

    If Rust causes LLVM to misbehave, that’s an LLVM bug, and we should fix LLVM bugs, not work around them. 

  4. Only on -O1 or higher, bizarrely. At -O0, LLVM decides that all of the poisons must have the same value, so it copies a bunch of registers around needlessly. This seems like a bug? 

  5. There are other cases where we might want to replace a union with one of its variants: for example, there’s a lot of cases where Result<&T, Error> is secretly a union { ptr, u32 }, in which case it should be replaced with a single ptr

Designing a SIMD Algorithm from Scratch

Another explainer on a fun, esoteric topic: optimizing code with SIMD (single instruction multiple data, also sometimes called vectorization). Designing a good, fast, portable SIMD algorithm is not a simple matter and requires thinking a little bit like a circuit designer.

Here’s the mandatory performance benchmark graph to catch your eye.

perf perf perf

“SIMD” often gets thrown around as a buzzword by performance and HPC (high performance computing) nerds, but I don’t think it’s a topic that has very friendly introductions out there, for a lot of reasons.

  • It’s not something you will really want to care about unless you think performance is cool.
  • APIs for programming with SIMD in most programming languages are garbage (I’ll get into why).
  • SIMD algorithms are hard to think about if you’re very procedural-programming-brained. A functional programming mindset can help a lot.

This post is mostly about vb64 (which stands for vector base64), a base64 codec I wrote to see for myself if Rust’s std::simd library is any good, but it’s also an excuse to talk about SIMD in general.

What is SIMD, anyways? Let’s dive in.

If you want to skip straight to the writeup on vb64, click here.

Problems with Physics

Unfortunately, computers exist in the real world[citation-needed], and are bound by the laws of nature. SIMD has relatively little to do with theoretical CS considerations, and everything to do with physics.

In the infancy of modern computing, you could simply improve performance of existing programs by buying new computers. This is often incorrectly attributed to Moore’s law (the number of transistors on IC designs doubles every two years). Moore’s law still appears to hold as of 2023, but some time in the last 15 years the Dennard scaling effect broke down. This means that denser transistors eventually means increased power dissipation density. In simpler terms, we don’t know how to continue to increase the clock frequency of computers without literally liquefying them.

So, since the early aughts, the hot new thing has been bigger core counts. Make your program more multi-threaded and it will run faster on bigger CPUs. This comes with synchronization overhead, since now the cores need to cooperate. All control flow, be it jumps, virtual calls, or synchronization will result in “stall”.

The main causes of stall are branches, instructions that indicate code can take one of two possible paths (like an if statement), and memory operations. Branches include all control flow: if statements, loops, function calls, function returns, even switch statements in C. Memory operations are loads and stores, especially ones that are cache-unfriendly.

Procedural Code Is Slow

Modern compute cores do not execute code line-by-line, because that would be very inefficient. Suppose I have this program:

let a = x + y;
let b = x ^ y;
println!("{a}, {b}");
Rust

There’s no reason for the CPU to wait to finish computing a before it begins computing b; it does not depend on a, and while the add is being executed, the xor circuits are idle. Computers say “program order be damned” and issue the add for a and the xor for b simultaneously. This is called instruction-level parallelism, and dependencies that get in the way of it are often called data hazards.

Of course, the Zen 2 in the machine I’m writing this with does not have one measly adder per core. It has dozens and dozens! The opportunities for parallelism are massive, as long as the compiler in your CPU’s execution pipeline can clear any data hazards in the way.

The better the core can do this, the more it can saturate all of the “functional units” for things like arithmetic, and the more numbers it can crunch per unit time, approaching maximum utilization of the hardware. Whenever the compiler can’t do this, the execution pipeline stalls and your code is slower.

Branches stall because they need to wait for the branch condition to be computed before fetching the next instruction (speculative execution is a somewhat iffy workaround for this). Memory operations stall because the data needs to physically arrive at the CPU, and the speed of light is finite in this universe.

Trying to reduce stall by improving opportunities for single-core parallelism is not a new idea. Consider the not-so-humble GPU, whose purpose in life is to render images. Images are vectors of pixels (i.e., color values), and rendering operations tend to be highly local. For example, a convolution kernel for a Gaussian blur will be two or even three orders of magnitude smaller than the final image, lending itself to locality.

Thus, GPUs are built for divide-and-conquer: they provide primitives for doing batched operations, and extremely limited control flow.

“SIMD” is synonymous with “batching”. It stands for “single instruction, multiple data”: a single instruction dispatches parallel operations on multiple lanes of data. GPUs are the original SIMD machines.

Lane-Wise

“SIMD” and “vector” are often used interchangeably. The fundamental unit a SIMD instruction (or “vector instruction”) operates on is a vector: a fixed-size array of numbers that you primarily operate on component-wise These components are called lanes.

SIMD vectors are usually quite small, since they need to fit into registers. For example, on my machine, the largest vectors are 256 bits wide. This is enough for 32 bytes (a u8x32), 4 double-precision floats (an f64x8), or all kinds of things in between.

some 256-bit vectors

Although this doesn’t seem like much, remember that offloading the overhead of keeping the pipeline saturated by a factor of 4x can translate to that big of a speedup in latency.

One-Bit Lanes

The simplest vector operations are bitwise: and, or, xor. Ordinary integers can be thought of as vectors themselves, with respect to the bitwise operations. That’s literally what “bitwise” means: lanes-wise with lanes that are one bit wide. An i32 is, in this regard, an i1x32.

In fact, as a warmup, let’s look at the problem of counting the number of 1 bits in an integer. This operation is called “population count”, or popcnt. If we view an i32 as an i1x32, popcnt is just a fold or reduce operation:

pub fn popcnt(mut x: u32) -> u32 {
  let mut bits = [0; 32];
  for (i, bit) in bits.iter_mut().enumerate() {
    *bit = (x >> i) & 1;
  }
  bits.into_iter().fold(0, |total, bit| total + bit)
}

In other words, we interpret the integer as an array of bits and then add the bits together to a 32-bit accumulator. Note that the accumulator needs to be higher precision to avoid overflow: accumulating into an i1 (as with the Iterator::reduce() method) will only tell us whether the number of 1 bits is even or odd.

Of course, this produces… comically bad code, frankly. We can do much better if we notice that we can vectorize the addition: first we add all of the adjacent pairs of bits together, then the pairs of pairs, and so on. This means the number of adds is logarithmic in the number of bits in the integer.

Visually, what we do is we “unzip” each vector, shift one to line up the lanes, add them, and then repeat with lanes twice as big.

first two popcnt merge steps

This is what that looks like in code.

pub fn popcnt(mut x: u32) -> u32 {
  // View x as a i1x32, and split it into two vectors
  // that contain the even and odd bits, respectively.
  let even = x & 0x55555555; // 0x5 == 0b0101.
  let odds = x & 0xaaaaaaaa; // 0xa == 0b1010.
  // Shift odds down to align the bits, and then add them together.
  // We interpret x now as a i2x16. When adding, each two-bit
  // lane cannot overflow, because the value in each lane is
  // either 0b00 or 0b01.
  x = even + (odds >> 1);

  // Repeat again but now splitting even and odd bit-pairs.
  let even = x & 0x33333333; // 0x3 == 0b0011.
  let odds = x & 0xcccccccc; // 0xc == 0b1100.
  // We need to shift by 2 to align, and now for this addition
  // we interpret x as a i4x8.
  x = even + (odds >> 2);

  // Again. The pattern should now be obvious.
  let even = x & 0x0f0f0f0f; // 0x0f == 0b00001111.
  let odds = x & 0xf0f0f0f0; // 0xf0 == 0b11110000.
  x = even + (odds >> 4); // i8x4

  let even = x & 0x00ff00ff;
  let odds = x & 0xff00ff00;
  x = even + (odds >> 8);  // i16x2

  let even = x & 0x0000ffff;
  let odds = x & 0xffff0000;
  // Because the value of `x` is at most 32, although we interpret this as a
  // i32x1 add, we could get away with just one e.g. i16 add.
  x = even + (odds >> 16);

  x // Done. All bits have been added.
}

This still won’t optimize down to a popcnt instruction, of course. The search scope for such a simplification is in the regime of superoptimizers. However, the generated code is small and fast, which is why this is the ideal implementation of popcnt for systems without such an instruction.

It’s especially nice because it is implementable for e.g. u64 with only one more reduction step (remember: it’s O(logn)O(\log n)!), and does not at any point require a full u64 addition.

Even though this is “just” using scalars, divide-and-conquer approaches like this are the bread and butter of the SIMD programmer.

Scaling Up: Operations on Real Vectors

Proper SIMD vectors provide more sophisticated semantics than scalars do, particularly because there is more need to provide replacements for things like control flow. Remember, control flow is slow!

What’s actually available is highly dependent on the architecture you’re compiling to (more on this later), but the way vector instruction sets are usually structured is something like this.

We have vector registers that are kind of like really big general-purpose registers. For example, on x86, most “high performance” cores (like my Zen 2) implement AVX2, which provides 256 bit ymm vectors. The registers themselves do not have a “lane count”; that is specified by the instructions. For example, the “vector byte add instruction” interprets the register as being divided into eight-byte lanes and adds them. The corresponding x86 instruction is vpaddb, which interprets a ymm as an i8x32.

The operations you usually get are:

  1. Bitwise operations. These don’t need to specify a lane width because it’s always implicitly 1: they’re bitwise.

  2. Lane-wise arithmetic. This is addition, subtraction, multiplication, division (both int and float), and shifts1 (int only). Lane-wise min and max are also common. These require specifying a lane width. Typically the smallest number of lanes is two or four.

  3. Lane-wise compare. Given a and b, we can create a new mask vector m such that m[i] = a[i] < b[i] (or any other comparison operation). A mask vector’s lanes contain boolean values with an unusual bit-pattern: all-zeros (for false) or all-ones (for true)2.

    • Masks can be used to select between two vectors: for example, given m, x, and y, you can form a fourth vector z such that z[i] = m[i] ? a[i] : b[i].
  4. Shuffles (sometimes called swizzles). Given a and x, create a third vector s such that s[i] = a[x[i]]. a is used as a lookup table, and x as a set of indices. Out of bounds produces a special value, usually zero. This emulates parallelized array access without needing to actually touch RAM (RAM is extremely slow).

    • Often there is a “shuffle2” or “riffle” operation that allows taking elements from one of two vectors. Given a, b, and x, we now define s as being s[i] = (a ++ b)[x[i]], where a ++ b is a double-width concatenation. How this is actually implemented depends on architecture, and it’s easy to build out of single shuffles regardless.

(1) and (2) are ordinary number crunching. Nothing deeply special about them.

The comparison and select operations in (3) are intended to help SIMD code stay “branchless”. Branchless code is written such that it performs the same operations regardless of its inputs, and relies on the properties of those operations to produce correct results. For example, this might mean taking advantage of identities like x * 0 = 0 and a ^ b ^ a = b to discard “garbage” results.

The shuffles described in (4) are much more powerful than meets the eye.

For example, “broadcast” (sometimes called “splat”) makes a vector whose lanes are all the same scalar, like Rust’s [42; N] array literal. A broadcast can be expressed as a shuffle: create a vector with the desired value in the first lane, and then shuffle it with an index vector of [0, 0, ...].

diagram of a broadcast

“Interleave” (also called “zip” or “pack”) takes two vectors a and b and creates two new vectors c and d whose lanes are alternating lanes from a and b. If the lane count is n, then c = [a[0], b[0], a[1], b[1], ...] and d = [a[n/2], b[n/2], a[n/2 + 1], b[n/2 + 1], ...]. This can also be implemented as a shuffle2, with shuffle indices of [0, n, 1, n + 1, ...]. “Deinterleave” (or “unzip”, or “unpack”) is the opposite operation: it interprets a pair of vectors as two halves of a larger vector of pairs, and produces two new vectors consisting of the halves of each pair.

Interleave can also be interpreted as taking a [T; N], transmuting it to a [[T; N/2]; 2], performing a matrix transpose to turn it into a [[T; 2]; N/2], and then transmuting that back to [T; N] again. Deinterleave is the same but it transmutes to [[T; 2]; N/2] first.

diagram of a interleave

“Rotate” takes a vector a with n lanes and produces a new vector b such that b[i] = a[(i + j) % n], for some chosen integer j. This is yet another shuffle, with indices [j, j + 1, ..., n - 1, 0, 1, ... j - 1].

diagram of a rotate

Shuffles are worth trying to wrap your mind around. SIMD programming is all about reinterpreting larger-than-an-integer-sized blocks of data as smaller blocks of varying sizes, and shuffling is important for getting data into the right “place”.

Intrinsics and Instruction Selection

Earlier, I mentioned that what you get varies by architecture. This section is basically a giant footnote.

So, there’s two big factors that go into this.

  1. We’ve learned over time which operations tend to be most useful to programmers. x86 might have something that ARM doesn’t because it “seemed like a good idea at the time” but turned out to be kinda niche.
  2. Instruction set extensions are often market differentiators, even within the same vendor. Intel has AVX-512, which provides even more sophisticated instructions, but it’s only available on high-end server chips, because it makes manufacturing more expensive.

Toolchains generalize different extensions as “target features”. Features can be detected at runtime through architecture-specific magic. On Linux, the lscpu command will list what features the CPU advertises that it recognizes, which correlate with the names of features that e.g. LLVM understands. What features are enabled for a particular function affects how LLVM compiles it. For example, LLVM will only emit ymm-using code when compiling with +avx2.

So how do you write portable SIMD code? On the surface, the answer is mostly “you don’t”, but it’s more complicated than that, and for that we need to understand how the later parts of a compiler works.

When a user requests an add by writing a + b, how should I decide which instruction to use for it? This seems like a trick question… just an add right? On x86, even this isn’t so easy, since you have a choice between the actual add instruction, or a lea instruction (which, among other things, preserves the rflags register). This question becomes more complicated for more sophisticated operations. This general problem is called instruction selection.

Because which “target features” are enabled affects which instructions are available, they affect instruction selection. When I went over operations “typically available”, this means that compilers will usually be able to select good choices of instructions for them on most architectures.

Compiling with something like -march=native or -Ctarget-cpu=native gets you “the best” code possible for the machine you’re building on, but it might not be portable3 to different processors. Gentoo was quite famous for building packages from source on user machines to take advantage of this (not to mention that they loved using -O3, which mostly exists to slow down build times with little benefit).

There is also runtime feature detection, where a program decides which version of a function to call at runtime by asking the CPU what it supports. Code deployed on heterogenous devices (like cryptography libraries) often make use of this. Doing this correctly is very hard and something I don’t particularly want to dig deeply into here.

The situation is made worse by the fact that in C++, you usually write SIMD code using “intrinsics”, which are special functions with inscrutable names like _mm256_cvtps_epu32 that represent a low-level operation in a specific instruction set (this is a float to int cast from AVX2). Intrinsics are defined by hardware vendors, but don’t necessarily map down to single instructions; the compiler can still optimize these instructions by merging, deduplication, and through instruction selection.

As a result you wind up writing the same code multiple times for different instruction sets, with only minor maintainability benefits over writing assembly.

The alternative is a portable SIMD library, which does some instruction selection behind the scenes at the library level but tries to rely on the compiler for most of the heavy-duty work. For a long time I was skeptical that this approach would actually produce good, competitive code, which brings us to the actual point of this article: using Rust’s portable SIMD library to implement a somewhat fussy algorithm, and measuring performance.

Parsing with SIMD

Let’s design a SIMD implementation for a well-known algorithm. Although it doesn’t look like it at first, the power of shuffles makes it possible to parse text with SIMD. And this parsing can be very, very fast.

In this case, we’re going to implement base64 decoding. To review, base64 is an encoding scheme for arbitrary binary data into ASCII. We interpret a byte slice as a bit vector, and divide it into six-bit chunks called sextets. Then, each sextet from 0 to 63 is mapped to an ASCII character:

  1. 0 to 25 go to 'A' to 'Z'.
  2. 26 to 51 go to 'a' to 'z'.
  3. 52 to 61 go to '0' to '9'.
  4. 62 goes to +.
  5. 63 goes to /.

There are other variants of base64, but the bulk of the complexity is the same for each variant.

There are a few basic pitfalls to keep in mind.

  1. Base64 is a “big endian” format: specifically, the bits in each byte are big endian. Because a sextet can span only parts of a byte, this distinction is important.

  2. We need to beware of cases where the input length is not divisible by 4; ostensibly messages should be padded with = to a multiple of 4, but it’s easy to just handle messages that aren’t padded correctly.

The length of a decoded message is given by this function:

fn decoded_len(input: usize) -> usize {
  input / 4 * 3 + match input % 4 {
    1 | 2 => 1,
    3 => 2,
    _ => 0,
  }
}
Rust

Given all this, the easiest way to implement base64 is something like this.

fn decode(data: &[u8], out: &mut Vec<u8>) -> Result<(), Error> {
  // Tear off at most two trailing =.
  let data = match data {
    [p @ .., b'=', b'='] | [p @ .., b'='] | p => p,
  };

  // Split the input into chunks of at most 4 bytes.
  for chunk in data.chunks(4) {
    let mut bytes = 0u32;
    for &byte in chunk {
      // Translate each ASCII character into its corresponding
      // sextet, or return an error.
      let sextet = match byte {
        b'A'..=b'Z' => byte - b'A',
        b'a'..=b'z' => byte - b'a' + 26,
        b'0'..=b'9' => byte - b'0' + 52,
        b'+' => 62,
        b'/' => 63,
        _ => return Err(Error(...)),
      };

      // Append the sextet to the temporary buffer.
      bytes <<= 6;
      bytes |= sextet as u32;
    }

    // Shift things so the actual data winds up at the
    // top of `bytes`.
    bytes <<= 32 - 6 * chunk.len();

    // Append the decoded data to `out`, keeping in mind that
    // `bytes` is big-endian encoded.
    let decoded = decoded_len(chunk.len());
    out.extend_from_slice(&bytes.to_be_bytes()[..decoded]);
  }

  Ok(())
}
Rust

So, what’s the process of turning this into a SIMD version? We want to follow one directive with inexorable, robotic dedication.

Eliminate all branches.

This is not completely feasible, since the input is of variable length. But we can try. There are several branches in this code:

  1. The for chunk in line. This one is is the length check: it checks if there is any data left to process.
  2. The for &byte in line. This is the hottest loop: it branches once per input byte.
  3. The match byte line is several branches, to determine which of the five “valid” match arms we land in.
  4. The return Err line. Returning in a hot loop is extra control flow, which is not ideal.
  5. The call to decoded_len contains a match, which generates branches.
  6. The call to Vec::extend_from_slice. This contains not just branches, but potential calls into the allocator. Extremely slow.

(5) is the easiest to deal with. The match is mapping the values 0, 1, 2, 3 to 0, 1, 1, 2. Call this function f. Then, the sequence given by x - f(x) is 0, 0, 1, 1. This just happens to equal x / 2 (or x >> 1), so we can write a completely branchless version of decoded_len like so.

pub fn decoded_len(input: usize) -> usize {
  let mod4 = input % 4;
  input / 4 * 3 + (mod4 - mod4 / 2)
}

That’s one branch eliminated4. ✅

The others will not prove so easy. Let’s turn our attention to the innermost loop next, branches (2), (3), and (4).

The Hottest Loop

The superpower of SIMD is that because you operate on so much data at a time, you can unroll the loop so hard it becomes branchless.

The insight is this: we want to load at most four bytes, do something to them, and then spit out at most three decoded bytes. While doing this operation, we may encounter a syntax error so we need to report that somehow.

Here’s some facts we can take advantage of.

  1. We don’t need to figure out how many bytes are in the “output” of the hot loop: our handy branchless decoded_len() does that for us.
  2. Invalid base64 is extremely rare. We want that syntax error to cost as little as possible. If the user still cares about which byte was the problem, they can scan the input for it after the fact.
  3. A is zero in base64. If we’re parsing a truncated chunk, padding it with A won’t change the value5.

This suggests an interface for the body of the “hottest loop”. We can factor it out as a separate function, and simplify since we can assume our input is always four bytes now.

fn decode_hot(ascii: [u8; 4]) -> ([u8; 3], bool) {
  let mut bytes = 0u32;
  let mut ok = true;
  for byte in ascii {
    let sextet = match byte {
      b'A'..=b'Z' => byte - b'A',
      b'a'..=b'z' => byte - b'a' + 26,
      b'0'..=b'9' => byte - b'0' + 52,
      b'+' => 62,
      b'/' => 63,
      _ => !0,
    };

    bytes <<= 6;
    bytes |= sextet as u32;
    ok &= byte == !0;
  }

  // This is the `to_be_bytes()` call.
  let [b1, b2, b3, _] = bytes.to_le_bytes();
  ([b3, b2, b1], ok)
}

// In decode()...
for chunk in data.chunks(4) {
  let mut ascii = [b'A'; 4];
  ascii[..chunk.len()].copy_from_slice(chunk);

  let (bytes, ok) = decode_hot(ascii);
  if !ok {
    return Err(Error)
  }

  let len = decoded_len(chunk.len());
  out.extend_from_slice(&bytes[..decoded]);
}
Rust

You’re probably thinking: why not return Option<[u8; 3]>? Returning an enum will make it messier to eliminate the if !ok branch later on (which we will!). We want to write branchless code, so let’s focus on finding a way of producing that three-byte output without needing to do early returns.

Now’s when we want to start talking about vectors rather than arrays, so let’s try to rewrite our function as such.

fn decode_hot(ascii: Simd<u8, 4>) -> (Simd<u8, 4>, bool) {
  unimplemented!()
}
Rust

Note that the output is now four bytes, not three. SIMD lane counts need to be powers of two, and that last element will never get looked at, so we don’t need to worry about what winds up there.

The callsite also needs to be tweaked, but only slightly, because Simd<u8, 4> is From<[u8; 4]>.

ASCII to Sextet

Let’s look at the first part of the for byte in ascii loop. We need to map each lane of the Simd<u8, 4> to the corresponding sextet, and somehow signal which ones are invalid. First, notice something special about the match: almost every arm can be written as byte - C for some constant C. The non-range case looks a little silly, but humor me:

let sextet = match byte {
  b'A'..=b'Z' => byte - b'A',
  b'a'..=b'z' => byte - b'a' + 26,
  b'0'..=b'9' => byte - b'0' + 52,
  b'+'        => byte - b'+' + 62,
  b'/'        => byte - b'/' + 63,
  _ => !0,
};
Rust

So, it should be sufficient to build a vector offsets that contains the appropriate constant C for each lane, and then let sextets = ascii - offsets;

How can we build offsets? Using compare-and-select.

// A lane-wise version of `x >= start && x <= end`.
fn in_range(bytes: Simd<u8, 4>, start: u8, end: u8) -> Mask<i8, 4> {
  bytes.simd_ge(Simd::splat(start)) & bytes.simd_le(Simd::splat(end))
}

// Create masks for each of the five ranges.
// Note that these are disjoint: for any two masks, m1 & m2 == 0.
let uppers = in_range(ascii, b'A', b'Z');
let lowers = in_range(ascii, b'a', b'z');
let digits = in_range(ascii, b'0', b'9');
let pluses = ascii.simd_eq([b'+'; N].into());
let solidi = ascii.simd_eq([b'/'; N].into());

// If any byte was invalid, none of the masks will select for it,
// so that lane will be 0 in the or of all the masks. This is our
// validation check.
let ok = (uppers | lowers | digits | pluses | solidi).all();

// Given a mask, create a new vector by splatting `value`
// over the set lanes.
fn masked_splat(mask: Mask<i8, N>, value: i8) -> Simd<i8, 4> {
  mask.select(Simd::splat(val), Simd::splat(0))
}

// Fill the the lanes of the offset vector by filling the
// set lanes with the corresponding offset. This is like
// a "vectorized" version of the `match`.
let offsets = masked_splat(uppers,  65)
            | masked_splat(lowers,  71)
            | masked_splat(digits,  -4)
            | masked_splat(pluses, -19)
            | masked_splat(solidi, -16);

// Finally, Build the sextets vector.
let sextets = ascii.cast::<i8>() - offsets;
Rust

This solution is quite elegant, and will produce very competitive code, but it’s not actually ideal. We need to do a lot of comparisons here: eight in total. We also keep lots of values alive at the same time, which might lead to unwanted register pressure.

SIMD Hash Table

Let’s look at the byte representations of the ranges. A-Z, a-z, and 0-9 are, as byte ranges, 0x41..0x5b, 0x61..0x7b, and 0x30..0x3a. Notice they all have different high nybbles! What’s more, + and / are 0x2b and 0x2f, so the function byte >> 4 is almost enough to distinguish all the ranges. If we subtract one if byte == b'/', we have a perfect hash for the ranges.

In other words, the value (byte >> 4) - (byte == '/') maps the ranges as follows:

  • A-Z goes to 4 or 5.
  • a-z goes to 6 or 7.
  • 0-9 goes to 3.
  • + goes to 2.
  • / goes to 1.

This is small enough that we could cram a lookup table of values for building the offsets vector into another SIMD vector, and use a shuffle operation to do the lookup.

This is not my original idea; I came across a GitHub issue where an anonymous user points out this perfect hash.

Our new ascii-to-sextet code looks like this:

// Compute the perfect hash for each lane.
let hashes = (ascii >> Simd::splat(4))
  + Simd::simd_eq(ascii, Simd::splat(b'/'))
    .to_int()  // to_int() is equivalent to masked_splat(-1, 0).
    .cast::<u8>();

// Look up offsets based on each hash and subtract them from `ascii`.
let sextets = ascii
    // This lookup table corresponds to the offsets we used to build the
    // `offsets` vector in the previous implementation, placed in the
    // indices that the perfect hash produces.
  - Simd::<i8, 8>::from([0, 16, 19, 4, -65, -65, -71, -71])
    .cast::<u8>()
    .swizzle_dyn(hashes);
Rust

There is a small wrinkle here: Simd::swizzle_dyn() requires that the index array be the same length as the lookup table. This is annoying because right now ascii is a Simd<u8, 4>, but that will not be the case later on, so I will simply sweep this under the rug.

Note that we no longer get validation as a side-effect of computing the sextets vector. The same GitHub issue also provides an exact bloom-filter for checking that a particular byte is valid; you can see my implementation here. I’m not sure how the OP constructed the bloom filter, but the search space is small enough that you could have written a little script to brute force it.

Riffling the Sextets

Now comes a much tricker operation: we need to somehow pack all four sextets into three bytes. One way to try to wrap our head around what the packing code in decode_hot() is doing is to pass in the all-ones sextet in one of the four bytes, and see where those ones end up in the return value.

This is not unlike how they use radioactive dyes in biology to track the moment of molecules or cells through an organism.

fn bits(value: u32) -> String {
  let [b1, b2, b3, b4] = value.reverse_bits().to_le_bytes();
  format!("{b1:08b} {b2:08b} {b3:08b} {b4:08b}")
}

fn decode_pack(input: [u8; 4]) {
  let mut output = 0u32;
  for byte in input {
    output <<= 6;
    output |= byte as u32;
  }
  output <<= 8;

  println!("{}\n{}\n", bits(u32::from_be_bytes(input)), bits(output));
}

decode_pack([0b111111, 0, 0, 0]);
decode_pack([0, 0b111111, 0, 0]);
decode_pack([0, 0, 0b111111, 0]);
decode_pack([0, 0, 0, 0b111111]);

// Output:
// 11111100 00000000 00000000 00000000
// 00111111 00000000 00000000 00000000
//
// 00000000 11111100 00000000 00000000
// 11000000 00001111 00000000 00000000
//
// 00000000 00000000 11111100 00000000
// 00000000 11110000 00000011 00000000
//
// 00000000 00000000 00000000 11111100
// 00000000 00000000 11111100 00000000

Bingo. Playing around with the inputs lets us verify which pieces of the bytes wind up where. For example, by passing 0b110000 as input[1], we see that the two high bits of input[1] correspond to the low bits of output[0]. I’ve written the code so that the bits in each byte are printed in little-endian order, so bits on the left are the low bits.

Putting this all together, we can draw a schematic of what this operation does to a general Simd<u8, 4>.

the riffling operation

Now, there’s no single instruction that will do this for us. Shuffles can be used to move bytes around, but we’re dealing with pieces of bytes here. We also can’t really do a shift, since we need bits that are overshifted to move into adjacent lanes.

The trick is to just make the lanes bigger.

Among the operations available for SIMD vectors are lane-wise casts, which allow us to zero-extend, sign-extend, or truncate each lane. So what we can do is cast sextets to a vector of u16, do the shift there and then… somehow put the parts back together?

Let’s see how far shifting gets us. How much do we need to shift things by? First, notice that the order of the bits within each chunk that doesn’t cross a byte boundary doesn’t change. For example, the four low bits of input[1] are in the same order when they become the high bits of output[1], and the two high bits of input[1] are also in the same order when they become the low bits of output[0].

This means we can determine how far to shift by comparing the bit position of the lowest bit of a byte of input with the bit position of the corresponding bit in output.

input[0]’s low bit is the third bit of output[0], so we need to shift input[0] by 2. input[1]’s lowest bit is the fifth bit of output[1], so we need to shift by 4. Analogously, the shifts for input[2] and input[3] turn out to be 6 and 0. In code:

let sextets = ...;
let shifted = sextets.cast::<u16>() << Simd::from([2, 4, 6, 0]);
Rust

So now we have a Simd<u16, 4> that contains the individual chunks that we need to move around, in the high and low bytes of each u16, which we can think of as being analogous to a [[u8; 2]; 4]. For example, shifted[0][0] contains sextet[0], but shifted. This corresponds to the red segment in the first schematic. The smaller blue segment is given by shifted[1][1], i.e., the high byte of the second u16. It’s already in the right place within that byte, so we want output[0] = shifted[0][0] | shifted[1][1].

This suggests a more general strategy: we want to take two vectors, the low bytes and the high bytes of each u16 in shifted, respectively, and somehow shuffle them so that when or’ed together, they give the desired output.

Look at the schematic again: if we had a vector consisting of [..aaaaaa, ....bbbb, ......cc], we could or it with a vector like [bb......, cccc...., dddddd..] to get the desired result.

One problem: dddddd.. is shifted[3][0], i.e., it’s a low byte. If we change the vector we shift by to [2, 4, 6, 8], though, it winds up in shifted[3][1], since it’s been shifted up by 8 bits: a full byte.

// Split shifted into low byte and high byte vectors.
// Same way you'd split a single u16 into bytes, but lane-wise.
let lo = shifted.cast::<u8>();
let hi = (shifted >> Simd::from([8; 4])).cast::<u8>();

// Align the lanes: we want to get shifted[0][0] | shifted[1][1],
// shifted[1][0] | shifted[2][1], etc.
let output = lo | hi.rotate_lanes_left::<1>();
Rust

Et voila, here is our new, totally branchless implementation of decode_hot().

fn decode_hot(ascii: Simd<u8, 4>) -> (Simd<u8, 4>, bool) {
  let hashes = (ascii >> Simd::splat(4))
    + Simd::simd_eq(ascii, Simd::splat(b'/'))
      .to_int()
      .cast::<u8>();

  let sextets = ascii
    - Simd::<i8, 8>::from([0, 16, 19, 4, -65, -65, -71, -71])
      .cast::<u8>()
      .swizzle_dyn(hashes);  // Note quite right yet, see next section.

  let ok = /* bloom filter shenanigans */;

  let shifted = sextets.cast::<u16>() << Simd::from([2, 4, 6, 8]);
  let lo = shifted.cast::<u8>();
  let hi = (shifted >> Simd::splat(8)).cast::<u8>();
  let output = lo | hi.rotate_lanes_left::<1>();

  (output, ok)
}
Rust

The compactness of this solution should not be understated. The simplicity of this solution is a large part of what makes it so efficient, because it aggressively leverages the primitives the hardware offers us.

Scaling Up

Ok, so now we have to contend with a new aspect of our implementation that’s crap: a Simd<u8, 4> is tiny. That’s not even 128 bits, which are the smallest vector registers on x86. What we need to do is make decode_hot() generic on the lane count. This will allow us to tune the number of lanes to batch together depending on benchmarks later on.

fn decode_hot<const N: usize>(ascii: Simd<u8, N>) -> (Simd<u8, N>, bool)
where
  // This makes sure N is a small power of 2.
  LaneCount<N>: SupportedLaneCount,
{
  let hashes = (ascii >> Simd::splat(4))
    + Simd::simd_eq(ascii, Simd::splat(b'/'))
      .to_int()
      .cast::<u8>();

  let sextets = ascii
    - tiled(&[0, 16, 19, 4, -65, -65, -71, -71])
      .cast::<u8>()
      .swizzle_dyn(hashes);  // Works fine now, as long as N >= 8.

  let ok = /* bloom filter shenanigans */;

  let shifted = sextets.cast::<u16>() << tiled(&[2, 4, 6, 8]);
  let lo = shifted.cast::<u8>();
  let hi = (shifted >> Simd::splat(8)).cast::<u8>();
  let output = lo | hi.rotate_lanes_left::<1>();

  (output, ok)
}

/// Generates a new vector made up of repeated "tiles" of identical
/// data.
const fn tiled<T, const N: usize>(tile: &[T]) -> Simd<T, N>
where
  T: SimdElement,
  LaneCount<N>: SupportedLaneCount,
{
  let mut out = [tile[0]; N];
  let mut i = 0;
  while i < N {
    out[i] = tile[i % tile.len()];
    i += 1;
  }
  Simd::from_array(out)
}
Rust

We have to change virtually nothing, which is pretty awesome! But unfortunately, this code is subtly incorrect. Remember how in the N = 4 case, the result of output had a garbage value that we ignore in its highest lane? Well, now that garbage data is interleaved into output: every fourth lane contains garbage.

We can use a shuffle to delete these lanes, thankfully. Specifically, we want shuffled[i] = output[i + i / 3], which skips every forth index. So, shuffled[3] = output[4], skipping over the garbage value in output[3]. If i + i / 3 overflows N, that’s ok, because that’s the high quarter of the final output vector, which is ignored anyways. In code:

fn decode_hot<const N: usize>(ascii: Simd<u8, N>) -> (Simd<u8, N>, bool)
where
  // This makes sure N is a small power of 2.
  LaneCount<N>: SupportedLaneCount,
{
  /* snip */

  let decoded_chunks = lo | hi.rotate_lanes_left::<1>();
  let output = swizzle!(N; decoded_chunks, array!(N; |i| i + i / 3));

  (output, ok)
}
Rust

swizzle!() is a helper macro6 for generating generic implementations of std::simd::Swizzle, and array!() is something I wrote for generating generic-length array constants; the closure is called once for each i in 0..N.

So now we can decode 32 base64 bytes in parallel by calling decode_hot::<32>(). We’ll try to keep things generic from here, so we can tune the lane parameter based on benchmarks.

The Outer Loop

Let’s look at decode() again. Let’s start by making it generic on the internal lane count, too.

fn decode<const N: usize>(data: &[u8], out: &mut Vec<u8>) -> Result<(), Error>
where
  LaneCount<N>: SupportedLaneCount,
{
  let data = match data {
    [p @ .., b'=', b'='] | [p @ .., b'='] | p => p,
  };

  for chunk in data.chunks(N) { // N-sized chunks now.
    let mut ascii = [b'A'; N];
    ascii[..chunk.len()].copy_from_slice(chunk);

    let (dec, ok) = decode_hot::<N>(ascii.into());
    if (!ok) {
      return Err(Error);
    }

    let decoded = decoded_len(chunk.len());
    out.extend_from_slice(&dec[..decoded]);
  }

  Ok(())
}
Rust

What branches are left? There’s still the branch from for chunks in .... It’s not ideal because it can’t do an exact pointer comparison, and needs to do a >= comparison on a length instead.

We call [T]::copy_from_slice, which is super slow because it needs to make a variable-length memcpy call, which can’t be inlined. Function calls are branches! The bounds checks are also a problem.

We branch on ok every loop iteration, still. Not returning early in decode_hot doesn’t win us anything (yet).

We potentially call the allocator in extend_from_slice, and perform another non-inline-able memcpy call.

Preallocating with Slop

The last of these is the easiest to address: we can reserve space in out, since we know exactly how much data we need to write thanks to decoded_len. Better yet, we can reserve some “slop”: i.e., scratch space past where the end of the message would be, so we can perform full SIMD stores, instead of the variable-length memcpy.

This way, in each iteration, we write the full SIMD vector, including any garbage bytes in the upper quarter. Then, the next write is offset 3/4 * N bytes over, so it overwrites the garbage bytes with decoded message bytes. The garbage bytes from the final right get “deleted” by not being included in the final Vec::set_len() that “commits” the memory we wrote to.

fn decode<const N: usize>(data: &[u8], out: &mut Vec<u8>) -> Result<(), Error>
where LaneCount<N>: SupportedLaneCount,
{
  let data = match data {
    [p @ .., b'=', b'='] | [p @ .., b'='] | p => p,
  };

  let final_len = decoded_len(data);
  out.reserve(final_len + N / 4);  // Reserve with slop.

  // Get a raw pointer to where we should start writing.
  let mut ptr = out.as_mut_ptr_range().end();
  let start = ptr;

  for chunk in data.chunks(N) { // N-sized chunks now.
    /* snip */

    let decoded = decoded_len(chunk.len());
    unsafe {
      // Do a raw write and advance the pointer.
      ptr.cast::<Simd<u8, N>>().write_unaligned(dec);
      ptr = ptr.add(decoded);
    }
  }

  unsafe {
    // Update the vector's final length.
    // This is the final "commit".
    let len = ptr.offset_from(start);
    out.set_len(len as usize);
  }

  Ok(())
}
Rust

This is safe, because we’ve pre-allocated exactly the amount of memory we need, and where ptr lands is equal to the amount of memory actually decoded. We could also compute the final length of out ahead of time.

Note that if we early return due to if !ok, out remains unmodified, because even though we did write to its buffer, we never execute the “commit” part, so the code remains correct.

Delaying Failure

Next up, we can eliminate the if !ok branches by waiting to return an error until as late as possible: just before the set_len call.

Remember our observation from before: most base64 encoded blobs are valid, so this unhappy path should be very rare. Also, syntax errors cannot cause code that follows to misbehave arbitrarily, so letting it go wild doesn’t hurt anything.

fn decode<const N: usize>(data: &[u8], out: &mut Vec<u8>) -> Result<(), Error>
where LaneCount<N>: SupportedLaneCount,
{
  /* snip */
  let mut error = false;
  for chunk in data.chunks(N) {
    let mut ascii = [b'A'; N];
    ascii[..chunk.len()].copy_from_slice(chunk);

    let (dec, ok) = decode_hot::<N>(ascii.into());
    error |= !ok;

    /* snip */
  }

  if error {
    return Err(Error);
  }

  unsafe {
    let len = ptr.offset_from(start);
    out.set_len(len as usize);
  }

  Ok(())
}
Rust

The branch is still “there”, sure, but it’s out of the hot loop.

Because we never hit the set_len call and commit whatever garbage we wrote, said garbage essentially disappears when we return early, to be overwritten by future calls to Vec::push().

Unroll It Harder

Ok, let’s look at the memcpy from copy_from_slice at the start of the hot loop. The loop has already been partly unrolled: it does N iterations with SIMD each step, doing something funny on the last step to make up for the missing data (padding with A).

We can take this a step further by doing an “unroll and jam” optimization. This type of unrolling splits the loop into two parts: a hot vectorized loop and a cold remainder part. The hot loop always handles length N input, and the remainder runs at most once and handles i < N input.

Rust provides an iterator adapter for hand-rolled (lol) unroll-and-jam: Iterator::chunks_exact().

fn decode<const N: usize>(data: &[u8], out: &mut Vec<u8>) -> Result<(), Error>
where LaneCount<N>: SupportedLaneCount,
{
  /* snip */
  let mut error = false;
  let mut chunks = data.chunks_exact(N);
  for chunk in &mut chunks {
    // Simd::from_slice() can do a load in one instruction.
    // The bounds check is easy for the compiler to elide.
    let (dec, ok) = decode_hot::<N>(Simd::from_slice(chunk));
    error |= !ok;
    /* snip */
  }

  let rest = chunks.remainder();
  if !rest.empty() {
    let mut ascii = [b'A'; N];
    ascii[..chunk.len()].copy_from_slice(chunk);

    let (dec, ok) = decode_hot::<N>(ascii.into());
    /* snip */
  }

  /* snip */
}
Rust

Splitting into two parts lets us call Simd::from_slice(), which performs a single, vector-sized load.

So, How Fast Is It?

At this point, it looks like we’ve addressed every branch that we can, so some benchmarks are in order. I wrote a benchmark that decodes messages of every length from 0 to something like 200 or 500 bytes, and compared it against the baseline base64 implementation on crates.io.

I compiled with -Zbuild-std and -Ctarget-cpu=native to try to get the best results. Based on some tuning, N = 32 was the best length, since it used one YMM register for each iteration of the hot loop.

a performance graph; our code is really good compared to the baseline, but variance is high

So, we have the baseline beat. But what’s up with that crazy heartbeat waveform? You can tell it has something to do with the “remainder” part of the loop, since it correlates strongly with data.len() % 32.

I stared at the assembly for a while. I don’t remember what was there, but I think that copy_from_slice had been inlined and unrolled into a loop that loaded each byte at a time. The moral equivalent of this:

let mut ascii = [b'A'; N];
for (a, b) in Iterator::zip(&mut ascii, chunk) {
  *a = *b;
}
Rust

I decided to try Simd::gather_or(), which is kind of like a “vectorized load”. It wound up producing worse assembly, so I gave up on using a gather and instead wrote a carefully optimized loading function by hand.

Unroll and Jam, Revisited

The idea here is to perform the largest scalar loads Rust offers where possible. The strategy is again unroll and jam: perform u128 loads in a loop and deal with the remainder separately.

The hot part looks like this:

let mut buf = [b'A'; N];

// Load a bunch of big 16-byte chunks. LLVM will lower these to XMM loads.
let ascii_ptr = buf.as_mut_ptr();
let mut write_at = ascii_ptr;
if slice.len() >= 16 {
  for i in 0..slice.len() / 16 {
    unsafe {
      write_at = write_at.add(i * 16);

      let word = slice.as_ptr().cast::<u128>().add(i).read_unaligned();
      write_at.cast::<u128>().write_unaligned(word);
    }
  }
}
Rust

The cold part seems hard to optimize at first. What’s the least number of unaligned loads you need to do to load 15 bytes from memory? It’s two! You can load a u64 from p, and then another one from p + 7; these loads (call them a and b) overlap by one byte, but we can or them together to merge that byte, so our loaded value is a as u128 | (b as u128 << 56).

A similar trick works if the data to load is between a u32 and a u64. Finally, to load 1, 2, or 3 bytes, we can load p, p + len/2 and p + len-1; depending on whether len is 1, 2, or 3, this will potentially load the same byte multiple times; however, this reduces the number of branches necessary, since we don’t need to distinguish the 1, 2, or 3 lines.

This is the kind of code that’s probably easier to read than to explain.

unsafe {
  let ptr = slice.as_ptr().offset(write_at.offset_from(ascii_ptr));
  let len = slice.len() % 16;

  if len >= 8 {
    // Load two overlapping u64s.
    let lo = ptr.cast::<u64>().read_unaligned() as u128;
    let hi = ptr.add(len - 8).cast::<u64>().read_unaligned() as u128;
    let data = lo | (hi << ((len - 8) * 8));

    let z = u128::from_ne_bytes([b'A'; 16]) << (len * 8);
    write_at.cast::<u128>().write_unaligned(data | z);
  } else if len >= 4 {
    // Load two overlapping u32s.
    let lo = ptr.cast::<u32>().read_unaligned() as u64;
    let hi = ptr.add(len - 4).cast::<u32>().read_unaligned() as u64;
    let data = lo | (hi << ((len - 4) * 8));

    let z = u64::from_ne_bytes([b'A'; 8]) << (len * 8);
    write_at.cast::<u64>().write_unaligned(data | z);
  } else {
    // Load 3 overlapping u8s.

    // For len       1       2       3     ...
    // ... this is  ptr[0]  ptr[0]  ptr[0]
    let lo = ptr.read() as u32;
    // ... this is  ptr[0]  ptr[1]  ptr[1]
    let mid = ptr.add(len / 2).read() as u32;
    // ... this is  ptr[0]  ptr[1]  ptr[2]
    let hi = ptr.add(len - 1).read() as u32;

    let data = lo | (mid << ((len / 2) * 8)) | hi << ((len - 1) * 8);

    let z = u32::from_ne_bytes([b'A'; 4]) << (len * 8);
    write_at.cast::<u32>().write_unaligned(data | z);
  }
}
Rust

I learned this type of loading code while contributing to Abseil: it’s very useful for loading variable-length data for data-hungry algorithms, like a codec or a hash function.

Here’s the same benchmark again, but with our new loading code.

a performance graph; our code is even better and the variance is very tight

The results are really, really good. The variance is super tight, and our performance is 2x that of the baseline pretty much everywhere. Success.

Encoding? Web-Safe?

Writing an encoding function is simple enough: first, implement an encode_hot() function that reverses the operations from decode_hot(). The perfect hash from before won’t work, so you’ll need to invent a new one.

Also, the loading/storing code around the encoder is slightly different, too. vb64 implements a very efficient encoding routine too, so I suggest taking a look at the source code if you’re interested.

There is a base64 variant called web-safe base64, that replaces the + and / characters with - and _. Building a perfect hash for these is trickier: you would probably have to do something like (byte >> 4) - (byte == '_' ? '_' : 0). I don’t support web-safe base64 yet, but only because I haven’t gotten around to it.

Conclusion

My library doesn’t really solve an important problem; base64 decoding isn’t a bottleneck… anywhere that I know of, really. But writing SIMD code is really fun! Writing branchless code is often overkill but can give you a good appreciation for what your compilers can and can’t do for you.

This project was also an excuse to try std::simd. I think it’s great overall, and generates excellent code. There’s some rough edges I’d like to see fixed to make SIMD code even simpler, but overall I’m very happy with the work that’s been done there.

This is probably one of the most complicated posts I’ve written in a long time. SIMD (and performance in general) is a complex topic that requires a breadth of knowledge of tricks and hardware, a lot of which isn’t written down. More of it is written down now, though.

  1. Shifts are better understood as arithmetic. They have a lane width, and closely approximate multiplication and division. AVX2 doesn’t even have vector shift or vector division: you emulate it with multiplication. 

  2. The two common representations of true and false, i.e. 1 and 0 or 0xff... and 0, are related by the two’s complement operation.

    For example, if I write uint32_t m = -(a == b);, m will be zero if a == b is false, and all-ones otherwise. This because applying any arithmetic operation to a bool promotes it to int, so false maps to 0 and true maps to 1. Applying the - sends 0 to 0 and 1 to -1, and it’s useful to know that in two’s complement, -1 is represented as all-ones.

    The all-ones representation for true is useful, because it can be used to implement branchless select very easily. For example,

    int select_if_eq(int a, int b, int x, int y) {
      int mask = -(a == b);
      return (mask & x) | (~mask & y);
    }
    C++

    This function returns x if a == b, and y otherwise. Can you tell why? 

  3. Target features also affect ABI in subtle ways that I could write many, many more words on. Compiling libraries you plan to distribute with weird target feature flags is a recipe for disaster. 

  4. Why can’t we leave this kind of thing to LLVM? Finding this particular branchless implementation is tricky. LLVM is smart enough to fold the match into a switch table, but that’s unnecessary memory traffic to look at the table. (In this domain, unnecessary memory traffic makes our code slower.)

    Incidentally, with the code I wrote for the original decoded_len(), LLVM produces a jump and a lookup table, which is definitely an odd choice? I went down something of a rabbit-hole. https://github.com/rust-lang/rust/issues/118306

    As for getting LLVM to find the “branchless” version of the lookup table? The search space is quite large, and this kind of “general strength reduction” problem is fairly open (keywords: “superoptimizers”). 

  5. To be clear on why this works: suppose that in our reference implementation, we only handle inputs that are a multiple-of-4 length, and are padded with = as necessary, and we treat = as zero in the match. Then, for the purposes of computing the bytes value (before appending it to out), we can assume the chunk length is always 4. 

  6. See vb64/src/util.rs

What is a Matrix? A Miserable Pile of Coefficients!

Linear algebra is undoubtedly the most useful field in all of algebra. It finds applications in all kinds of science and engineering, like quantum mechanics, graphics programming, and machine learning. It is the “most well-behaved” algebraic theory, in that other abstract algebra topics often try to approximate linear algebra, when possible.

For many students, linear algebra means vectors and matrices and determinants, and complex formulas for computing them. Matrices, in particular, come equipped with a fairly complicated, and a fortiori convoluted, multiplication operation.

This is not the only way to teach linear algebra, of course. Matrices and their multiplication appear complicated, but actually are a natural and compact way to represent a particular type of function, i.e., a linear map (or linear transformation).

This article is a short introduction to viewing linear algebra from the perspective of abstract algebra, from which matrices arise as a computational tool, rather than an object of study in and of themselves. I do assume some degree of familiarity with the idea of a matrix.

Linear Spaces

Most linear algebra courses open with a description of vectors in Euclidean space: Rn\R^n. Vectors there are defined as tuples of real numbers that can be added, multiplied, and scaled. Two vectors can be combined into a number through the dot product. Vectors come equipped with a notion of magnitude and direction.

However, this highly geometric picture can be counterproductive, since it is hard to apply geometric intuition directly to higher dimensions. It also obscures how this connects to working over a different number system, like the complex numbers.

Instead, I’d like to open with the concept of a linear space, which is somewhat more abstract than a vector space1.

First, we will need a notion of a “coefficient”, which is essentially something that you can do arithmetic with. We will draw coefficients from a designated ground field KK. A field is a setting for doing arithmetic: a set of objects that can be added, subtracted, and multiplied, and divided in the “usual fashion” along with special 00 and 11 values. E.g. a+0=aa + 0 = a, 1a=a1a = a, a(b+c)=ab+aca(b + c) = ab + ac, and so on.

Not only are the real numbers R\R a field, but so are the complex numbers C\C, and the rational numbers Q\Q. If we drop the “division” requirement, we can also include the integers Z\Z, or polynomials with rational coefficients Q[x]\Q[x], for example.

Having chosen our coefficients KK, a linear space VV over KK is another set of objects that can be added and subtracted (and including a special value 00)2, along with a scaling operation, which takes a coefficient cKc \in K and one of our objects vVv \in V and produces a new cvVcv \in V.

The important part of the scaling operation is that it’s compatible with addition: if we have a,bKa, b \in K and v,wVv, w \in V, we require that

a(v+w)=av+aw(a+b)v=av+bv\begin{gather*}a (v + w) = av + aw \\ (a + b) v = av + bv\end{gather*}
Math

This is what makes a linear space “linear”: you can write equations that look like first-degree polynomials (e.g. ax+bax + b), and which can be manipulated like first-degree polynomials.

These polynomials are called linear because their graph looks like a line. There’s no multiplication, so we can’t have x2x^2, but we do have multiplication by a coefficient. This is what makes linear algebra is “linear”.

Some examples: nn-tuples of elements drawn from any field are a linear space over that field, by componentwise addition and scalar multiplication; e.g., R3R^3. Setting n=1n = 1 shows that every field is a linear space over itself.

Polynomials in one variable over some field, K[x]K[x], are also a linear space, since polynomials can be added together and scaled by a any value in KK (since lone coefficients are degree zero polynomials). Real-valued functions also form a linear space over R\R in a similar way.

Linear Transformations

A linear map is a function f:VWf: V \to W between two linear spaces VV and WW over KK which “respects” the linear structure in a particular way. That is, for any cKc\in K and v,wVv, w \in V,

f(v+w)=f(v)+f(w)f(cv)=cf(v)\begin{gather*}f(v + w) = f(v) + f(w) \\ f(cv) = c \cdot f(v)\end{gather*}
Math

We call this type of relationship (respecting addition and scaling) “linearity”. One way to think of this relationship is that ff is kind of like a different kind of coefficient, in that it distributes over addition, which commutes with the “ordinary” coefficients from KK. However, applying ff produces a value from WW rather than VV.

Another way to think of it is that if we have a linear polynomial like p(x)=ax+bp(x) = ax + b in xx, then f(p(x))=p(f(x))f(p(x)) = p(f(x)). We say that ff commutes with all linear polynomials.

The most obvious sort of linear map is scaling. Given any coefficient cKc \in K, it defines a “scaling map”:

μc:VVvcv\begin{gather*}\mu_c: V \to V \\ v \mapsto cv\end{gather*}
Math

It’s trivial to check this is a linear map, by plugging it into the above equations: it’s linear because scaling is distributive and commutative.

Linear maps are the essential thing we study in linear algebra, since they describe all the different kinds of relationships between linear spaces.

Some linear maps are complicated. For example, a function from R2R2\R^2 \to \R^2 that rotates the plane by some angle θ\theta is linear, as are operations that stretch or shear the plane. However, they can’t “bend” or “fold” the plane: they are all fairly rigid motions. In the linear space Q[x]\Q[x] of rational polynomials, multiplication by any polynomial, such as xx or x21x^2 - 1, is a linear map. The notion of “linear map” depends heavily on the space we’re in.

Unfortunately, linear maps as they are quite opaque, and do not lend themselves well to calculation. However, we can build an explicit representation using a linear basis.

Linear Basis

For any linear space, we can construct a relatively small of elements such that any element of the space can be expressed as some linear function of these elements.

Explicitly, for any VV, we can construct a sequence3 eie_i such that for any vVv \in V, we can find ciKc_i \in K such that

v=iciei.v = \sum_i c_i e_i.
Math

Such a set eie_i is called a basis if it is linearly independent: no one eie_i can be expressed as a linear function of the rest. The dimension of VV, denoted dimV\dim V, is the number of elements in any choice of basis. This value does not depend on the choice of basis4.

Constructing a basis for any VV is easy: we can do this recursively. First, pick a random element e1e_1 of VV, and define a new linear space V/e1V/e_1 where we have identified all elements that differ by a factor of e1e_1 as equal (i.e., if vw=ce1v - w = ce_1, we treat vv and ww as equal in V/e1V/e_1).

Then, a basis for VV is a basis of V/e1V/e_1 with e1e_1 added. The construction of V/e1V/e_1 is essentially “collapsing” the dimension e1e_1 “points” in, giving us a new space where we’ve “deleted” all of the elements that have a nonzero e1e_1 component.

However, this only works when the dimension is finite; more complex methods must be used for infinite-dimensional spaces. For example, the polynomials Q[x]\Q[x] are an infinite-dimensional space, with basis elements 1,x,x2,x3,...\\{1, x, x^2, x^3, ...\\}. In general, for any linear space VV, it is always possible to arbitrarily choose a basis, although it may be infinite5.

Bases are useful because they give us a concrete representation of any element of VV. Given a fixed basis eie_i, we can represent any w=icieiw = \sum_i c_i e_i by the coefficients cic_i themselves. For a finite-dimensional VV, this brings us back column vectors: (dimV)(\dim V)-tuples of coefficients from KK that are added and scaled componentwise.

[c0c1cn]:=given eiiciei\Mat{c_0 \\ c_1 \\ \vdots \\ c_n} \,\underset{\text{given } e_i}{:=}\, \sum_i c_i e_i
Math

The iith basis element is represented as the vector whose entries are all 00 except for the iith one, which is 11. E.g.,

[100]=given eie1,[010]=given eie2,...\Mat{1 \\ 0 \\ \vdots \\ 0} \,\underset{\text{given } e_i}{=}\, e_1, \,\,\, \Mat{0 \\ 1 \\ \vdots \\ 0} \,\underset{\text{given } e_i}{=}\, e_2, \,\,\, ...
Math

It is important to recall that the choice of basis is arbitrary. From the mathematical perspective, any basis is just as good as any other, although some may be more computationally convenient.

Over R2\R^2, (1,0)(1, 0) and (0,1)(0, 1) are sometimes called the “standard basis”, but (1,2)(1, 2) and (3,4)(3, -4) are also a basis for this space. One easy mistake to make, particularly when working over the tuple space KnK^n, is to confuse the actual elements of the linear space with the coefficient vectors that represent them. Working with abstract linear spaces eliminates this source of confusion.

Representing Linear Transformations

Working with finite-dimensional linear spaces VV and WW, let’s choose bases eie_i and djd_j for them, and let’s consider a linear map f:VWf: V \to W.

The powerful thing about bases is that we can more compactly express the information content of ff. Given any vVv \in V, we can decompose it into a linear function of the basis (for some coefficients), so we can write

f(v)=f(iciei)=if(ciei)=icif(ei)f(v) = f\left(\sum_i c_i e_i\right) = \sum_i f(c_i e_i) = \sum_i c_i \cdot f(e_i)
Math

In other words, to specify ff, we only need to specify what it does to each of the dimV\dim V basis elements. But what’s more, because WW also has a basis, we can write

f(ei)=jAijdjf(e_i) = \sum_j A_{ij} d_j
Math

Putting these two formulas together, we have an explicit closed form for f(v)f(v), given the coefficients AijA_{ij} of ff, and the coefficients cic_i of vv:

f(v)=i,jciAijdjf(v) = \sum_{i,j} c_i A_{ij} d_j
Math

Alternatively, we can express vv and f(v)f(v) as column vectors, and ff as the AA matrix with entires AijA_{ij}. The entries of the resulting column vector are given by the above explicit formula for f(v)f(v), fixing the value of jj in each entry.

[A0,0A1,0An,0A1,0A1,1An,1A0,mA1,mAn,m]A[c0c1cn]v=[iciAi,0iciAi,1iciAi,m]Av\underbrace{\Mat{ A_{0,0} & A_{1,0} & \cdots & A_{n,0} \\ A_{1,0} & A_{1,1} & \cdots & A_{n,1} \\ \vdots & \vdots & \ddots & \vdots \\ A_{0,m} & A_{1,m} & \cdots & A_{n,m} }}_A \, \underbrace{\Mat{c_0 \\ c_1 \\ \vdots \\ c_n}}_v = \underbrace{\Mat{ \sum_i c_i A_{i,0} \\ \sum_i c_i A_{i,1} \\ \vdots \\ \sum_i c_i A_{i,m} }}_{Av}
Math

(Remember, this is all dependent on the choices of bases eie_i and djd_j!)

Behold, we have derived the matrix-vector multiplication formula: the jjth entry of the result is the dot product of the vector and the jjth row of the matrix.

But it is crucial to keep in mind that we had to choose bases eie_i and djd_j to be entitled to write down a matrix for ff. The values of the coefficients depend on the choice of basis.

If your linear space happens to be Rn\R^n, there is an “obvious” choice of basis, but not every linear space over R\R is Rn\R^n! Importantly, the actual linear algebra does not change depending on the basis6.

Matrix Multiplication

So, where does matrix multiplication come from? An n×mn \times m7 matrix AA represents some linear map f:VWf: V \to W, where dimV=n\dim V = n, dimW=m\dim W = m, and appropriate choices of basis (eie_i, djd_j) have been made.

Keeping in mind that linear maps are supreme over matrices, suppose we have a third linear space UU, and a map g:UVg: U \to V, and let =dimU\ell = \dim U. Choosing a basis hkh_k for UU, we can represent gg as a matrix BB of dimension ×n\ell \times n.

Then, we’d like for the matrix product ABAB to be the same matrix we’d get from representing the composite map fg:UWfg: U \to W as a matrix, using the aforementioned choices of bases for UU and WW (the basis choice for VV should “cancel out”).

Recall our formula for f(v)f(v) in terms of its matrix coefficients AijA_{ij} and the coefficients of the input vv, which we call cic_i. We can produce a similar formula for g(u)g(u), giving it matrix coefficients BkiB_{ki}, and coefficients bkb_k for uu. (I appologize for the number of indices and coefficients here.)

f(v)=i,jciAijdjg(u)=k,ibkBkiei\begin{align*}f(v) &= \sum_{i,j} c_i A_{ij} d_j \\ g(u) &= \sum_{k,i} b_k B_{ki} e_i\end{align*}
Math

If we write f(g(u))f(g(u)), then cic_i is the coefficient eie_i is multiplied by; i.e., we fix ii, and drop it from the summation: ci=kbkBkic_i = \sum_k b_k B_{ki}.

Substituting that into the above formula, we now have something like the following.

f(g(u))=i,jkbkBkiAijdjf(g(u))=k,jbk(iAijBki)dj()\begin{align*}f(g(u)) &= \sum_{i,j} \sum_{k} b_k B_{ki} A_{ij} d_j \\ f(g(u)) &= \sum_{k,j} b_k \left(\sum_{i} A_{ij} B_{ki} \right) d_j &(\star)\end{align*}
Math

In ()(\star), we’ve rearranged things so that the sum in parenthesis is the (k,j)(k,j)th matrix coefficient of the composite fgfg. Because we wanted ABAB to represent fgfg, it must be an ×m\ell \times m matrix whose entries are

(AB)kj=iAijBki(AB)_{kj} = \sum_{i} A_{ij} B_{ki}
Math

This is matrix multiplication. It arises naturally out of composition of linear maps. In this way, the matrix multiplication formula is not a definition, but a theorem of linear algebra!

Theorem (Matrix Multiplication)

Given an n×mn \times m matrix AA and an ×n\ell \times n matrix BB, both with coefficients in KK, then ABAB is an ×m\ell \times m matrix with entires

(AB)kj=iAijBki (AB)_{kj} = \sum_{i} A_{ij} B_{ki}
Math

If the matrix dimension is read as nmn \to m instead of n×mn \times m, the shape requirements are more obvious: two matrices AA and BB can be multiplied together only when they represent a pair of maps VWV \to W and UVU \to V.

Other Consequences, and Conclusion

The identity matrix is an n×nn \times n matrix:

In=[111]I_n = \Mat{ 1 \\ & 1 \\ && \ddots \\ &&& 1 }
Math

We want it to be such that for any appropriately-sized matrices AA and BB, it has AIn=AAI_n = A and InB=BI_n B = B. Lifted up to linear maps, this means that InI_n should represent the identity map VVV \to V, when dimV=n\dim V = n. This map sends each basis element eie_i to itself, so the columns of InI_n should be the basis vectors, in order:

[100][010][001]\Mat{1 \\ 0 \\ \vdots \\ 0} \Mat{0 \\ 1 \\ \vdots \\ 0} \cdots \Mat{0 \\ 0 \\ \vdots \\ 1}
Math

If we shuffle the columns, we’ll get a permutation matrix, which shuffles the coefficients of a column vector. For example, consider this matrix.

[010100001]\Mat{ 0 & 1 & 0 \\ 1 & 0 & 0 \\ 0 & 0 & 1 }
Math

This is similar to the identity, but we’ve swapped the first two columns. Thus, it will swap the first two coefficients of any column vector.

Matrices may seem unintuitive when they’re introduced as a subject of study. Every student encountering matrices for the same time may ask “If they add componentwise, why don’t they multiply componentwise too?”

However, approaching matrices as a computational and representational tool shows that the convoluted-looking matrix multiplication formula is a direct consequence of linearity.

f(v+w)=f(v)+f(w)f(cv)=cf(v)\begin{gather*}f(v + w) = f(v) + f(w) \\ f(cv) = c \cdot f(v)\end{gather*}
Math
  1. In actual modern mathematics, the objects I describe are still called vector spaces, which I think generates unnecessary confusion in this case. “Linear space” is a bit more on the nose for what I’m going for. 

  2. This type of structure (just the addition part) is also called an “abelian group”. 

  3. Throughout ii, jj, and kk are indices in some unspecified but ordered indexing set, usually {1,2,...,n}\{1, 2, ..., n\}. I will not bother giving this index set a name. 

  4. This is sometimes called the dimension theorem, which is somewhat tedious to prove. 

  5. An example of a messy infinite-dimensional basis is R\R considered as linear space over Q\Q (in general, every field is a linear space over its subfields). The basis for this space essentially has to be “11, and all irrational numbers” except if we include e.g. ee and π\pi we can’t include e+12πe + \frac{1}{2}\pi, which is a Q\Q-linear combination of ee and π\pi.

    On the other hand, C\C is two-dimensional over R\R, with basis 1,i\\{1, i\\}.

    Incidentally, this idea of “view a field KK as a linear space over its subfield FF” is such a useful concept that it is called the “degree of the field extension K/FK/F”, and given the symbol [K:F][K : F].

    This, [R:Q]=[\R : \Q] = \infty and [C:R]=2[\C : \R] = 2

  6. You may recall from linear algebra class that two matrices AA and BB of the same shape are similar if there are two appropriately-sized square matrices SS and RR such that SAR=BSAR = B. These matrices SS and RR represent a change of basis, and indicate that the linear maps A,B:VWA, B: V \to W these matrices come from do “the same thing” to elements of VV.

    Over an algebraically closed field like C\C (i.e. all polynomials have solutions), there is an even stronger way to capture the information content of a linear map via Jordan canonicalization, which takes any square matrix AA and produces an almost-diagonal square matrix that only depends on the eigenvalues of AA, which is the same for similar matrices, and thus basis-independent. 

  7. Here, as always, matrix dimensions are given in RC (row-column) order. You can think of this as being “input dimension” to “output dimension”.