III. Chainable Middleware

It turns out that Handlers provide a nice way to create a chain of HTTP middleware.

There's more than one way to go about this, but they all use http.Handler. If you google it, you'll come across examples that have you nest function calls in a way that gets grim pretty quicly.

func(w http.Writer, seriously(
    who(
        wants(
            this(req)
        )
    )
) {}

However, if you google "golang chainable middleware" (or similar), you'll find a much nicer pattern! Let's explore it, and see how Handlers help us here.

Generating a Middleware

Let's see a function that generates a middleware:

// LogMiddleware returns a function that logs
// related output for each request received.
func LogMiddleware(next http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Printf("%s: %s", r.Method, r.RequestURI)
		next.ServeHTTP(w, r)
	})
}

We'll deal with http.HandlerFunc instead of http.Handler in this case. The only reason for this is that it's easier to pass a regular function (see the previous article on Handlers if you haven't read it).

Function LogMiddleware takes an instance of http.HandlerFunc and returns another http.HandlerFunc.

The function we return does some logic, and then calls ServeHTTP on the "next" http.HandlerFunc. We know the ServeHTTP method is available to call thanks to the magic of http.HandlerFunc as described in the previous article.

We can keep nesting middleware having each one call the "next" Handler. Let's write some code and make that idea concrete.

Codify the Middleware

First, we'll codify this "pattern" of generating Middleware functions as a type:

// Middleware is func type that allows
// for chaining middleware
type Middleware func(http.HandlerFunc) http.HandlerFunc

Defining this gives us the ability to enforce types, which we'll see later.

Chaining the Middleware

The LogMiddleware function logs info about the request before calling the "next" middleware. Every middleware does this until the last Handler is run (or if a middleware decides to short-circuit the process and do something else, e.g. return a "not authorized" response).

Since each middleware calls the "next" middleware, we call this a chain of middleware.

That becomes a bit clearer if you see multiple middleware in use. Let's pretend we have 2 middleware:

// LogMiddleware logs some output for each request received
func LogMiddleware(h http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Printf("%s: %s", r.Method, r.RequestURI)
		h.ServeHTTP(w, r)
	})
}

// RateLimit middleware limits how often a
// request can be made from a given client
func RateLimitMiddleware(h http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        // Love me some handy-wavy magic
        if rateLimitBreached(r) {
            // 429 Too Many Requests
            writer.WriteHeader(429)
            fmt.Fprint(writer, "Rate limit reached")
            return
        }

		h.ServeHTTP(w, r)
	})
}

We have 2 middleware, and then we have a function that actually handles the web request:

func(writer http.ResponseWriter, r *http.Request) {
    writer.WriteHeader(200)
    fmt.Fprint(writer, "HELLO")
}

Let's whip up a helper function that will wrap an HTTP handler function (the thing doing the actual work of responding to a request) in our various middleware:

// CompileMiddleware takes the base http.HandlerFunc h 
// and wraps it around the given list of Middleware m
func CompileMiddleware(h http.HandlerFunc, m []Middleware) http.HandlerFunc {
	if len(m) < 1 {
		return h
	}

	wrapped := h

	// loop in reverse to preserve middleware order
	for i := len(m) - 1; i >= 0; i-- {
		wrapped = m[i](wrapped)
	}

	return wrapped
}

Putting it all together along with our basic web server looks like this:

package main

import (
	"fmt"
	"log"
	"net"
	"net/http"
)

// Middleware is func type that allows for
// chaining middleware
type Middleware func(http.HandlerFunc) http.HandlerFunc

// CompileMiddleware takes the base http.HandlerFunc h 
// and wraps it around the given list of Middleware m
func CompileMiddleware(h http.HandlerFunc, m []Middleware) http.HandlerFunc {
	if len(m) < 1 {
		return h
	}

	wrapped := h

	// loop in reverse to preserve middleware order
	for i := len(m) - 1; i >= 0; i-- {
		wrapped = m[i](wrapped)
	}

	return wrapped
}

// Let's define the middleware!

// LogMiddleware logs some output for each
// request received
func LogMiddleware(h http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		log.Printf("%s: %s", r.Method, r.RequestURI)
		h.ServeHTTP(w, r)
	})
}

// RateLimit middleware limits how often a
// request can be made from a given client
func RateLimitMiddleware(h http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        if rateLimitBreached(r) {
            writer.WriteHeader(403)
            fmt.Fprint(writer, "Rate limit reached")
            return
        }

		h.ServeHTTP(w, r)
	})
}

func main() {

	mux := http.NewServeMux()

	// Define our middleware stack
	// These run in the order given
	stack := []Middleware{
        LogMiddleware,
		RateLimitMiddleware,
	}

    // Assign our base HTTP Handler to a variable
	handleAllRequests := func(writer http.ResponseWriter, r *http.Request) {
		writer.WriteHeader(200)
		fmt.Fprint(writer, "HELLO")
	}

    // Set our handler as a "wrapped" handler.
    // Each middleware is called before finally
    // calling the handleAllRequests http Handler
	mux.HandleFunc("/", CompileMiddleware(handleAllRequests, stack))

	srv := &http.Server{
		Handler: mux,
	}

	ln, err := net.Listen("tcp", ":80")
	if err != nil {
		panic(err)
	}

	srv.Serve(ln)
}

Fairly simple, if a little verbose (just like Golang itself). Any request will be logged, then checked against a rate limit. If the rate limit is not reached, then our handleAllRequests handler is finally run.

Note that we preserve the order that middleware are run. Our base handler is run last.

Quick Review

The annoying part (but the part that makes this interesting for me to learn and write about) is how Golang's type system can be a bit obtuse.

A little review from the last article: Golang uses a type HandlerFunc func(...) to define a type that is itself a function. That function type provides method ServeHTTP. This is so we can pass a regular function and have the http stdlib convert it to an http.Handler (which wants that ServeHTTP method).

On top of that, we abuse it a bit for our Middleware, allowing us to create a chain of http.Handler's (technically http.HandlerFunc's) that can call each sibling middleware (in a specific order!) before finally calling the actual HTTP handler that returns a response.

Sidenote: Any middleware that doesn't call ServeHTTP breaks the chain. This is on purpose - it can abort the chain and respond with whatever makes sense if needed. That's why middleware are often used on routes that require authentication. No use processing the request further if we know the user needs to be authenticated to perform the action.

So, now we have middleware! Let's next see how to share information amongst our middleware.