6 min read
WebSocket Session Manager using Boost Beast

In this post, we’ll explore how to implement a WebSocket session manager using the Boost Beast stackful coroutine. The session manager helps track active WebSocket sessions, synchronize between them, and facilitate data exchange—something I found necessary when there is a need to share data between two WebSocket sessions.

Before creating the session manager class, let’s start by defining the Session class and its dependencies:

#include <atomic>
#include <future>
#include <iostream>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <vector>

#include <boost/algorithm/string.hpp>
#include <boost/asio.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/beast.hpp>
#include <boost/beast/core.hpp>
#include <boost/beast/websocket.hpp>

template <typename Token>
typename net::async_result<typename std::decay<Token>::type, void(boost::system::error_code, int)>::return_type
static async_wait(std::shared_future<void> sfut, Token&& token) {
    auto init = [sfut](auto handler) {
        boost::async(boost::launch::async,
					[sfut, handler = std::move(handler), work = make_work_guard(handler)]() mutable
						{
                boost::system::error_code ec;
                int result = 0;
                sfut.wait();
                result = 1;

                auto executor = net::get_associated_executor(handler);
                net::dispatch(executor, [ec, handler = std::move(handler), result]() mutable
								{
                    std::move(handler)(ec, result);
                });
            });
    };
    return net::async_initiate<Token, void(boost::system::error_code, int)>(init, token);
}

struct Data {
    std::promise<void> left_connection_ready;
    std::promise<void> right_connection_ready;
    std::optional<int> unique_id = std::nullopt;
		std::atomic_uint counter{0};
    
    std::mutex mtx;
};

template <typename DataType>
struct Communicator {
    Communicator(const std::string& id) : id(id) {}

    std::string id;
    DataType data;
};

class Session {
   public:
    Session(websocket::stream<beast::tcp_stream>& ws, net::yield_context yield)
        : ws(ws), yield(yield) {}

    void attach_comm(std::shared_ptr<Communicator<Data>> comm) {
        this->comm = comm;
    }

    bool do_handshake() {
        beast::error_code ec;
        beast::flat_buffer read_buffer;
				// HTTP handshake
        http::async_read(ws.next_layer(), read_buffer, this->req_parser, yield[ec]);
        if (ec) {
            std::cerr << ec.message() << '\n';
            return false;
        }

        // Get the endpoint
        websocket::request_type req(req_parser.get());
        std::string target_uri = req.target();
        std::vector<std::string> paths;
        boost::split(paths, target_uri, boost::is_any_of("/"));
        if ((paths.size() != 2) || !(paths[0] == "left" || paths[0] == "right")) {
            std::cerr << "Invalid path" << '\n';
            ws.async_write(net::buffer("Invalid path"), yield[ec]);
            if (ec) {
                std::cerr << ec.message() << '\n';
            }
            ws.async_close(websocket::close_code::normal, yield);
            return false;
        }

        this->session_type = paths[0];
        this->id = paths[1];

        return true;
    }

    void run() {
        beast::error_code ec;
				// Upgrading http connection to websocket
        ws.async_accept(req_parser.get(), yield[ec]);
        if (ec) {
            std::cerr << ec.message() << '\n';
            return;
        }

        // Wait until the counterpart session is ready
        std::future<void> fut;
        if (this->session_type == "left") {  // wait for right session
            this->comm->data.left_connection_ready.set_value();
            fut = comm->data.right_connection_ready.get_future();
        } else if (this->session_type == "right") {  // wait for left session
            this->comm->data.right_connection_ready.set_value();
            fut = comm->data.left_connection_ready.get_future();
        }
        if (fut.valid()) {
            async_wait(fut.share(), yield[ec]);
            if (ec) {
                std::cerr << ec.message() << '\n';
                return;
            }
        } else {
            std::cerr << "connection ready promise is invalid\n";
            return;
        }

        // generate unique id for both left and right sessions
        int current_unique_id = -1;
        {
            std::lock_guard<std::mutex> lock(comm->data.mtx);
            if (!(comm->data.unique_id.has_value())) {
                comm->data.unique_id = rand();
            }
            current_unique_id = comm->data.unique_id.value();
        }

        for (;;) {
            beast::flat_buffer buffer;
            ws.async_read(buffer, yield[ec]);
            if (ec == websocket::error::closed) {
                break;
            }

            if (ec) {
                std::cerr << ec.message() << '\n';
                return;
            }

            // Echo the message back with unique id and a unique message index
						int current_idx = (comm->data.counter)++;
            auto message = std::to_string(current_unique_id) + " | "
							+ std::to_string(current_idx) + " | "
							+ beast::buffers_to_string(buffer.cdata());
            ws.async_write(net::buffer(message), yield[ec]);
            if (ec) {
                std::cerr << ec.message() << '\n';
                return;
            }
        }
    }

   private:
    websocket::stream<beast::tcp_stream>& ws;
    net::yield_context yield;
    std::shared_ptr<Communicator<Data>> comm;

    http::request_parser<http::string_body> req_parser;
    std::string id;
    std::string session_type;
};

The Session class encapsulates logic during a WebSocket session, covering the from handshake to client data processing. We separate the handshake and run logic because, between these two methods, the session manager needs to instantiate a Communicator<Data> shared by a pair of sessions.

In the do_handshake method, we determine which session can communicate by parsing the request endpoint after the HTTP handshake and before upgrading it to a WebSocket connection. For example, a pair of sessions comes from requests with /left/id and /right/id when the id is identical. This id is used by the session manager to share Communicator<Data>.

After upgrading to the WebSocket connection in the run method, there are three distinct code parts:

  1. Waiting until the counterpart session is ready:
    We use two std::promise objects to signal the readiness of a session. When waiting for the promise to be ready, we use a custom async_wait function to avoid thread blocking the wait, similar to what we discussed in Custom Async Function on Boost Beast Coroutine.
  2. Generating a unique_id shared with partner sessions:
    This demonstrates a use case to share a state between two sessions. To avoid race conditions, we use std::mutex before setting/accessing the variable.
  3. Main session loop echoing back client messages with unique_id and message index:
    The main loop echoes back messages from the client, prepending unique_id and a unique message index. Since the message index uses a std::atomic variable, we don’t need to protect the access with std::mutex, unlike the previous code part.

After creating the Session class, let’s move on to the SessionManager class.

#include <map>

class SessionManager {
public:
    std::shared_ptr<Session> create_session(
        websocket::stream<beast::tcp_stream> &ws,
        net::yield_context yield)
    {
        auto session = std::make_shared<Session>(ws, yield);
        if (session == nullptr) {
            std::cerr << "Failed creating session\n";
            return nullptr;
        }

        if (bool ret = session->do_handshake(); !ret) {
            std::cerr << "Session handshake failed\n";
            return nullptr;
        }

        bool session_inserted = false;
        std::string id = session->get_id();
        if (session->get_session_type() == "left") {
            std::lock_guard<std::mutex> lock(left_mtx);
            if (left_sessions.find(id) == left_sessions.end()) {
                left_sessions.emplace(id, session);
                session_inserted = true;
            }
        } else if (session->get_session_type() == "right") {
            std::lock_guard<std::mutex> lock(right_mtx);
            if (right_sessions.find(id) == right_sessions.end()) {
                right_sessions.emplace(id, session);
                session_inserted = true;
            }
        }

        if (!session_inserted) {
            std::cerr << "Failed to insert session, possibly duplicate session existed\n";
            return nullptr;
        }

        // Prepare communicator
        {
            std::lock_guard<std::mutex> lock(comms_mtx);
            if (intersession_comms.find(id) == intersession_comms.end()) {
                intersession_comms.emplace(id, std::make_shared<Communicator<Data>>(id));
            }
        }
        session->attach_comm(intersession_comms.at(id));

        return session;
    }

    bool remove_session(std::shared_ptr<Session> session) {
        bool session_removed = false;
        
        std::string id = session->get_id();
        if (session->get_session_type() == "left") {
            std::lock_guard<std::mutex> lock(left_mtx);
            if (left_sessions.find(id) != left_sessions.end()) {
                left_sessions.erase(id);
                session_removed = true;
            }

            // Remove communicator if session counterpart is not exist anymore
            if (right_sessions.find(id) == right_sessions.end()) {
                std::lock_guard<std::mutex> inner_lock(comms_mtx);
                intersession_comms.erase(id);
            }
        } else if (session->get_session_type() == "right") {
            std::lock_guard<std::mutex> lock(right_mtx);
            if (right_sessions.find(id) != right_sessions.end()) {
                right_sessions.erase(id);
                session_removed = true;
            }

            // Remove communicator if session counterpart is not exist anymore
            if (left_sessions.find(id) == left_sessions.end()) {
                std::lock_guard<std::mutex> inner_lock(comms_mtx);
                intersession_comms.erase(id);
            }
        }

        if (!session_removed) {
            std::cerr << "Failed to remove session\n";
            return false;
        }

        return true;
    }

private:
    std::map<std::string, std::shared_ptr<Session>> left_sessions, right_sessions;
    std::mutex left_mtx, right_mtx;

    using intersession_comms_t = std::map<std::string,
        std::shared_ptr<Communicator<Data>>>;
    intersession_comms_t intersession_comms;
    std::mutex comms_mtx;
};

The SessionManager has two public methods:

  1. create_session:
    Session creation executes the session handshake, requiring the session type and session id parsed from the request endpoint. Then, we insert the session into a std::map of sessions for its respective type (left or right). During this, we create a shared communicator inserted into intersession_comms and attach it to the Session object.
  2. remove_session:
    Session removal is crucial to properly remove the Session object from memory when the connection no longer exists. In this method, the Communicator object also needs attention when the session counterpart is already removed.

Finally, here’s how you use the session manager in the do_session function of the WebSocket coroutine code:

SessionManager session_manager;

void do_session(websocket::stream<beast::tcp_stream> &ws, net::yield_context yield) {
    std::shared_ptr<Session> sess_ptr;
    try {
        sess_ptr = session_manager.create_session(ws, yield);

        if (sess_ptr == nullptr) {
            std::cerr << "Unable to run session\n";
        } else {
            sess_ptr->run();
        }
    } catch (const std::exception& ex) {
        std::cerr << "Exception during do_session: " << ex.what() << '\n';
    }

    if (sess_ptr != nullptr) {
        session_manager.remove_session(sess_ptr);
    }
}

Simply create a session, run it, and remove it at the end of the function.

In conclusion, to enable state sharing between WebSocket sessions on the Boost Beast WebSocket coroutine server, we write a session manager class managing session lifetimes while allowing shared Communicator objects between sessions. However, we must be careful with the use of shared pointers to avoid memory leaks and the use of mutexes to prevent data races.