各位技术同仁,下午好!

今天,我们将共同深入探索 C++ 异步编程的奥秘,特别是如何为您的自定义异步网络库量身打造一个 awaiter 对象。在 C++20 协程(Coroutines)的强大能力下,我们得以用同步的思维编写异步的代码,极大地提升了开发效率和代码可读性。然而,要真正驾驭协程,特别是将其应用于网络 I/O 这种高并发、事件驱动的场景,我们必须理解并掌握其核心机制——尤其是 co_await 表达式背后的 awaiter

本次讲座,我将以一名编程专家的视角,带领大家从零开始,逐步构建一个功能完备的 awaiter,并将其融入一个简化的异步网络库中。我们将看到,自定义 awaiter 不仅仅是语法糖,它更是连接协程与底层异步事件的关键桥梁,是实现高度定制化挂起逻辑的基石。

异步编程的演进与 C++ Coroutines 的崛起

在现代软件开发中,异步编程已成为构建高性能、高响应性应用不可或缺的一部分。无论是 Web 服务器、数据库驱动程序还是 GUI 应用,都需要在等待耗时操作(如网络请求、磁盘 I/O)完成时,不阻塞主线程,从而保持应用的流畅和高效。

传统异步编程模式的挑战

在 C++ 引入协程之前,我们通常依赖以下模式来实现异步:

  1. 回调函数 (Callbacks):这是最原始的异步模式。当一个异步操作完成时,会调用预先注册的回调函数。

    • 优点:简单直接,易于理解基本概念。
    • 缺点:容易陷入“回调地狱”(Callback Hell),代码层层嵌套,难以阅读、维护和错误处理。状态管理复杂。
  2. Future/Promise 模式:通过 std::futurestd::promise(或 Boost.Future)将异步操作的结果封装起来。操作发起方获得一个 future,可以在将来某个时刻查询结果。

    • 优点:将结果与回调分离,一定程度上缓解了回调地狱。
    • 缺点:链式操作仍需依赖 then() 等方法,语法仍显笨重。错误传播和取消机制相对复杂。
  3. 事件循环 (Event Loop):如 libuvboost::asio。通过一个主循环不断监听 I/O 事件,当事件就绪时,分发到对应的处理器。

    • 优点:高性能,适用于高并发场景。
    • 缺点:编程模型与传统同步代码差异大,需要适应事件驱动的思维。逻辑分散在多个事件处理器中,难以追踪。

这些传统模式虽然有效,但在表达复杂异步逻辑时,往往导致代码结构扭曲、状态管理混乱,从而降低了开发效率和代码质量。

C++ Coroutines:同步语法的异步表达

C++20 标准引入的协程,为异步编程带来了革命性的改变。它允许函数在执行过程中暂停(挂起),并在稍后从暂停点恢复执行,而无需阻塞调用线程。最关键的是,这一切都以近乎同步的顺序流控制来表达,极大地提升了异步代码的可读性和可维护性。

协程的核心思想是无栈协程(Stackless Coroutines),这意味着协程的局部变量和状态不再存储在独立的调用栈中,而是由编译器在堆上分配的协程帧(Coroutine Frame)中管理。这使得协程的切换成本非常低,并且可以实现更灵活的调度。

C++ 协程主要通过三个关键字来操作:

  • co_await:用于挂起当前协程,等待一个可等待对象(Awaitable)完成,并在完成后恢复。
  • co_yield:用于生成一个值并挂起当前协程,通常用于实现生成器(Generator)。
  • co_return:用于从协程返回一个值,并销毁协程帧。

在本次讲座中,我们的焦点将是 co_await 表达式,以及它背后的 awaiter 对象。

深入理解 C++ Coroutines 的核心机制

要手写自定义 awaiter,我们必须对 C++ 协程的底层机制有深刻的理解。

协程的生命周期与核心组件

一个协程的生命周期涉及以下几个关键组件:

  1. 协程函数 (Coroutine Function):任何包含 co_await, co_yield, co_return 关键字的函数。
  2. 返回类型 (Return Type):协程函数必须返回一个可等待类型(Awaitable Type),例如 Task<T>Generator<T> 等。这个返回类型负责管理协程的生命周期和结果。
  3. promise_type:这是协程返回类型内部的一个嵌套类型,是协程状态机的核心。编译器会根据返回类型自动查找 promise_type。它定义了协程的初始挂起行为 (initial_suspend)、最终挂起行为 (final_suspend)、返回值处理 (return_value)、异常处理 (unhandled_exception) 以及如何获取协程的句柄 (get_return_object)。
  4. std::coroutine_handle<P>:协程的句柄。P 是协程的 promise_type。通过句柄,我们可以恢复(resume())或销毁(destroy())一个挂起的协程。它是连接外部世界与协程内部状态的桥梁。
  5. awaiter:一个实现了 await_ready()await_suspend()await_resume() 三个方法的类型。co_await 表达式就是通过与 awaiter 对象的交互来控制协程的挂起和恢复。

co_await 表达式的幕后机制

当编译器遇到 co_await awaitable_expr; 这样的表达式时,它会执行以下步骤:

  1. 获取 awaiter 对象

    • 如果 awaitable_expr 本身就是 awaiter 类型,则直接使用。
    • 如果 awaitable_expr可等待类型(Awaitable Type),则调用其 operator co_await() 方法来获取 awaiter 对象。如果 operator co_await() 不存在,则直接将 awaitable_expr 视为 awaiter
    • 这一步是为了提供一个转换点,允许将任何对象包装成 awaiter
  2. 调用 awaiter.await_ready()

    • 如果返回 true,表示操作已经完成,协程可以立即继续执行,无需挂起。await_resume() 会被立即调用。
    • 如果返回 false,表示操作尚未完成,协程需要挂起。
  3. 调用 awaiter.await_suspend(std::coroutine_handle<CurrentPromiseType> caller_handle)

    • 这是挂起协程的关键一步。caller_handle 是当前正在执行的协程的句柄。
    • 在这个方法中,awaiter 通常会将 caller_handle 存储起来,并注册一个回调,当底层异步操作完成时,该回调会使用存储的句柄来恢复协程。
    • await_suspend() 的返回值决定了协程的调度:
      • void:挂起当前协程,将控制权返回给调用方。
      • bool:如果返回 true,挂起当前协程;如果返回 false,表示操作已完成,不挂起,立即调用 await_resume()
      • std::coroutine_handle<>:挂起当前协程,并立即恢复由返回的句柄所代表的协程。
  4. 调用 awaiter.await_resume()

    • await_suspend() 返回 true,且底层异步操作完成并恢复了协程后,await_resume() 会被调用。
    • 这个方法负责获取异步操作的结果(例如,读取的字节数、操作是否成功),并将其返回给 co_await 表达式。如果操作失败,它也可以抛出异常。

协程的 promise_type 概览

promise_type 是协程的幕后指挥。这里我们简单回顾一下其关键方法:

方法签名 描述
get_return_object() 返回一个 Awaitable 对象,该对象会传递给协程的调用者。通常包含协程的句柄。
initial_suspend() 返回一个 Awaiter 对象(通常是 std::suspend_alwaysstd::suspend_never),决定协程是否在创建后立即挂起。
final_suspend() 返回一个 Awaiter 对象,决定协程在执行完 co_return 或抛出异常后是否挂起。通常用于资源清理或调度。
return_void()return_value(T value) 协程执行 co_return;co_return value; 时调用。用于处理协程的返回值。
unhandled_exception() 当协程内部抛出未捕获的异常时调用。可以在这里处理异常,例如存储异常信息。

理解这些基本概念是构建自定义 awaiter 的前提。

为何需要自定义 awaiter?标准库的局限性

C++ 标准库提供了一些基本的 awaiter,如 std::suspend_alwaysstd::suspend_never

  • std::suspend_alwaysawait_ready() 返回 falseawait_suspend() 返回 true。总是挂起。
  • std::suspend_neverawait_ready() 返回 true。从不挂起。

它们非常有用,但对于实际的异步操作,特别是网络 I/O,它们远远不够。网络 I/O 的特点是:

  1. 外部事件驱动:我们等待的不是简单的暂停,而是操作系统通知某个套接字已准备好读写。
  2. 不确定性:操作完成的时间不确定。
  3. 结果返回:操作完成后需要返回结果(例如,读取了多少字节,连接是否成功)。
  4. 错误处理:网络操作可能失败,需要捕获和传播错误。

一个自定义的 awaiter 正是为了解决这些问题。它的核心任务是:

  • 将协程的挂起与底层异步 I/O 操作绑定:当协程 co_await 一个网络操作时,awaiter 负责启动该操作,并将当前协程的句柄注册到事件循环中。
  • 在事件完成后恢复协程:当事件循环检测到 I/O 事件就绪时,它会通过之前注册的句柄来恢复对应的协程。
  • 传递操作结果awaiter 在协程恢复时,负责将 I/O 操作的结果(成功/失败、数据)传递回协程。

简而言之,自定义 awaiter 是我们连接 C++ 协程的高级抽象与操作系统底层异步 I/O 机制的“胶水”。

构建异步网络库的基石:事件循环与 I/O 多路复用

在设计 SocketAwaiter 之前,我们需要一个基础的异步网络库架构。这个架构的核心是一个事件循环(Event Loop),它负责监听和分发 I/O 事件。在 Linux 上,我们通常使用 epoll;在 macOS/FreeBSD 上是 kqueue;在 Windows 上则是 IOCP。为了保持通用性并简化示例,我们将抽象一个 IoContext 类来代表这个事件循环。

IoContext:事件循环的抽象

IoContext 将是我们的网络库的核心。它会:

  • 管理一个 I/O 多路复用器(如 epoll 实例)。
  • 注册和注销文件描述符(Socket FD)及其感兴趣的事件(可读、可写)。
  • 运行一个循环,等待 I/O 事件发生。
  • 当事件发生时,查找并调度对应的协程句柄。
// 伪代码,表示IoContext的核心功能
class IoContext {
public:
    // 定义操作类型,读或写
    enum class IoOperation { Read, Write };

    // 注册一个协程句柄到特定的文件描述符和操作类型
    // 当对应的I/O事件发生时,会恢复这个句柄
    void register_awaiter(int fd, IoOperation op, std::coroutine_handle<> handle) {
        // 实际实现会存储fd、op和handle的映射关系
        // 并将fd添加到epoll/kqueue等监听列表中
        // ...
    }

    // 从监听列表中移除一个文件描述符
    void unregister_fd(int fd) {
        // ...
    }

    // 运行事件循环
    void run() {
        // 这是一个阻塞调用,会持续监听I/O事件
        // 当事件发生时,会调用内部的回调机制来恢复协程
        // ...
    }

    // 内部方法,当I/O事件就绪时被调用
    void handle_event(int fd, IoOperation op, int result_bytes_or_error) {
        // 查找与fd和op关联的协程句柄
        // std::coroutine_handle<> handle = ...;
        // 如果找到,恢复协程 handle.resume();
        // 还需要将result_bytes_or_error传递给恢复的协程
        // 这需要一个更复杂的机制,通常awaiter会把自身指针传给IoContext
        // IoContext可以直接在awaiter上设置结果
        // ...
    }
};

Socket:封装网络操作

Socket 类将封装底层的文件描述符,并提供异步读写操作的接口。这些接口将返回我们的自定义 SocketAwaiter

// 伪代码
class Socket {
public:
    Socket(IoContext& io_context, int fd) : io_context_(io_context), fd_(fd) {}

    // 异步读取操作,返回一个SocketAwaiter
    SocketAwaiter async_read(std::span<char> buffer);

    // 异步写入操作,返回一个SocketAwaiter
    SocketAwaiter async_write(std::span<char> buffer);

    // 异步连接操作
    SocketAwaiter async_connect(const std::string& host, int port);

    int fd() const { return fd_; }

private:
    IoContext& io_context_;
    int fd_;
    // 存储一些临时的状态,如错误码、读取字节数等,供awaiter使用
    // 这通常会通过awaiter直接回调给awaiter自身
    // 或者awaiter从Socket获取这些结果
    int last_op_result_ = 0;
    int last_op_errno_ = 0;
};

设计并实现 SocketAwaiter:挂起与恢复的艺术

现在,我们来设计核心的 SocketAwaiter。这个 awaiter 将处理 async_readasync_write 操作的挂起和恢复。

为了简化,我们先定义一个 Task 类型,作为协程函数的返回类型。

#include <coroutine>
#include <iostream>
#include <string>
#include <vector>
#include <map>
#include <functional>
#include <stdexcept>
#include <span> // C++20 for std::span
#include <sys/socket.h> // For socket operations
#include <netinet/in.h> // For sockaddr_in
#include <arpa/inet.h>  // For inet_pton
#include <unistd.h>     // For close, read, write
#include <fcntl.h>      // For fcntl, O_NONBLOCK

// 简单的Task类型,用于包装协程的返回结果
template <typename T = void>
struct Task {
    struct promise_type;
    using handle_type = std::coroutine_handle<promise_type>;

    handle_type coro_handle;

    Task(handle_type h) : coro_handle(h) {}
    Task(Task&& other) noexcept : coro_handle(std::exchange(other.coro_handle, nullptr))) {}
    ~Task() { if (coro_handle) coro_handle.destroy(); }

    Task(const Task&) = delete;
    Task& operator=(const Task&) = delete;

    // 可等待接口,允许co_await Task
    auto operator co_await() const {
        struct TaskAwaiter {
            handle_type awaiter_handle;

            bool await_ready() const noexcept { return !awaiter_handle || awaiter_handle.done(); }
            void await_suspend(std::coroutine_handle<> caller) noexcept {
                // 如果Task已经被co_await,并且它本身还未完成,
                // 则将caller的句柄存储到Task的promise中,
                // 当Task完成时,会恢复caller
                awaiter_handle.promise().continuation = caller;
            }
            T await_resume() {
                if constexpr (!std::is_void_v<T>) {
                    return awaiter_handle.promise().get_result();
                } else {
                    awaiter_handle.promise().get_result(); // 检查异常
                }
            }
        };
        return TaskAwaiter{coro_handle};
    }

    // Promise Type 定义
    struct promise_type {
        T value_;
        std::exception_ptr exception_;
        std::coroutine_handle<> continuation; // 存储co_await此Task的协程句柄

        Task get_return_object() { return Task{handle_type::from_promise(*this)}; }
        std::suspend_always initial_suspend() { return {}; } // 协程创建后立即挂起
        std::suspend_always final_suspend() noexcept {
            if (continuation) {
                continuation.resume(); // 协程结束时恢复等待它的协程
            }
            return {};
        }
        void unhandled_exception() { exception_ = std::current_exception(); }

        // return_value for non-void Task
        template<typename U = T>
        void return_value(U value) {
            if constexpr (!std::is_void_v<U>) {
                value_ = value;
            }
        }
        // return_void for void Task
        void return_void() { }

        T get_result() {
            if (exception_) {
                std::rethrow_exception(exception_);
            }
            if constexpr (!std::is_void_v<T>) {
                return value_;
            }
            // For void T, just check exception
        }
    };
};

// =========================================================================
// IoContext 和 Socket 的简化实现
// =========================================================================

class IoContext; // 前向声明

// Socket操作的类型
enum class IoOperation { Read, Write, Connect, Accept };

// 用于存储挂起协程信息和结果的结构
struct PendingOperation {
    std::coroutine_handle<> handle;
    int result_code = 0; // 操作结果:读取/写入字节数,或错误码
    int error_code = 0;  // 系统错误码
};

class IoContext {
public:
    IoContext() {
        epoll_fd_ = epoll_create1(0);
        if (epoll_fd_ == -1) {
            throw std::runtime_error("epoll_create1 failed");
        }
    }

    ~IoContext() {
        if (epoll_fd_ != -1) {
            close(epoll_fd_);
        }
    }

    // 注册一个等待中的操作
    void register_operation(int fd, IoOperation op, std::coroutine_handle<> handle) {
        std::lock_guard<std::mutex> lock(mtx_); // 线程安全考虑
        auto& ops_for_fd = pending_operations_[fd];
        ops_for_fd[op].handle = handle;

        // 首次注册fd,需要添加到epoll监听
        if (ops_for_fd.size() == 1) { // 假设这是fd的第一个注册操作
            epoll_event event{};
            event.events = EPOLLIN | EPOLLOUT | EPOLLET; // 监听读写,边缘触发
            event.data.fd = fd;
            if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) == -1) {
                // 错误处理
                std::cerr << "epoll_ctl_add failed for fd " << fd << std::endl;
            }
        } else {
            // 如果fd已在epoll中,可能需要修改监听事件,这里简化处理
            // 实际应用中需要更精细的epoll_ctl_mod
        }
    }

    // 移除一个等待中的操作
    void unregister_operation(int fd, IoOperation op) {
        std::lock_guard<std::mutex> lock(mtx_);
        auto it_fd = pending_operations_.find(fd);
        if (it_fd != pending_operations_.end()) {
            it_fd->second.erase(op);
            if (it_fd->second.empty()) {
                // 如果fd上没有其他等待操作,从epoll中移除
                epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr);
                pending_operations_.erase(it_fd);
            }
        }
    }

    // 获取指定fd和操作的结果
    PendingOperation get_op_result(int fd, IoOperation op) {
        std::lock_guard<std::mutex> lock(mtx_);
        return pending_operations_[fd][op]; // 假设一定存在
    }

    void run_one() {
        epoll_event events[16]; // 最多处理16个事件
        int num_events = epoll_wait(epoll_fd_, events, 16, -1); // 阻塞等待
        if (num_events == -1) {
            if (errno == EINTR) return; // 被信号打断
            throw std::runtime_error("epoll_wait failed");
        }

        for (int i = 0; i < num_events; ++i) {
            int fd = events[i].data.fd;
            int event_flags = events[i].events;

            std::lock_guard<std::mutex> lock(mtx_); // 保护pending_operations_

            // 检查读事件
            if (event_flags & EPOLLIN) {
                auto it_fd = pending_operations_.find(fd);
                if (it_fd != pending_operations_.end()) {
                    auto it_op = it_fd->second.find(IoOperation::Read);
                    if (it_op != it_fd->second.end()) {
                        // 准备恢复协程
                        PendingOperation& op_info = it_op->second;
                        op_info.result_code = 0; // 实际值将在awaiter中获取
                        op_info.error_code = 0;
                        std::coroutine_handle<> handle_to_resume = op_info.handle;
                        it_fd->second.erase(it_op); // 移除此操作的注册

                        // 恢复协程 (可能在当前线程,也可能调度到其他线程)
                        if (handle_to_resume) {
                            handle_to_resume.resume();
                        }
                    }
                }
            }

            // 检查写事件
            if (event_flags & EPOLLOUT) {
                auto it_fd = pending_operations_.find(fd);
                if (it_fd != pending_operations_.end()) {
                    // 处理Connect操作
                    auto it_connect_op = it_fd->second.find(IoOperation::Connect);
                    if (it_connect_op != it_fd->second.end()) {
                        PendingOperation& op_info = it_connect_op->second;
                        int optval;
                        socklen_t optlen = sizeof(optval);
                        // 检查connect是否成功
                        if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &optval, &optlen) == 0 && optval == 0) {
                            op_info.result_code = 0; // 成功
                        } else {
                            op_info.result_code = -1; // 失败
                            op_info.error_code = optval;
                            if (optval == 0) op_info.error_code = errno; // Fallback for some systems
                        }
                        std::coroutine_handle<> handle_to_resume = op_info.handle;
                        it_fd->second.erase(it_connect_op);
                        if (handle_to_resume) {
                            handle_to_resume.resume();
                        }
                    }

                    // 处理Write操作
                    auto it_write_op = it_fd->second.find(IoOperation::Write);
                    if (it_write_op != it_fd->second.end()) {
                        // 准备恢复协程
                        PendingOperation& op_info = it_write_op->second;
                        op_info.result_code = 0; // 实际值将在awaiter中获取
                        op_info.error_code = 0;
                        std::coroutine_handle<> handle_to_resume = op_info.handle;
                        it_fd->second.erase(it_write_op); // 移除此操作的注册
                        if (handle_to_resume) {
                            handle_to_resume.resume();
                        }
                    }
                }
            }

            // 如果fd上所有操作都已处理,且没有新的注册,则从epoll中移除
            // 这里为了简化,我们仅在unregister_operation中处理移除
            // 实际应用需要更细致的状态管理
        }
    }

    void run() {
        while (true) { // 真实应用中会有退出机制
            run_one();
        }
    }

private:
    int epoll_fd_;
    std::mutex mtx_; // 保护pending_operations_
    // map: fd -> (map: IoOperation -> PendingOperation)
    std::map<int, std::map<IoOperation, PendingOperation>> pending_operations_;
};

// =========================================================================
// SocketAwaiter 的实现
// =========================================================================

class SocketAwaiter {
public:
    SocketAwaiter(IoContext& io_context, int fd, IoOperation op, std::span<char> buffer = {},
                  const sockaddr_in* addr = nullptr, socklen_t addrlen = 0)
        : io_context_(io_context), fd_(fd), op_(op), buffer_(buffer),
          addr_(addr), addrlen_(addrlen) {}

    // 1. await_ready():检查是否可以立即继续
    bool await_ready() const noexcept {
        // 对于非阻塞I/O,如果操作可以立即完成,就避免挂起。
        // 例如,对于read,如果缓冲区已经有数据,可以直接读取。
        // 但对于网络I/O,通常我们倾向于直接挂起,让事件循环调度。
        // 简化起见,我们总是返回false,强制挂起,让事件循环来处理。
        // 实际应用中,可以尝试一次非阻塞read/write,如果成功则返回true。
        // 例如:
        // if (op_ == IoOperation::Read) {
        //     ssize_t bytes_read = ::recv(fd_, buffer_.data(), buffer_.size(), MSG_DONTWAIT);
        //     if (bytes_read >= 0) {
        //         // 立即完成,将结果存储,并返回true
        //         result_ = bytes_read;
        //         return true;
        //     }
        //     if (errno != EWOULDBLOCK && errno != EAGAIN) {
        //         // 发生错误,存储错误码,返回true (抛出异常)
        //         error_ = errno;
        //         return true;
        //     }
        // }
        return false;
    }

    // 2. await_suspend():挂起协程并注册I/O事件
    void await_suspend(std::coroutine_handle<> handle) {
        // 存储当前协程的句柄,以便事件就绪时恢复
        coro_handle_ = handle;

        // 将协程句柄注册到IoContext
        io_context_.register_operation(fd_, op_, handle);

        // 如果是连接操作,我们立即发起非阻塞连接
        if (op_ == IoOperation::Connect && addr_) {
            int ret = ::connect(fd_, (const sockaddr*)addr_, addrlen_);
            if (ret == -1 && errno != EINPROGRESS) {
                // 连接失败,不应该挂起,而是立即恢复并抛出异常
                // 此时需要取消在IoContext中的注册并立即恢复
                io_context_.unregister_operation(fd_, op_);
                error_ = errno;
                // 注意:这里不能直接抛出异常,只能设置状态让await_resume处理
                // 或者在这里直接恢复协程:handle.resume();
                // 为了简化,我们假设 connect 立即失败的情况会被 await_resume 处理
                // 或者 IoContext 会在 EPOLLOUT 触发后检查getsockopt(SO_ERROR)
                // 并且 IoContext 负责将错误码设置到 PendingOperation 中。
                // 实际上,为了避免二次挂起,这里应该直接恢复并处理错误
                // 但为了演示 awaiter 的基本流程,我们让它继续挂起,由 IoContext 调度
            }
        }
        // 对于读写,我们只是注册事件,实际读写在await_resume中完成或在IoContext的调度下完成
        // 对于 read/write,真正的操作应该在IoContext事件处理时进行,
        // 或者在await_resume中再次尝试非阻塞操作。
    }

    // 3. await_resume():协程恢复后获取结果
    int await_resume() {
        // 从IoContext获取操作结果
        PendingOperation op_result = io_context_.get_op_result(fd_, op_);

        // 检查是否有在await_suspend中发生的错误(如connect立即失败)
        if (error_ != 0) {
            throw std::runtime_error("Socket operation failed immediately: " + std::string(strerror(error_)));
        }

        // 针对不同操作类型处理结果
        if (op_ == IoOperation::Connect) {
            if (op_result.result_code == -1) {
                // 连接失败
                throw std::runtime_error("Connect failed: " + std::string(strerror(op_result.error_code)));
            }
            return 0; // 连接成功
        } else if (op_ == IoOperation::Read) {
            ssize_t bytes_read = ::recv(fd_, buffer_.data(), buffer_.size(), 0);
            if (bytes_read == -1) {
                throw std::runtime_error("Read failed: " + std::string(strerror(errno)));
            }
            return static_cast<int>(bytes_read);
        } else if (op_ == IoOperation::Write) {
            ssize_t bytes_written = ::send(fd_, buffer_.data(), buffer_.size(), 0);
            if (bytes_written == -1) {
                throw std::runtime_error("Write failed: " + std::string(strerror(errno)));
            }
            return static_cast<int>(bytes_written);
        }

        throw std::logic_error("Unsupported socket operation in await_resume");
    }

private:
    IoContext& io_context_;
    int fd_;
    IoOperation op_;
    std::span<char> buffer_; // 用于读写操作的缓冲区
    std::coroutine_handle<> coro_handle_; // 存储挂起的协程句柄

    const sockaddr_in* addr_; // 用于连接操作
    socklen_t addrlen_;

    int result_ = 0; // 存储操作结果
    int error_ = 0;  // 存储错误码
};

// =========================================================================
// Socket 类的完整实现
// =========================================================================

class Socket {
public:
    Socket(IoContext& io_context) : io_context_(io_context), fd_(-1) {
        fd_ = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
        if (fd_ == -1) {
            throw std::runtime_error("Failed to create socket");
        }
    }

    Socket(IoContext& io_context, int connected_fd) : io_context_(io_context), fd_(connected_fd) {
        // Ensure the accepted socket is non-blocking
        int flags = fcntl(fd_, F_GETFL, 0);
        if (flags == -1) throw std::runtime_error("fcntl F_GETFL failed");
        if (fcntl(fd_, F_SETFL, flags | O_NONBLOCK) == -1) {
            throw std::runtime_error("fcntl F_SETFL O_NONBLOCK failed");
        }
    }

    ~Socket() {
        if (fd_ != -1) {
            // 在IoContext中取消注册所有与此fd相关的操作
            // 这里简化,实际需要更精细的清理
            close(fd_);
            fd_ = -1;
        }
    }

    // 禁用拷贝,允许移动
    Socket(const Socket&) = delete;
    Socket& operator=(const Socket&) = delete;
    Socket(Socket&& other) noexcept : io_context_(other.io_context_), fd_(std::exchange(other.fd_, -1)) {}
    Socket& operator=(Socket&& other) noexcept {
        if (this != &other) {
            if (fd_ != -1) close(fd_);
            fd_ = std::exchange(other.fd_, -1);
        }
        return *this;
    }

    int fd() const { return fd_; }

    SocketAwaiter async_read(std::span<char> buffer) {
        return SocketAwaiter(io_context_, fd_, IoOperation::Read, buffer);
    }

    SocketAwaiter async_write(std::span<char> buffer) {
        return SocketAwaiter(io_context_, fd_, IoOperation::Write, buffer);
    }

    SocketAwaiter async_connect(const std::string& host, int port) {
        sockaddr_in addr{};
        addr.sin_family = AF_INET;
        addr.sin_port = htons(port);
        if (inet_pton(AF_INET, host.c_str(), &addr.sin_addr) <= 0) {
            throw std::runtime_error("Invalid address");
        }
        return SocketAwaiter(io_context_, fd_, IoOperation::Connect, {}, &addr, sizeof(addr));
    }
};

// =========================================================================
// TcpListener 的实现 (用于接受连接)
// =========================================================================

class TcpListenerAwaiter {
public:
    TcpListenerAwaiter(IoContext& io_context, int listen_fd)
        : io_context_(io_context), listen_fd_(listen_fd) {}

    bool await_ready() const noexcept { return false; } // 总是挂起等待新连接

    void await_suspend(std::coroutine_handle<> handle) {
        coro_handle_ = handle;
        io_context_.register_operation(listen_fd_, IoOperation::Accept, handle);
    }

    // 返回新接受的客户端Socket
    Socket await_resume() {
        sockaddr_in client_addr{};
        socklen_t client_len = sizeof(client_addr);
        int client_fd = ::accept4(listen_fd_, (sockaddr*)&client_addr, &client_len, SOCK_NONBLOCK | SOCK_CLOEXEC);
        if (client_fd == -1) {
            throw std::runtime_error("Accept failed: " + std::string(strerror(errno)));
        }
        return Socket(io_context_, client_fd);
    }

private:
    IoContext& io_context_;
    int listen_fd_;
    std::coroutine_handle<> coro_handle_;
};

class TcpListener {
public:
    TcpListener(IoContext& io_context, int port) : io_context_(io_context), listen_fd_(-1) {
        listen_fd_ = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
        if (listen_fd_ == -1) {
            throw std::runtime_error("Failed to create listener socket");
        }

        int opt = 1;
        if (setsockopt(listen_fd_, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) == -1) {
            close(listen_fd_);
            throw std::runtime_error("setsockopt SO_REUSEADDR failed");
        }

        sockaddr_in addr{};
        addr.sin_family = AF_INET;
        addr.sin_addr.s_addr = INADDR_ANY;
        addr.sin_port = htons(port);

        if (bind(listen_fd_, (sockaddr*)&addr, sizeof(addr)) == -1) {
            close(listen_fd_);
            throw std::runtime_error("Bind failed: " + std::string(strerror(errno)));
        }

        if (listen(listen_fd_, SOMAXCONN) == -1) {
            close(listen_fd_);
            throw std::runtime_error("Listen failed: " + std::string(strerror(errno)));
        }
    }

    ~TcpListener() {
        if (listen_fd_ != -1) {
            close(listen_fd_);
        }
    }

    TcpListenerAwaiter async_accept() {
        return TcpListenerAwaiter(io_context_, listen_fd_);
    }

private:
    IoContext& io_context_;
    int listen_fd_;
};

// =========================================================================
// 协程化的网络服务器和客户端示例
// =========================================================================

Task<void> echo_session(Socket client_socket) {
    std::vector<char> buffer(1024);
    try {
        while (true) {
            int bytes_read = co_await client_socket.async_read(buffer);
            if (bytes_read == 0) { // 客户端关闭连接
                std::cout << "Client disconnected." << std::endl;
                break;
            }
            std::cout << "Server received " << bytes_read << " bytes: " << std::string(buffer.data(), bytes_read) << std::endl;

            // 回写数据
            co_await client_socket.async_write(std::span<char>(buffer.data(), bytes_read));
            std::cout << "Server echoed " << bytes_read << " bytes." << std::endl;
        }
    } catch (const std::exception& e) {
        std::cerr << "Session error: " << e.what() << std::endl;
    }
    co_return;
}

Task<void> server(IoContext& io_context, int port) {
    TcpListener listener(io_context, port);
    std::cout << "Server listening on port " << port << std::endl;

    try {
        while (true) {
            Socket client_socket = co_await listener.async_accept();
            std::cout << "Accepted new client on fd " << client_socket.fd() << std::endl;
            // 启动一个新的协程来处理客户端连接
            echo_session(std::move(client_socket)); // 注意:这里没有co_await,直接启动
        }
    } catch (const std::exception& e) {
        std::cerr << "Server error: " << e.what() << std::endl;
    }
    co_return;
}

Task<void> client(IoContext& io_context, const std::string& host, int port, const std::string& message) {
    Socket sock(io_context);
    try {
        std::cout << "Client connecting to " << host << ":" << port << std::endl;
        co_await sock.async_connect(host, port);
        std::cout << "Client connected." << std::endl;

        // 发送消息
        co_await sock.async_write(std::span<const char>(message.data(), message.size()));
        std::cout << "Client sent: " << message << std::endl;

        // 接收回显
        std::vector<char> buffer(1024);
        int bytes_read = co_await sock.async_read(buffer);
        std::cout << "Client received " << bytes_read << " bytes: " << std::string(buffer.data(), bytes_read) << std::endl;

    } catch (const std::exception& e) {
        std::cerr << "Client error: " << e.what() << std::endl;
    }
    co_return;
}

int main(int argc, char* argv[]) {
    if (argc < 2) {
        std::cerr << "Usage: " << argv[0] << " <server|client> [args...]" << std::endl;
        return 1;
    }

    IoContext io_context;

    if (std::string(argv[1]) == "server") {
        int port = 8080;
        if (argc > 2) port = std::stoi(argv[2]);
        server(io_context, port); // 启动服务器协程
        io_context.run(); // 运行事件循环,永不停止
    } else if (std::string(argv[1]) == "client") {
        if (argc < 5) {
            std::cerr << "Usage: " << argv[0] << " client <host> <port> <message>" << std::endl;
            return 1;
        }
        std::string host = argv[2];
        int port = std::stoi(argv[3]);
        std::string message = argv[4];
        client(io_context, host, port, message); // 启动客户端协程
        // 客户端通常只运行一次,所以我们让io_context运行并等待所有协程完成
        // 对于这个简单的例子,我们只运行一次epoll_wait来确保客户端操作完成
        // 真实场景下,客户端io_context也需要持续运行直到所有任务完成或退出
        io_context.run_one(); // 运行一次事件循环
    } else {
        std::cerr << "Unknown command: " << argv[1] << std::endl;
        return 1;
    }

    return 0;
}

代码解释:

  1. Task<T> 类型:这是一个简单的协程返回类型,封装了 promise_typepromise_type 负责在协程完成时(通过 final_suspend)恢复等待此 Task 的协程(通过 continuation 句柄)。initial_suspend 保证协程创建后立即挂起,等待被调度。get_result 用于从 Task 获取返回值或传播异常。
  2. IoContext
    • 使用 epoll 作为 I/O 多路复用机制。
    • pending_operations_ 维护了一个映射:fd -> operation_type -> PendingOperationPendingOperation 存储了等待该 I/O 事件的协程句柄以及操作结果。
    • register_operation:当 SocketAwaiter 挂起时,它会调用此方法将当前协程的句柄注册到 IoContext。同时,fd 会被添加到 epoll 监听。
    • run_one():执行一次 epoll_wait,处理所有就绪的 I/O 事件。当事件就绪时,它会从 pending_operations_ 中查找对应的协程句柄,并调用 handle.resume() 恢复协程。
    • get_op_resultSocketAwaiter 恢复后,需要从 IoContext 获取操作结果。
  3. SocketAwaiter
    • 构造函数:接收 IoContext 引用、文件描述符 fd、操作类型 op 和可选的缓冲区 buffer 或连接地址 addr
    • await_ready():这里为了简化,我们总是返回 false,强制协程挂起。在生产级代码中,你可以尝试一次非阻塞 I/O 操作,如果立即完成则返回 true,避免不必要的挂起和上下文切换。
    • await_suspend(std::coroutine_handle<> handle)
      • 存储当前协程的句柄 handle
      • 调用 io_context_.register_operation() 将句柄注册到事件循环。
      • 对于 Connect 操作,会立即尝试非阻塞 connect()。如果 connect 立即失败(非 EINPROGRESS),则会设置 error_,并在 await_resume 中抛出。如果返回 EINPROGRESS,表示连接正在进行中,需要等待 EPOLLOUT 事件。
    • await_resume()
      • IoContext 恢复协程时,此方法被调用。
      • 它首先从 IoContext 获取操作结果(虽然在本简化版本中 IoContext 只是恢复了协程,实际的 read/write 操作是在 await_resume 中再次执行非阻塞调用来获取结果)。
      • 对于 Connect 操作,会检查 getsockopt(SO_ERROR) 来确认连接结果。
      • 对于 Read/Write 操作,会再次调用 ::recv::send 来实际完成数据传输并获取字节数。
      • 如果操作失败,它会抛出 std::runtime_error
  4. Socket
    • 封装了文件描述符。
    • async_readasync_writeasync_connect 方法都返回一个 SocketAwaiter 实例,使其可以直接被 co_await
  5. TcpListenerTcpListenerAwaiter
    • TcpListener 封装了 bindlisten 操作。
    • TcpListenerAwaiterasync_accept 方法返回的可等待对象,它在 await_suspend 中注册 EPOLLIN 事件到监听套接字,在 await_resume 中调用 ::accept4 来接受新连接并返回一个 Socket 对象。
  6. server()client() 协程函数
    • 展示了如何使用 co_await 来编写异步的服务器和客户端逻辑。代码看起来几乎是同步的,非常直观。
    • echo_session 协程处理单个客户端的请求,读取数据并回写。

编译与运行:

在 Linux 环境下,你需要使用支持 C++20 协程的 GCC 或 Clang 版本(例如 GCC 10+ 或 Clang 10+)。

g++ -std=c++20 -fcoroutines -Wall -O2 your_file_name.cpp -o coroutine_net

然后,你可以启动服务器:

./coroutine_net server 8080

并在另一个终端启动客户端:

./coroutine_net client 127.0.0.1 8080 "Hello, Coroutines!"

你将看到客户端发送消息,服务器接收并回显,客户端再接收回显,整个过程通过协程以同步流的方式优雅地完成。

错误处理、取消与资源管理

一个健壮的异步网络库需要仔细考虑错误处理、操作取消和资源管理。

错误处理

  • 异常传播:如示例所示,await_resume() 可以抛出异常。这些异常会沿着协程调用栈向上冒泡,直到被 try-catch 块捕获,或者最终在 promise_type::unhandled_exception() 中处理。我们的 Task 类型就包含了 exception_ptr 来捕获和重抛异常。
  • 错误码:在底层 IoContextSocketAwaiter 中,I/O 操作的失败通常通过错误码(如 errno)表示。在 await_resume() 中,我们应该检查这些错误码,并将其转换为 C++ 异常。

取消机制

取消一个挂起的协程是一个复杂但重要的功能,例如当操作超时或用户主动取消时。实现取消通常有几种方式:

  1. 返回 falseawait_suspend:如果 await_suspend 在注册操作后发现操作已被取消或立即失败,它可以返回 false,这样 co_await 表达式就不会真正挂起,而是立即调用 await_resume()
  2. std::coroutine_handle::destroy():外部可以持有协程句柄,并在需要时调用 destroy() 来销毁协程帧,这会跳过 await_resume() 和后续的协程逻辑。但需要确保资源得到正确释放。
  3. 协同取消SocketAwaiter 可以接收一个取消令牌。在 await_suspend 中,除了注册协程句柄,还可以将取消令牌注册到事件循环或一个取消管理器中。当取消请求到达时,取消管理器会通知 IoContext 移除对应的 I/O 监听,并恢复协程,让 await_resume 返回一个取消异常。

资源管理

  • RAII (Resource Acquisition Is Initialization):这是 C++ 的黄金法则。Socket 类在构造时打开文件描述符,在析构时关闭。IoContext 在构造时创建 epoll 实例,在析构时关闭。
  • 协程句柄生命周期std::coroutine_handledestroy() 方法负责释放协程帧占用的内存。确保在协程不再需要时调用它。我们的 Task 类型在析构函数中处理了 coro_handle.destroy()
  • IoContext 清理:当 SocketTcpListener 对象被销毁时,它们应该通知 IoContext 移除所有相关的 I/O 监听和挂起的协程句柄,防止野指针和资源泄露。

性能考量与高级技巧

优化 await_ready()

如前所述,await_ready() 的目的是避免不必要的挂起。如果一个 I/O 操作可以在不阻塞的情况下立即完成(例如,本地缓冲区中有数据可读,或发送缓冲区有空间可写),那么 await_ready() 返回 true 可以节省一次上下文切换的开销。这对于高吞吐量、低延迟的场景非常重要。

无堆分配的协程帧

默认情况下,协程帧可能在堆上分配。对于性能敏感的应用,可以通过自定义 promise_typenewdelete 操作符,或者通过提供自定义的内存分配器,来实现无堆分配(或使用自定义内存池)的协程帧,从而减少内存碎片和分配/释放开销。

批处理 I/O 与 io_uring

现代操作系统提供了更高效的 I/O 接口,如 Linux 的 io_uring。它允许一次提交多个 I/O 请求,并在一个事件中获取所有完成结果,极大地减少了系统调用次数和上下文切换。将 io_uring 与协程 awaiter 结合,可以实现极致的性能。一个 io_uringawaiter 会将协程句柄与 io_uring 的完成队列(Completion Queue)中的一个条目关联起来。

集成现有异步库

许多成熟的异步网络库(如 boost::asio)已经提供了强大的功能。我们完全可以将这些库的异步操作包装成 C++ 协程的 awaiter。例如,boost::asio::awaitable 便是 boost::asio 官方提供的协程集成方案,它内部也是通过自定义 awaiter 来实现的。这种方法可以让我们享受协程的编程便利性,同时利用成熟库的稳定性和功能。

C++ 异步编程的未来图景与实践价值

C++ 协程为异步编程带来了范式转变,使得我们能够以更直观、更可维护的方式编写复杂的并发和异步逻辑。通过手写 awaiter,我们不仅深入理解了协程的底层机制,更掌握了将协程与各种自定义异步源(如网络 I/O、文件 I/O、数据库操作、硬件交互)无缝集成的能力。

这种能力对于构建高性能、高并发的服务端应用,以及响应式、流畅的客户端应用都具有极其重要的价值。随着 C++ 标准的不断演进,协程相关的库支持和工具链将日益完善,其在异步编程领域的地位也将愈发巩固。掌握 awaiter 的自定义,是每一个 C++ 专家通向现代异步编程的必经之路。它让您能够根据具体需求,设计出最适合自己应用场景的挂起和恢复策略,真正做到对异步行为的精细控制。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐