Skip to main content

Testing in Go: WebSockets

·10 mins

WebSockets offer duplex communication from a non-trusted source to a server that we own across a TCP socket connection. This means that, instead of continually polling the web server for updates and having to perform the whole TCP dance with each request, we can maintain a single TCP socket connection and then send and listen to messages on said connection.

In Go’s ecosystem there are few different implementations of the WebSocket protocol. Some libraries are pure implementation of the protocol. Others though, have chosen to build on top of the WebSocket protocol to create better abstractions for their particular use-case.

Here’s a non-exhaustive list of Go WebSocket protocol implementations:

In this article we will use the excellent gorilla/websocket implementation of the WebSocket protocol, from the Gorilla Web Toolkit project. You will notice that testing WebSocket is not much different from testing HTTP servers. Still, there are aspects of WebSockets that we have to take into account while testing.

Auctions #

One of the businesses whose backbone is real-time communication are online auction houses. During an auction, seconds make the difference between winning or losing a collectible item that you have been wanting for so long.

Let’s use a simple auction application powered by gorilla/websocket as an example for this article.

First, we will define two very simple entities Bid and Auction that we will use in our WebSocket handlers. The Auction will receive a Bid method that we will use to place a new bid on the Auction.

Entities #

Let’s look at the Auction and Bid types, in all of their glory:

type Bid struct {
	UserID int     `json:"user_id"`
	Amount float64 `json:"amount"`
}

type Auction struct {
	ItemID  int   `json:"item_id"`
	EndTime int64 `json:"end_time"`
	Bids    []*Bid
}

func NewAuction(d time.Duration, itemID int, b []*Bid) Auction {
	return Auction{
		ItemID:  itemID,
		EndTime: time.Now().Add(d).Unix(),
		Bids:    b,
	}
}

Both of the types are fairly simple, encapsulating very little data. The NewAuction constructor function builds an auction with a duration, itemID and a slice of *Bids.

Bidding #

We will place a bid on an auction through the Bid method:

func (a *Auction) Bid(amount float64, userID int) (*Bid, error) {
	if len(a.Bids) > 0 {
		largestBid := a.Bids[len(a.Bids)-1]
		if largestBid.Amount >= amount {
			return nil, fmt.Errorf("amount must be larger than %.2f", largestBid.Amount)
		}
	}

	if a.EndTime < time.Now().Unix() {
		return nil, fmt.Errorf("auction already closed")
	}

	bid := Bid{
		Amount: amount,
		UserID: userID,
	}

	// Mutex lock
	a.Bids = append(a.Bids, &bid)
	// Mutex unlock

	return &bid, nil
}

The Auction’s Bid method is where the bidding magic happens. It takes an amount and a userID as arguments and adds a Bid to the Auction. Also, it checks if the Auction has already closed and that the new bid amount is larger than the amount of the largest bid. If any of these conditions are not true, it will return an appropriate error to the caller.

Having the types and the Bid method out of the way, let’s dive into the WebSockets mechanics.

Handling WebSockets #

Imagine a web frontend that can place bids on an auction in real time. With every JSON message it sends over WebSockets it will supply the identifier of the user placing the bid (UserID) and the amount (Amount) of the bid. Once the server accepts the message, it will place the bid and reply with a meaningful answer to the client.

On the server side, this communication will be done by a net/http Handler. It will handle all of the WebSocket intricacies, with a few notable steps:

  1. Upgrade the incoming HTTP connection to a WebSocket one
  2. Accept incoming messages from a client
  3. Decode bid from the inbound message
  4. Place the bid
  5. Send an outbound message with the reply to the client

Let’s write such a handler.

First, let’s define the inbound and outbound message types:

type inbound struct {
	UserID int     `json:"user_id"`
	Amount float64 `json:"amount"`
}

type outbound struct {
	Body string `json:"body"`
}

Both of them represent the in/outbound messages respectively, which will be the data flowing between the client and the server. The inbound message will represent a bid, while the outbound type represents a simple message with some text in its Body.

Next, let’s define the bidsHandler, including its ServeHTTP method containing the HTTP connection upgrade:

var upgrader = websocket.Upgrader{}

type bidsHandler struct {
	auction *Auction
}

func (bh bidsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	upgrader.CheckOrigin = func(r *http.Request) bool { return true }
	ws, err := upgrader.Upgrade(w, r, nil)
	if err != nil {
		log.Println("upgrade:", err)
		return
	}
	defer ws.Close()

	// More to come...
}

First, we define a websocket.Upgrader, which takes the http.ResponseWriter and the *http.Request from the handler and upgrades the connection. Because this is just an example application, the upgrader.CheckOrigin method will only return a true bool, without checking the origin of the incoming request.

Once the upgrader finishes with the connection upgrade, it returns a *websocket.Conn object, stored in the ws variable. The *websocket.Conn will receive all of the incoming messages, where our handler will be reading from. Also, the handler will be writing messages to the *websocket.Conn, which will send an outbound message to the client.

Let’s add the message loop next:

func (bh bidsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	// Code from above...

	for {
		_, m, err := ws.ReadMessage()
		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
				log.Printf("error: %v", err)
			}
			return
		}

		var in inbound
		err = json.Unmarshal(m, &in)
		if err != nil {
			handleError(ws, err)
			continue
		}

		bid, err := bh.auction.Bid(in.Amount, in.UserID)
		if err != nil {
			handleError(ws, err)
			continue
		}

		out, err := json.Marshal(outbound{Body: fmt.Sprintf("Bid placed: %.2f", bid.Amount)})
		if err != nil {
			handleError(ws, err)
			continue
		}

		err = ws.WriteMessage(websocket.BinaryMessage, out)
		if err != nil {
			handleError(ws, err)
			continue
		}
	}
}

This for loop does a few things. First, it reads a new WebSocket message using ws.ReadMessage(), which returns the type of the message (binary or text), the message itself (m) and a potential error (err). Then, it checks the error if the client has closed the connection unexpectedly.

Once the error handling is completed and the message is retrieved, we decode it using json.Unmarshal into the in inbound message. Once in is available, we invoke bh.auction.Bid which places a bid on the auction, using the amount of the bid (in.Amount) and the ID of the bidder (in.UserID) as arguments. The Bid method returns a bid (bid) and an error (err).

After the bid is placed, we use json.Marshal to convert an outbound message with the bid confirmation message encapsulated to slice of bytes ([]byte). Then we send the bytes to the client using the ws.WriteMessage method, which concludes the request-response server loop.

We can ignore the client side for now. Let’s now see how we can test this WebSockets handler code.

Testing WebSockets handlers #

Although writing WebSocket handlers is more involved relative to ordinary HTTP handlers, testing them is simple. In fact, testing WebSockets handlers is as simple as testing HTTP handlers. This is because WebSockets are built on HTTP, so testing WebSockets is done using the same tools that testing HTTP servers is done with.

We will begin by adding the test setup:

func TestBidsHandler(t *testing.T) {
	tcs := []struct {
		name     string
		bids     []*Bid
		duration time.Duration
		message  inbound
		reply    outbound
	}{
		{
			name:     "with good bid",
			bids:     []*Bid{},
			duration: time.Hour * 1,
			message:  inbound{UserID: 1, Amount: 10},
			reply:    outbound{Body: "Bid placed: 10.00"},
		},
		{
			name: "with bad bid",
			bids: []*Bid{
				&Bid{
					UserID: 1,
					Amount: 20,
				},
			},
			duration: time.Hour * 1,
			message:  inbound{UserID: 1, Amount: 10},
			reply:    outbound{Body: "amount must be larger than 20.00"},
		},
		{

			name: "good bid on expired auction",
			bids: []*Bid{
				&Bid{
					UserID: 1,
					Amount: 20,
				},
			},
			duration: time.Hour * -1,
			message:  inbound{UserID: 1, Amount: 30},
			reply:    outbound{Body: "auction already closed"},
		},
	}

	for _, tt := range tcs {
		t.Run(tt.name, func(t *testing.T) {
			a := NewAuction(tt.duration, 1, tt.bids)
			h := bidsHandler{&a}

			// To be added...
		})
	}
}

First, we begin by defining the testcase type. It has a name, which is the human-readable name of the test case. Also, each testcase has a bids slice and a duration which will be used to create a test Auction with. The testcase also has an inbound message and an outbound reply - which is what the test case will send to and expect in return from the handler.

After, in the TestBidsHandler we add three different test cases – one where the client wants to place a bad bid, that is lower than the largest bid, another test case where the client adds a good bid and a third one where the client bids on an expired auction.

In the for loop, for each of the test cases, we create a subtest which uses the NewAuction constructor to create a new test auction. We also create a bidsHandler that takes the newly created Auction as an attribute.

Let’s finish off the subtest function:

func TestBidsHandler(t *testing.T) {
	// Test cases and other setup from above...

	for _, tt := range tcs {
		t.Run(tt.name, func(t *testing.T) {
			a := NewAuction(tt.duration, 1, tt.bids)
			h := bidsHandler{&a}

			s, ws := newWSServer(t, h)
			defer s.Close()
			defer ws.Close()

			sendMessage(t, ws, tt.message)

			reply := receiveWSMessage(t, ws)

			if reply != tt.reply {
				t.Fatalf("Expected '%+v', got '%+v'", tt.reply, reply)
			}
		})
	}
}

We added few new functions to the subtest function body. The newWSServer will create a test server and upgrade it to a WebSocket connection, returning both the server and the WebSocket connections. Then, the sendMessage function will send the message from the test case to the test server throught the WebSocket connection. After that, through the receiveWSMessage we will retrieve the reply from the server and assert for its correctness by comparing it to the reply of the test case.

So, what do each of these small functions do? Let’s break them down one by one.

func newWSServer(t *testing.T, h http.Handler) (*httptest.Server, *websocket.Conn) {
	t.Helper()

	s := httptest.NewServer(h)
	wsURL := httpToWs(t, s.URL)

	ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
	if err != nil {
		t.Fatal(err)
	}

	return s, ws
}

The newWSServer function will use the httptest.NewServer function to mount the handler on a test HTTP server. Once that is done, it will convert the server’s URL to a WebSocket URL through the httpToWS function. (It simply replaces the http protocol to a ws, or https to wss, protocol in the URL.`)

To establish a WebSocket connection, we use the websocket.DefaultDialer which is a dialer with all fields set to the default values. We invoke the Dial method on the dialer, with the WebSocket server URL (wsURL) which returns the WebSocket connection.

func sendMessage(t *testing.T, ws *websocket.Conn, msg inbound) {
	t.Helper()

	m, err := json.Marshal(msg)
	if err != nil {
		t.Fatal(err)
	}

	if err := ws.WriteMessage(websocket.BinaryMessage, m); err != nil {
		t.Fatalf("%v", err)
	}
}

The sendMessage function takes an inbound message as argument with the WebSocket connection (ws). It marshals the message into a JSON and it sends it over the WebSocket connection as a binary message.

func receiveWSMessage(t *testing.T, ws *websocket.Conn) outbound {
	t.Helper()

	_, m, err := ws.ReadMessage()
	if err != nil {
		t.Fatalf("%v", err)
	}

	var reply outbound
	err = json.Unmarshal(m, &reply)
	if err != nil {
		t.Fatal(err)
	}

	return reply
}

receiveWSMessage takes the WebSocket connection (ws) as argument and it fetches a message using ws.ReadMessage(). Once the message is successfully retrieved, it unmarshals it into a outbound message using json.Unmarshal. As a last step, receiveWSMessage returns the outbound message to the test, so the test can continue with its assertions.

If we would run the tests, we will see them passing:

$ go test ./... -v
=== RUN   TestBidsHandler
=== RUN   TestBidsHandler/with_good_bid
=== RUN   TestBidsHandler/with_bad_bid
=== RUN   TestBidsHandler/good_bid_on_expired_auction
--- PASS: TestBidsHandler (0.00s)
    --- PASS: TestBidsHandler/with_good_bid (0.00s)
    --- PASS: TestBidsHandler/with_bad_bid (0.00s)
    --- PASS: TestBidsHandler/good_bid_on_expired_auction (0.00s)
PASS
ok  	github.com/fteem/go-playground/testing-in-go-web-sockets	0.013s

You can see the example code on Github.

Also, to see another approach at testing WebSockets in Go, you can head over to the WebSockets chapter from the book “Learn Go with tests”.

More WebSockets reading #

If you would like to learn more about the details of the WebSocket protocol, I recommend reading RFC 6455 which defines the protocol itself. In addition, you can read more in follow-up RFCs regarding the WebSocket protocol: