IV. Adding Context to Requests

We're going to discuss Golang's "context" objects (context.Context). I'll assume you're at least passingly familiar with them.

It's useful if your request handlers can share information about a request.

Often the request data itself (http.Request) has everything you need, but sometimes your application has its own data. For example - the authenticated user.

To help here, one common pattern is to pass a context object through the middleware chain. A middleware can set some data, and the following middleware can see that data. This is generaly done with context objects.

Now, using contexts is generally agreed to be a good thing, but what data you save to a context is disagreed upon. The rules of thumb that I like are:

  1. For HTTP requests, only put information in a context that is specific to that request
  2. Don't put data into a request that lives on longer than that one request

Something that belongs in a request context: the current user, or perhaps a DB transaction used just for that request.

Something that doesn't belong in a request context: A Logger or DB connection (which is indeed different from a specific transaction).

Context and Cloning

Pontificating about programming aside, there's a few annoying things to explain about Go's context object, especially in regards to HTTP requests.

First, an http.Request object has a few pertinent methods:

  1. req.Context() returns the request's context. If none was set on the request, it returns a new context.Background().
  2. req.WithContext(ctx) returns a shallow copy of the request with the provided context. Requests are (should be) immutable, and contexts are definitely immutable.

This means adding a context to a request nets us a copied request object with a new context on it.

But just what the hell is a shallow copy of a request?

Here's WithContext(ctx) from stdlib, with a bit of the relevant stdlib comments (which will change after Go 1.19):

// To change the context of a request, such as an incoming request you
// want to modify before sending back out, use Request.Clone. Between
// those two uses, it's rare to need WithContext.
func (r *Request) WithContext(ctx context.Context) *Request {
    if ctx == nil {
        panic("nil context")
    }
    r2 := new(Request)
    *r2 = *r
    r2.ctx = ctx
    return r2
}

Great, so I should never actually use WithContext()?!? I asked people smarter than me (who were also confused, it wasn't just me)! One of those people went to the source to ask.

It turns out, using WithContext() is just fine for our use case. We can run newReq := r.WithContext(myShinyNewCtx) in our middleware, and pass that along as if it is our original request.

Using r.Clone() is a "deep copy". It's better suited for making a completely new copy of the request with it's own "lifecycle". For example, the built-in httputil.NewSingleHostReverseProxy() makes use of Clone() in order to take a received request, and then modify it as needed before passing the cloned & modified request to an upstream server.

Here's the Clone() method:

// Clone returns a deep copy of r with its context changed to ctx.
// The provided ctx must be non-nil.
//
// For an outgoing client request, the context controls the entire
// lifetime of a request and its response: obtaining a connection,
// sending the request, and reading the response headers and body.
func (r *Request) Clone(ctx context.Context) *Request {
    if ctx == nil {
        panic("nil context")
    }
    r2 := new(Request)
    *r2 = *r
    r2.ctx = ctx
    r2.URL = cloneURL(r.URL)
    if r.Header != nil {
        r2.Header = r.Header.Clone()
    }
    if r.Trailer != nil {
        r2.Trailer = r.Trailer.Clone()
    }
    if s := r.TransferEncoding; s != nil {
        s2 := make([]string, len(s))
        copy(s2, s)
        r2.TransferEncoding = s2
    }
    r2.Form = cloneURLValues(r.Form)
    r2.PostForm = cloneURLValues(r.PostForm)
    r2.MultipartForm = cloneMultipartForm(r.MultipartForm)
    return r2
}

It does more stuff! I'm still not sure why a "shallow" copy is safe to use with Middleware while a "deep" copy requires explictly copying some data. It seems like it's concerned with cloning specific types of the http.Request struct defined by the http module (vs "standard" types such as string, bool, or []string).

Anyway, let's do some contexting.

Adding Context

We'll stick with our example of adding information about the current authenticated user. Let's add a Middleware that "adds" the current user to the request's context.

// UserMiddleware gets the current user and adds it to a new context
func UserMiddleware(h http.HandlerFunc) http.HandlerFunc {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ctx := context.WithValue(r.Context(), "user", "fideloper")
        newReq := r.WithContext(ctx)
        h.ServeHTTP(w, newReq)
    })
}

Contexts are immutable, so each "change" requires creating an new context based off of an old one. We grab r.Context(), which likely is just returning context.Background() as mentioned earlier.

We pass our new request object newReq along in h.ServeHTTP, leaving the old one to die a lonely death when the garbage collector comes calling.

We can add this into our Middleware stack, and then we get a user object (just a string for now) that any other middleware/handler can retrieve.

Here's the whole thing:

package main

import (
    "context"
    "fmt"
    "log"
    "net"
    "net/http"
)

// Middleware is a func type that
// allows for chaining middleware
type Middleware func(http.HandlerFunc) http.HandlerFunc

// CompileMiddleware takes the base http.HandlerFunc h
// and wraps around the given list of Middleware m
func CompileMiddleware(h http.HandlerFunc, m []Middleware) http.HandlerFunc {
    if len(m) < 1 {
        return h
    }

    wrapped := h

    // loop in reverse to preserve middleware order
    for i := len(m) - 1; i >= 0; i-- {
        wrapped = m[i](wrapped)
    }

    return wrapped
}

// Let's define some middleware!

// LogMiddleware logs some output 
// for each request received
func LogMiddleware(h http.HandlerFunc) http.HandlerFunc {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        log.Printf("%s: %s", r.Method, r.RequestURI)
        h.ServeHTTP(w, r)
    })
}

// UserMiddleware gets the current user 
// and adds it to a new context
func UserMiddleware(h http.HandlerFunc) http.HandlerFunc {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ctx := context.WithValue(r.Context(), "user", "fideloper")
        newReq := r.WithContext(ctx)
        h.ServeHTTP(w, newReq)
    })
}

// RateLimitMiddleware limits how often
// a request can be made from a given client
func RateLimitMiddleware(h http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        if rateLimitBreached(r) {
            writer.WriteHeader(403)
            fmt.Fprint(writer, "Rate limit reached")
            return
        }

		h.ServeHTTP(w, r)
	})
}

func main() {
    mux := http.NewServeMux()

    // Define our middleware stack
    // Last middleware will be run first
    stack := []Middleware{
        LogMiddleware,
        UserMiddleware,
        RateLimitMiddleware,
    }

    // Assign our actual HTTP Handler to a variable
    handleAllRequests := func(writer http.ResponseWriter, r *http.Request) {
        writer.WriteHeader(200)
        fmt.Fprint(writer, "HELLO!")
        fmt.Fprintf(writer, " HERE IS YOUR USER: %s", r.Context().Value("user"))
    }

    // Set our handler as a "wrapped" handler - each middleware is called
    // before finally calling the handleAllRequests http Handler
    mux.HandleFunc("/", CompileMiddleware(handleAllRequests, stack))

    srv := &http.Server{
        Handler: mux,
    }

    ln, err := net.Listen("tcp", ":80")
    if err != nil {
        panic(err)
    }

    srv.Serve(ln)
}

In addition to creating the UserMiddleware and adding it to the stack, we updated the base handler to print out information about the current user, retrieved from the request context.

That's this part:

// Assign our actual HTTP Handler to a variable
handleAllRequests := func(writer http.ResponseWriter, r *http.Request) {
    writer.WriteHeader(200)
    fmt.Fprint(writer, "HELLO!")
    fmt.Fprintf(writer, " HERE IS YOUR USER: %s", r.Context().Value("user"))
}

But I Want Types!

One thing sort of sucks: The type any.

The context.WithValue method accepts a value of type any, and r.Context().Value("foo") can return a value of type any.

This means Go's compiler (and our IDE's) can't enforce types, nor help us to know what data is being get/set in the context object. But we want type safety! That's why we use Go!

This article (and this one) covers some ways to get type saftey. I've not settled on what I like best, but here's a stab at it.

First, let's assume that our context shouldn't just receive a string representing a user. We'll instead make a User struct and some helper functions to manage it:

// Still a bit contrived,
// but bear with me
type User struct {
    Username string
}

// setUser adds a user to a context, returning 
// a new context with the user attached
func setUser(ctx context.Context, u *User) context.Context {
    return context.WithValue(ctx, "user", u)
}

// getUser returns an instance of User,
// if set, from the given context
func getUser(ctx context.Context) *User {
    user, ok := ctx.Value("user").(*User)

    if !ok {
        return nil
    }

    return user
}

These helper functions gives us a type-safe way to manage getting/setting the User to/from our context, and gives the compiler something to chew on.

Our UserMiddleware becomes this:

// UserMiddleware gets the current user and adds it to a new context
func UserMiddleware(h http.HandlerFunc) http.HandlerFunc {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ctx := setUser(r.Context(), &User{
            Username: "fideloper",
        })
        newReq := r.WithContext(ctx)
        h.ServeHTTP(w, newReq)
    })
}

The only change there is to use the setUser function to add a new User to the context (which returns a new Context - remember, contexts are immutable).

Then we can update our base handler to use getUser to retrieve the User. I chose to return nil if no user is associated, rather than an error. You do you.

// Assign our actual HTTP Handler to a variable
handleAllRequests := func(writer http.ResponseWriter, r *http.Request) {
    writer.WriteHeader(200)
    fmt.Fprint(writer, "HELLO!")
    
    user := getUser(r.Context())

    if user != nil {
        fmt.Fprintf(writer, " HERE IS YOUR USER: %s", user.Username)
        return
    }

    fmt.Fprint(writer, " NO USER AUTHENTICATED")
}

Here's the whole thing:

package main

import (
    "context"
    "fmt"
    "log"
    "net"
    "net/http"
)

type User struct {
    Username string
}

func setUser(ctx context.Context, u *User) context.Context {
    return context.WithValue(ctx, "user", u)
}

func getUser(ctx context.Context) *User {
    user, ok := ctx.Value("user").(*User)

    if !ok {
        return nil
    }

    return user
}

// Middleware is func type that allows for chaining middleware
type Middleware func(http.HandlerFunc) http.HandlerFunc

// CompileMiddleware takes the base http.HandlerFunc h and wraps around the given list of Middleware m
func CompileMiddleware(h http.HandlerFunc, m []Middleware) http.HandlerFunc {
    if len(m) < 1 {
        return h
    }

    wrapped := h

    // loop in reverse to preserve middleware order
    for i := len(m) - 1; i >= 0; i-- {
        wrapped = m[i](wrapped)
    }

    return wrapped
}

// Let's define some middleware!

// LogMiddleware logs some output for each request received
func LogMiddleware(h http.HandlerFunc) http.HandlerFunc {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        log.Printf("%s: %s", r.Method, r.RequestURI)
        h.ServeHTTP(w, r)
    })
}

// UserMiddleware gets the current user and adds it to a new context
func UserMiddleware(h http.HandlerFunc) http.HandlerFunc {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ctx := setUser(r.Context(), &User{
            Username: "fideloper",
        })
        newReq := r.WithContext(ctx)
        h.ServeHTTP(w, newReq)
    })
}

// RateLimitMiddleware limits how often a request can be made from a given client
func RateLimitMiddleware(h http.HandlerFunc) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        if rateLimitBreached(r) {
            writer.WriteHeader(403)
            fmt.Fprint(writer, "Rate limit reached")
            return
        }

		h.ServeHTTP(w, r)
	})
}

func main() {
    mux := http.NewServeMux()

    // Define our middleware stack
    // Last middleware will be run first
    stack := []Middleware{
        LogMiddleware,
        UserMiddleware,
        RateLimitMiddleware,
    }

    // Assign our actual HTTP Handler to a variable
    handleAllRequests := func(writer http.ResponseWriter, r *http.Request) {
        writer.WriteHeader(200)
        fmt.Fprint(writer, "HELLO!")
        user := getUser(r.Context())

        if user != nil {
            fmt.Fprintf(writer, " HERE IS YOUR USER: %s", user.Username)
            return
        }

        fmt.Fprint(writer, " NO USER AUTHENTICATED")
    }

    // Set our handler as a "wrapped" handler - each middleware is called before
    // finally calling the handleAllRequests http Handler
    mux.HandleFunc("/", CompileMiddleware(handleAllRequests, stack))

    srv := &http.Server{
        Handler: mux,
    }

    ln, err := net.Listen("tcp", ":80")
    if err != nil {
        panic(err)
    }

    srv.Serve(ln)
}

In reality I'd have several files and/or some modules of my own here to manage Users, Middleware, etc. In this examples, we're throwinng it all into one file.

But now we know how to use context objects to pass data through our middleware and handlers!

No More Web Servers

That's a wrap on web servers. The next thing we'll look into is more fun: Reverse Proxies.