基于FPGA的Qlearning强化学习模型设计指南
目录
随着人工智能技术的飞速发展,深度学习和强化学习已经在图像识别、自然语言处理、自动驾驶、机器人控制等领域取得了突破性进展。然而,传统的GPU和CPU平台在部署这些模型时,往往面临功耗高、延迟大、体积大等问题,难以满足边缘计算和实时推理的需求。FPGA凭借其高度并行的计算架构、可重构性、低功耗和低延迟等优势,成为部署AI模型的理想硬件平台。本文将以Q-Learning作为强化学习的代表,系统阐述如何在FPGA上实现这两种模型。文章将从数学原理出发,逐步分析每一个关键计算步骤,并给出相应的Verilog硬件描述语言实现代码,帮助读者建立从算法到硬件的完整映射关系。
1.Q-Learning基本原理
Q-Learning是一种无模型(Model-Free)的强化学习算法,属于时序差分(Temporal Difference, TD)学习方法。其核心思想是学习一个动作价值函数Q(s,a),表示在状态s下采取动作a所能获得的期望累积奖励。
1.1 Q值更新公式
Q-Learning的核心更新公式为:

其中:

将此公式展开:

1.2 ε-贪心策略
在选择动作时,Q-Learning通常采用ε-贪心(ε-greedy)策略来平衡探索和利用:

2.Q-Table的FPGA存储设计
Q-Learning需要维护一个Q表,存储所有状态-动作对的Q值。假设有Ns个状态和Na个动作,Q表的大小为Ns×Na。在FPGA中,Q表可用Block RAM实现:
module q_table #(
parameter NUM_STATES = 16,
parameter NUM_ACTIONS = 4,
parameter DATA_WIDTH = 16,
parameter ADDR_WIDTH = 6 // log2(16*4) = 6
)(
input wire clk,
input wire we, // 写使能
input wire [ADDR_WIDTH-1:0] addr_rd, // 读地址
input wire [ADDR_WIDTH-1:0] addr_wr, // 写地址
input wire [DATA_WIDTH-1:0] data_in, // 写入数据
output reg [DATA_WIDTH-1:0] data_out // 读出数据
);
// Q表存储:使用BRAM
reg [DATA_WIDTH-1:0] q_mem [0:NUM_STATES*NUM_ACTIONS-1];
// 初始化Q表为0
integer i;
initial begin
for (i = 0; i < NUM_STATES * NUM_ACTIONS; i = i + 1)
q_mem[i] = 0;
end
// 同步读写
always @(posedge clk) begin
if (we)
q_mem[addr_wr] <= data_in;
data_out <= q_mem[addr_rd];
end
endmodule
地址映射关系:对于状态s和动作a,Q表中的地址为:
addr=s×Na+a
// 地址计算模块
module addr_calc #(
parameter NUM_ACTIONS = 4
)(
input wire [3:0] state,
input wire [1:0] action,
output wire [5:0] addr
);
// 当NUM_ACTIONS为2的幂时,乘法可用移位实现
assign addr = (state << 2) + action; // state * 4 + action
endmodule
3. 求最大Q值模块
在Q值更新和动作选择中,都需要找到maxa′Q(st+1,a′)及其对应的动作。假设有4个动作:

对应的verilog设计如下:
module find_max_q #(
parameter NUM_ACTIONS = 4
)(
input wire signed [15:0] q_values [0:NUM_ACTIONS-1],
output reg signed [15:0] max_q,
output reg [1:0] best_action
);
integer i;
always @(*) begin
max_q = q_values[0];
best_action = 2'd0;
for (i = 1; i < NUM_ACTIONS; i = i + 1) begin
if (q_values[i] > max_q) begin
max_q = q_values[i];
best_action = i[1:0];
end
end
end
endmodule
4.TD误差计算
时序差分(TD)误差是Q-Learning更新的核心,定义为:

其中每一步运算都需要用定点数实现:
即:

对应的verilog设计如下:
module td_error_calc (
input wire signed [15:0] reward, // r_t (Q7.8)
input wire signed [15:0] gamma, // 折扣因子 (Q7.8), e.g., 0.9 = 230
input wire signed [15:0] max_q_next, // max Q(s_{t+1}, a')
input wire signed [15:0] q_current, // Q(s_t, a_t)
output wire signed [15:0] td_error // δ
);
// Step 1: gamma * max_q_next
wire signed [31:0] gamma_q_full;
wire signed [15:0] gamma_q;
assign gamma_q_full = gamma * max_q_next;
assign gamma_q = gamma_q_full[23:8]; // 截断回Q7.8
// Step 2: r + gamma * max_q_next
wire signed [15:0] target;
assign target = reward + gamma_q;
// Step 3: td_error = target - q_current
assign td_error = target - q_current;
endmodule
5.Q值更新模块
完整的Q值更新公式:
![]()
其中α是学习率,δ是TD误差。
module q_update (
input wire signed [15:0] q_old, // 当前Q值
input wire signed [15:0] alpha, // 学习率 (Q7.8), e.g., 0.1 = 26
input wire signed [15:0] td_error, // TD误差
output wire signed [15:0] q_new // 更新后的Q值
);
// alpha * td_error
wire signed [31:0] update_full;
wire signed [15:0] update_step;
assign update_full = alpha * td_error;
assign update_step = update_full[23:8]; // 截断回Q7.8
// Q_new = Q_old + alpha * td_error
assign q_new = q_old + update_step;
endmodule
6.ε-贪心策略的FPGA实现
ε-贪心策略需要一个随机数生成器。在FPGA中,通常使用线性反馈移位寄存器(LFSR)来生成伪随机数:
module lfsr_random #(
parameter WIDTH = 16
)(
input wire clk,
input wire rst_n,
input wire [WIDTH-1:0] seed,
output wire [WIDTH-1:0] rand_out
);
reg [WIDTH-1:0] lfsr_reg;
// 16位LFSR,反馈多项式:x^16 + x^14 + x^13 + x^11 + 1
wire feedback;
assign feedback = lfsr_reg[15] ^ lfsr_reg[13] ^ lfsr_reg[12] ^ lfsr_reg[10];
always @(posedge clk or negedge rst_n) begin
if (!rst_n)
lfsr_reg <= seed;
else
lfsr_reg <= {lfsr_reg[WIDTH-2:0], feedback};
end
assign rand_out = lfsr_reg;
endmodule
ε-贪心动作选择模块:
module epsilon_greedy #(
parameter NUM_ACTIONS = 4
)(
input wire clk,
input wire rst_n,
input wire enable,
input wire signed [15:0] epsilon, // ε值 (Q7.8), e.g., 0.1 = 26
input wire [15:0] rand_value, // 随机数
input wire [1:0] best_action, // argmax Q(s,a)
output reg [1:0] selected_action,
output reg action_valid
);
// 将随机数映射到[0, 1)范围(取高8位作为Q0.8)
wire [7:0] rand_normalized;
assign rand_normalized = rand_value[15:8];
// epsilon的小数部分
wire [7:0] eps_frac;
assign eps_frac = epsilon[7:0];
always @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
selected_action <= 2'd0;
action_valid <= 1'b0;
end else if (enable) begin
if (rand_normalized < eps_frac) begin
// 探索:随机选择动作
selected_action <= rand_value[1:0]; // 用随机数低2位
end else begin
// 利用:选择最佳动作
selected_action <= best_action;
end
action_valid <= 1'b1;
end else begin
action_valid <= 1'b0;
end
end
endmodule
7. Q-Learning完整控制器
综上所述,整个系统的流程图如下:

将以上模块整合成一个完整的Q-Learning控制器,通过状态机管理整个学习流程:
module q_learning_controller #(
parameter NUM_STATES = 16,
parameter NUM_ACTIONS = 4,
parameter DATA_WIDTH = 16
)(
input wire clk,
input wire rst_n,
input wire start_episode,
input wire [3:0] current_state,
input wire signed [15:0] reward,
input wire [3:0] next_state,
input wire episode_done,
output reg [1:0] action_out,
output reg action_valid,
output reg update_done
);
// 参数(Q7.8格式)
localparam signed [15:0] ALPHA = 16'sd26; // 0.1
localparam signed [15:0] GAMMA = 16'sd230; // 0.9
localparam signed [15:0] EPSILON = 16'sd26; // 0.1
// 状态机状态
localparam S_IDLE = 4'd0;
localparam S_READ_Q_ALL = 4'd1;
localparam S_WAIT_READ = 4'd2;
localparam S_SELECT_ACTION = 4'd3;
localparam S_WAIT_ENV = 4'd4;
localparam S_READ_NEXT_Q = 4'd5;
localparam S_WAIT_NEXT = 4'd6;
localparam S_FIND_MAX = 4'd7;
localparam S_COMPUTE_TD = 4'd8;
localparam S_UPDATE_Q = 4'd9;
localparam S_WRITE_Q = 4'd10;
localparam S_DONE = 4'd11;
reg [3:0] fsm_state;
reg [1:0] action_idx; // 动作遍历索引
reg signed [15:0] q_vals_current [0:NUM_ACTIONS-1];
reg signed [15:0] q_vals_next [0:NUM_ACTIONS-1];
// Q表接口信号
reg q_we;
reg [5:0] q_addr_rd, q_addr_wr;
reg signed [15:0] q_data_in;
wire signed [15:0] q_data_out;
// LFSR随机数
wire [15:0] rand_val;
// 内部计算信号
wire signed [15:0] max_q_next;
wire [1:0] best_action;
wire signed [15:0] td_err;
wire signed [15:0] q_new;
// 实例化Q表
q_table #(
.NUM_STATES(NUM_STATES),
.NUM_ACTIONS(NUM_ACTIONS)
) u_qtable (
.clk(clk), .we(q_we),
.addr_rd(q_addr_rd), .addr_wr(q_addr_wr),
.data_in(q_data_in), .data_out(q_data_out)
);
// 实例化LFSR
lfsr_random u_lfsr (
.clk(clk), .rst_n(rst_n),
.seed(16'hACE1), .rand_out(rand_val)
);
// 实例化最大值查找
find_max_q u_find_max (
.q_values(q_vals_next),
.max_q(max_q_next),
.best_action(best_action)
);
// 实例化TD误差计算
td_error_calc u_td (
.reward(reward), .gamma(GAMMA),
.max_q_next(max_q_next),
.q_current(q_vals_current[action_out]),
.td_error(td_err)
);
// 实例化Q值更新
q_update u_qupdate (
.q_old(q_vals_current[action_out]),
.alpha(ALPHA), .td_error(td_err),
.q_new(q_new)
);
// 主状态机
always @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
fsm_state <= S_IDLE;
action_idx <= 0;
action_out <= 0;
action_valid <= 0;
update_done <= 0;
q_we <= 0;
end else begin
case (fsm_state)
S_IDLE: begin
update_done <= 0;
q_we <= 0;
if (start_episode) begin
action_idx <= 0;
fsm_state <= S_READ_Q_ALL;
end
end
S_READ_Q_ALL: begin
// 逐个读取当前状态的所有Q值
q_addr_rd <= (current_state << 2) + action_idx;
fsm_state <= S_WAIT_READ;
end
S_WAIT_READ: begin
q_vals_current[action_idx] <= q_data_out;
if (action_idx == NUM_ACTIONS - 1) begin
action_idx <= 0;
fsm_state <= S_SELECT_ACTION;
end else begin
action_idx <= action_idx + 1;
fsm_state <= S_READ_Q_ALL;
end
end
S_SELECT_ACTION: begin
// ε-贪心选择
if (rand_val[15:8] < EPSILON[7:0])
action_out <= rand_val[1:0];
else begin
// 找当前状态最大Q值对应动作
// 简化实现:遍历比较
action_out <= 0;
if (q_vals_current[1] > q_vals_current[0])
action_out <= 1;
if (q_vals_current[2] > q_vals_current[action_out])
action_out <= 2;
if (q_vals_current[3] > q_vals_current[action_out])
action_out <= 3;
end
action_valid <= 1;
fsm_state <= S_WAIT_ENV;
end
S_WAIT_ENV: begin
action_valid <= 0;
// 等待环境返回reward和next_state
// 此处简化,假设下一周期就能获取
action_idx <= 0;
fsm_state <= S_READ_NEXT_Q;
end
S_READ_NEXT_Q: begin
q_addr_rd <= (next_state << 2) + action_idx;
fsm_state <= S_WAIT_NEXT;
end
S_WAIT_NEXT: begin
q_vals_next[action_idx] <= q_data_out;
if (action_idx == NUM_ACTIONS - 1) begin
fsm_state <= S_FIND_MAX;
end else begin
action_idx <= action_idx + 1;
fsm_state <= S_READ_NEXT_Q;
end
end
S_FIND_MAX: begin
// find_max_q组合逻辑已计算好max_q_next
fsm_state <= S_COMPUTE_TD;
end
S_COMPUTE_TD: begin
// td_error_calc组合逻辑已计算好td_err
fsm_state <= S_UPDATE_Q;
end
S_UPDATE_Q: begin
// q_update组合逻辑已计算好q_new
fsm_state <= S_WRITE_Q;
end
S_WRITE_Q: begin
q_we <= 1;
q_addr_wr <= (current_state << 2) + action_out;
q_data_in <= q_new;
fsm_state <= S_DONE;
end
S_DONE: begin
q_we <= 0;
update_done <= 1;
fsm_state <= S_IDLE;
end
default: fsm_state <= S_IDLE;
endcase
end
end
endmodule
8.ε衰减机制
在Q-Learning训练过程中,ε值通常需要逐步衰减,从较多探索逐渐过渡到更多利用:
![]()
其中ϵdecay通常为0.995或0.99。
module epsilon_decay (
input wire clk,
input wire rst_n,
input wire decay_trigger, // 触发衰减
input wire signed [15:0] decay_factor, // 衰减因子 (Q7.8),e.g., 0.995=255
input wire signed [15:0] epsilon_min, // 最小epsilon
output reg signed [15:0] epsilon // 当前epsilon
);
localparam signed [15:0] EPSILON_INIT = 16'sd256; // 1.0
always @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
epsilon <= EPSILON_INIT;
end else if (decay_trigger) begin
// epsilon = epsilon * decay_factor
// 注意:两个Q7.8相乘后右移8位
reg signed [31:0] new_eps;
new_eps = (epsilon * decay_factor) >>> 8;
// 下限约束
if (new_eps[15:0] < epsilon_min)
epsilon <= epsilon_min;
else
epsilon <= new_eps[15:0];
end
end
endmodule
9.总结
并行Q值读取:使用多端口RAM或将Q表分成多个Bank,允许同时读取一个状态对应的所有动作的Q值,从而将Q值读取从多个周期缩短到单个周期。
查找表加速:对于状态空间和动作空间较小的问题,可以将整个Q表分布在FPGA的分布式RAM(LUT RAM)中,实现单周期读写。
流水线化:将TD误差计算、Q值更新等步骤进行流水线化处理,使得在更新一个状态-动作对的同时,可以开始下一个状态的Q值读取。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)