Research in Brain-inspired Computing [5]-双人乒乓球
文章目录
SNN-based Pong Demo
Mathematical Principles and Network Model of the SNN-based Pong Demo
1. Spiking Neural Network (SNN) Overview
SNNs are a class of neural networks that more closely mimic biological neurons. Instead of continuous activation values, neurons communicate via discrete spikes (action potentials). Information is encoded in the timing and rate of spikes. The model used here is a leaky integrate-and-fire (LIF) variant extended with multiple thresholds (MSF – Multi‑Spike‑Factor), allowing a neuron to emit multiple spikes in a single time step based on the input current.
2. MSF Neuron Model
The MSF neuron is defined by the following equations:
2.1 Membrane Potential
The membrane potential v v v is computed from the total input current I ext I_{\text{ext}} Iext (which combines external input, recurrent feedback, and bias) using the steady‑state approximation of the LIF model:
v = V rest + R m ⋅ I ext v = V_{\text{rest}} + R_m \cdot I_{\text{ext}} v=Vrest+Rm⋅Iext
where:
- V rest = − 70 mV V_{\text{rest}} = -70\ \text{mV} Vrest=−70 mV is the resting potential.
- R m = 10 M Ω R_m = 10\ \text{M}\Omega Rm=10 MΩ is the membrane resistance.
This simplification assumes that the membrane potential reaches equilibrium within the simulation time step Δ t = 1 ms \Delta t = 1\ \text{ms} Δt=1 ms, which is much shorter than the membrane time constant τ m = 10 ms \tau_m = 10\ \text{ms} τm=10 ms.
2.2 Multiple Thresholds
The neuron has D = 4 D = 4 D=4 thresholds:
θ d = V thresh_base + d ⋅ Δ θ , d = 0 , 1 , … , D − 1 \theta_d = V_{\text{thresh\_base}} + d \cdot \Delta \theta, \quad d = 0,1,\dots,D-1 θd=Vthresh_base+d⋅Δθ,d=0,1,…,D−1
where:
- V thresh_base = − 55 mV V_{\text{thresh\_base}} = -55\ \text{mV} Vthresh_base=−55 mV
- Δ θ = 1.0 mV \Delta \theta = 1.0\ \text{mV} Δθ=1.0 mV
The number of spikes generated in one time step is the count of thresholds that are exceeded:
s = max { d ∣ v ≥ θ d } s = \max\{d \mid v \ge \theta_d\} s=max{d∣v≥θd}
If v v v is very high, it is clipped to θ D − 1 + 10 \theta_{D-1} + 10 θD−1+10 to prevent excessive spikes.
3. Network Architecture
The network is a recurrent spiking neural network with the following layers (all fully connected):
- Input layer: 6 neurons – receives normalized game state.
- Hidden layer 1: 20 neurons – includes recurrent connections (self‑loops).
- Hidden layer 2: 10 neurons.
- Output layer: 3 neurons – each corresponds to one action: up, down, or stay.
All weights and biases are stored in the model file.
3.1 Input Encoding
The game state is normalized to the range [ − 1 , 1 ] [-1, 1] [−1,1] or [ 0 , 1 ] [0, 1] [0,1] and fed as a 6‑dimensional vector to the input layer. The inputs (for player 0, the AI) are:
x 0 = ball_x x 1 = ball_y x 2 = ball_vx / BALL_SPEED x 3 = ball_vy / BALL_SPEED x 4 = own_paddle_y / FIELD_HEIGHT x 5 = opp_paddle_y / FIELD_HEIGHT \begin{aligned} x_0 &= \text{ball\_x} \\ x_1 &= \text{ball\_y} \\ x_2 &= \text{ball\_vx} / \text{BALL\_SPEED} \\ x_3 &= \text{ball\_vy} / \text{BALL\_SPEED} \\ x_4 &= \text{own\_paddle\_y} / \text{FIELD\_HEIGHT} \\ x_5 &= \text{opp\_paddle\_y} / \text{FIELD\_HEIGHT} \end{aligned} x0x1x2x3x4x5=ball_x=ball_y=ball_vx/BALL_SPEED=ball_vy/BALL_SPEED=own_paddle_y/FIELD_HEIGHT=opp_paddle_y/FIELD_HEIGHT
3.2 Forward Propagation
At each simulation step, the network computes the activity of each layer sequentially.
Hidden layer 1:
The synaptic current I h 1 , i I_{h1,i} Ih1,i for neuron i i i is:
I h 1 , i = b i (h1) + ∑ j = 0 N in − 1 w i j (in→h1) x j + ∑ k = 0 N h 1 − 1 w i k (fb) s k (h1,prev) I_{h1,i} = b_i^{\text{(h1)}} + \sum_{j=0}^{N_{\text{in}}-1} w_{ij}^{\text{(in→h1)}} x_j + \sum_{k=0}^{N_{h1}-1} w_{ik}^{\text{(fb)}} s_{k}^{\text{(h1,prev)}} Ih1,i=bi(h1)+j=0∑Nin−1wij(in→h1)xj+k=0∑Nh1−1wik(fb)sk(h1,prev)
where:
- b i (h1) b_i^{\text{(h1)}} bi(h1) is the bias.
- w i j (in→h1) w_{ij}^{\text{(in→h1)}} wij(in→h1) are input weights.
- w i k (fb) w_{ik}^{\text{(fb)}} wik(fb) are recurrent weights (feedback).
- s k (h1,prev) s_{k}^{\text{(h1,prev)}} sk(h1,prev) is the spike count of neuron k k k from the previous time step.
The spike count s h 1 , i s_{h1,i} sh1,i is computed using the MSF neuron model on I h 1 , i I_{h1,i} Ih1,i.
Hidden layer 2:
I h 2 , i = b i (h2) + ∑ j = 0 N h 1 − 1 w i j (h1→h2) s h 1 , j I_{h2,i} = b_i^{\text{(h2)}} + \sum_{j=0}^{N_{h1}-1} w_{ij}^{\text{(h1→h2)}} s_{h1,j} Ih2,i=bi(h2)+j=0∑Nh1−1wij(h1→h2)sh1,j
Output layer:
I out , i = b i (out) + ∑ j = 0 N h 2 − 1 w i j (h2→out) s h 2 , j I_{\text{out},i} = b_i^{\text{(out)}} + \sum_{j=0}^{N_{h2}-1} w_{ij}^{\text{(h2→out)}} s_{h2,j} Iout,i=bi(out)+j=0∑Nh2−1wij(h2→out)sh2,j
The spike counts of the output neurons are then computed.
3.3 Action Selection
The output layer has three neurons, each representing one action:
- Neuron 0 → move up
- Neuron 1 → move down
- Neuron 2 → stay
The chosen action is the index of the output neuron with the highest spike count:
action = arg max i ∈ { 0 , 1 , 2 } s out , i \text{action} = \arg\max_{i \in \{0,1,2\}} s_{\text{out},i} action=argi∈{0,1,2}maxsout,i
If multiple neurons have the same spike count, the first one (lowest index) is selected.
4. Synaptic Weights and Biases
All weights and biases are real numbers initially drawn from a uniform distribution in [ − 5 , 5 ] [-5,5] [−5,5]. They are optimized during training (here, loaded from a pre‑trained file). The total number of parameters is:
N params = N in ⋅ N h 1 ( input→h1 ) + N h 1 ⋅ N h 1 ( recurrent ) + N h 1 ⋅ N h 2 ( h1→h2 ) + N h 2 ⋅ N out ( h2→out ) + N h 1 + N h 2 + N out ( biases ) \begin{aligned} N_{\text{params}} = &\; N_{\text{in}} \cdot N_{h1} \quad (\text{input→h1})\\ &+ N_{h1} \cdot N_{h1} \quad (\text{recurrent})\\ &+ N_{h1} \cdot N_{h2} \quad (\text{h1→h2})\\ &+ N_{h2} \cdot N_{\text{out}} \quad (\text{h2→out})\\ &+ N_{h1} + N_{h2} + N_{\text{out}} \quad (\text{biases}) \end{aligned} Nparams=Nin⋅Nh1(input→h1)+Nh1⋅Nh1(recurrent)+Nh1⋅Nh2(h1→h2)+Nh2⋅Nout(h2→out)+Nh1+Nh2+Nout(biases)
With the given numbers ( N in = 6 N_{\text{in}}=6 Nin=6, N h 1 = 20 N_{h1}=20 Nh1=20, N h 2 = 10 N_{h2}=10 Nh2=10, N out = 3 N_{\text{out}}=3 Nout=3), the total is 20 ⋅ 6 + 20 ⋅ 20 + 10 ⋅ 20 + 3 ⋅ 10 + 20 + 10 + 3 = 120 + 400 + 200 + 30 + 33 = 783 20\cdot6 + 20\cdot20 + 10\cdot20 + 3\cdot10 + 20+10+3 = 120+400+200+30+33 = 783 20⋅6+20⋅20+10⋅20+3⋅10+20+10+3=120+400+200+30+33=783 parameters.
5. Temporal Dynamics
The network is recurrent: hidden layer 1 receives its own previous spikes as input. This introduces memory and allows the network to integrate information over time, which is essential for tracking ball movement and paddle position.
The simulation is discrete with a time step of 1 ms 1\ \text{ms} 1 ms, matching the SNN simulation granularity. The environment physics (ball and paddle movement) also updates at this rate (but with a different physical time constant). The agent processes one game state per environment step, producing an action for that step.
6. Training and Evaluation
The weights were trained using a co‑evolutionary genetic algorithm (not shown in the demo program). The training process involved evaluating individuals by playing against other agents and against an elite pool, with fitness based on win rate, score difference, hit count, and paddle movement (to encourage active behavior). The trained model is saved as a text file listing all parameters line by line.
The demo program simply loads the parameters, builds the network, and uses it to play Pong. The mathematical model remains exactly the same as during training, enabling the learned policy to be deployed.
7. Summary of Mathematical Equations
| Component | Equation |
|---|---|
| Membrane potential | v = V rest + R m I ext v = V_{\text{rest}} + R_m I_{\text{ext}} v=Vrest+RmIext |
| Spike count | s = max { d ∣ v ≥ θ d } , θ d = V base + d Δ θ s = \max\{d \mid v \ge \theta_d\},\quad \theta_d = V_{\text{base}} + d\Delta\theta s=max{d∣v≥θd},θd=Vbase+dΔθ |
| Hidden layer 1 current | I h 1 , i = b i + ∑ j w i j (in) x j + ∑ k w i k (fb) s k (prev) I_{h1,i} = b_i + \sum_j w_{ij}^{\text{(in)}} x_j + \sum_k w_{ik}^{\text{(fb)}} s_{k}^{\text{(prev)}} Ih1,i=bi+∑jwij(in)xj+∑kwik(fb)sk(prev) |
| Hidden layer 2 current | I h 2 , i = b i + ∑ j w i j (h1→h2) s h 1 , j I_{h2,i} = b_i + \sum_j w_{ij}^{\text{(h1→h2)}} s_{h1,j} Ih2,i=bi+∑jwij(h1→h2)sh1,j |
| Output layer current | I out , i = b i + ∑ j w i j (h2→out) s h 2 , j I_{\text{out},i} = b_i + \sum_j w_{ij}^{\text{(h2→out)}} s_{h2,j} Iout,i=bi+∑jwij(h2→out)sh2,j |
| Action selection | action = arg max i s out , i \text{action} = \arg\max_i s_{\text{out},i} action=argmaxisout,i |
These equations define a compact, biologically‑plausible spiking neural network capable of learning a reactive policy for the Pong game.
c++ 程序
/**
* 文件: pong_demo.cpp
* 描述: 乒乓球 (Pong) 双人对战 + 循环MSF脉冲神经网络 演示版
* 加载已训练好的模型文件,支持实时控制台图形演示和批量评估。
* 若模型文件缺失或无效,则自动使用随机权重并警告。
* 编译: g++ -O3 -std=c++11 pong_demo.cpp -o pong_demo -lm
* 运行: ./pong_demo [模型文件路径]
* 默认模型文件: best_pong_model_adv_dyn.txt
*/
#include <iostream>
#include <vector>
#include <random>
#include <cmath>
#include <chrono>
#include <algorithm>
#include <fstream>
#include <numeric>
#include <cstring>
#include <unistd.h>
#include <termios.h>
#include <fcntl.h>
#include <iomanip>
using namespace std;
// ==================== 游戏环境常量 ====================
constexpr double PADDLE_SPEED = 0.1;
constexpr double BALL_SPEED = 0.03;
constexpr double PADDLE_HEIGHT = 0.2;
constexpr double FIELD_WIDTH = 1.0;
constexpr double FIELD_HEIGHT = 1.0;
constexpr int MAX_STEPS_PONG = 500;
constexpr int WIN_SCORE = 10;
constexpr double INIT_RANDOM_RANGE = 0.05;
// ==================== SNN参数 ====================
constexpr int D = 4;
constexpr double V_REST = -70.0;
constexpr double V_RESET = -75.0;
constexpr double V_THRESH_BASE = -55.0;
constexpr double THRESH_INTERVAL = 1.0;
constexpr double TAU_M = 10.0;
constexpr double R_M = 10.0;
constexpr double DT = 1.0;
constexpr int N_INPUT = 6;
constexpr int N_HIDDEN1 = 20;
constexpr int N_HIDDEN2 = 10;
constexpr int N_OUTPUT = 3;
constexpr double W_MAX = 5.0;
// ==================== 随机数生成器 ====================
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
uniform_real_distribution<double> uniform_real(-W_MAX, W_MAX);
uniform_int_distribution<int> uniform_coin(0, 1);
uniform_real_distribution<double> uniform_pos(-INIT_RANDOM_RANGE, INIT_RANDOM_RANGE);
// ==================== MSF神经元类 ====================
struct MSFNeuron {
vector<double> thresholds;
MSFNeuron() {
thresholds.resize(D);
for (int d = 0; d < D; ++d) thresholds[d] = V_THRESH_BASE + d * THRESH_INTERVAL;
}
int step(double I_ext) {
double v = V_REST + R_M * I_ext;
if (v > V_THRESH_BASE + (D-1)*THRESH_INTERVAL + 10) v = V_THRESH_BASE + (D-1)*THRESH_INTERVAL + 10;
int spike_count = 0;
for (int d = 0; d < D; ++d) {
if (v >= thresholds[d]) spike_count++;
else break;
}
return spike_count;
}
};
// ==================== 循环SNN类 ====================
class RNN_SNN {
private:
vector<double> w_in_h1, w_fb_h1, w_h1_h2, w_h2_out, bias_h1, bias_h2, bias_out;
vector<MSFNeuron> hidden1, hidden2, output;
vector<int> prev_spike_h1, prev_spike_h2;
public:
RNN_SNN(const vector<double>& params) {
hidden1.resize(N_HIDDEN1); hidden2.resize(N_HIDDEN2); output.resize(N_OUTPUT);
size_t pos = 0;
w_in_h1.resize(N_HIDDEN1 * N_INPUT);
w_fb_h1.resize(N_HIDDEN1 * N_HIDDEN1);
w_h1_h2.resize(N_HIDDEN2 * N_HIDDEN1);
w_h2_out.resize(N_OUTPUT * N_HIDDEN2);
bias_h1.resize(N_HIDDEN1); bias_h2.resize(N_HIDDEN2); bias_out.resize(N_OUTPUT);
for (size_t i = 0; i < w_in_h1.size(); ++i) w_in_h1[i] = params[pos++];
for (size_t i = 0; i < w_fb_h1.size(); ++i) w_fb_h1[i] = params[pos++];
for (size_t i = 0; i < w_h1_h2.size(); ++i) w_h1_h2[i] = params[pos++];
for (size_t i = 0; i < w_h2_out.size(); ++i) w_h2_out[i] = params[pos++];
for (int i = 0; i < N_HIDDEN1; ++i) bias_h1[i] = params[pos++];
for (int i = 0; i < N_HIDDEN2; ++i) bias_h2[i] = params[pos++];
for (int i = 0; i < N_OUTPUT; ++i) bias_out[i] = params[pos++];
prev_spike_h1.assign(N_HIDDEN1, 0);
prev_spike_h2.assign(N_HIDDEN2, 0);
}
void reset_state() {
fill(prev_spike_h1.begin(), prev_spike_h1.end(), 0);
fill(prev_spike_h2.begin(), prev_spike_h2.end(), 0);
}
int forward(const vector<double>& input) {
vector<double> I_h1(N_HIDDEN1);
for (int i = 0; i < N_HIDDEN1; ++i) {
I_h1[i] = bias_h1[i];
for (int j = 0; j < N_INPUT; ++j) I_h1[i] += w_in_h1[i * N_INPUT + j] * input[j];
for (int j = 0; j < N_HIDDEN1; ++j) I_h1[i] += w_fb_h1[i * N_HIDDEN1 + j] * prev_spike_h1[j];
}
vector<int> spike_h1(N_HIDDEN1);
for (int i = 0; i < N_HIDDEN1; ++i) spike_h1[i] = hidden1[i].step(I_h1[i]);
vector<double> I_h2(N_HIDDEN2);
for (int i = 0; i < N_HIDDEN2; ++i) {
I_h2[i] = bias_h2[i];
for (int j = 0; j < N_HIDDEN1; ++j) I_h2[i] += w_h1_h2[i * N_HIDDEN1 + j] * spike_h1[j];
}
vector<int> spike_h2(N_HIDDEN2);
for (int i = 0; i < N_HIDDEN2; ++i) spike_h2[i] = hidden2[i].step(I_h2[i]);
vector<double> I_out(N_OUTPUT);
for (int i = 0; i < N_OUTPUT; ++i) {
I_out[i] = bias_out[i];
for (int j = 0; j < N_HIDDEN2; ++j) I_out[i] += w_h2_out[i * N_HIDDEN2 + j] * spike_h2[j];
}
vector<int> out_spikes(N_OUTPUT);
for (int i = 0; i < N_OUTPUT; ++i) out_spikes[i] = output[i].step(I_out[i]);
prev_spike_h1 = spike_h1;
prev_spike_h2 = spike_h2;
int best_action = 0;
int max_spike = out_spikes[0];
for (int i = 1; i < N_OUTPUT; ++i) {
if (out_spikes[i] > max_spike) {
max_spike = out_spikes[i];
best_action = i;
}
}
return best_action;
}
};
// ==================== 乒乓球环境 ====================
class Pong {
private:
double ball_x, ball_y, ball_vx, ball_vy, paddle1_y, paddle2_y;
int score1, score2, steps;
bool game_over;
void reset_ball() {
ball_x = 0.5; ball_y = 0.5;
uniform_real_distribution<double> dir(0.5, 1.5);
ball_vx = (uniform_coin(rng) == 0 ? -BALL_SPEED : BALL_SPEED);
ball_vy = (dir(rng) - 1.0) * BALL_SPEED;
}
public:
Pong() { reset(); }
void reset() {
ball_x = 0.5 + uniform_pos(rng); ball_y = 0.5 + uniform_pos(rng);
ball_x = max(0.1, min(0.9, ball_x)); ball_y = max(0.1, min(0.9, ball_y));
uniform_real_distribution<double> dir(0.5, 1.5);
ball_vx = (uniform_coin(rng) == 0 ? -BALL_SPEED : BALL_SPEED);
ball_vy = (dir(rng) - 1.0) * BALL_SPEED;
paddle1_y = (FIELD_HEIGHT - PADDLE_HEIGHT) / 2 + uniform_pos(rng);
paddle2_y = (FIELD_HEIGHT - PADDLE_HEIGHT) / 2 + uniform_pos(rng);
paddle1_y = max(0.0, min(FIELD_HEIGHT - PADDLE_HEIGHT, paddle1_y));
paddle2_y = max(0.0, min(FIELD_HEIGHT - PADDLE_HEIGHT, paddle2_y));
score1 = score2 = 0; steps = 0; game_over = false;
}
void step(int action1, int action2) {
if (game_over) return;
steps++;
const double move = PADDLE_SPEED;
if (action1 == 0) paddle1_y = max(0.0, paddle1_y - move);
if (action1 == 1) paddle1_y = min(FIELD_HEIGHT - PADDLE_HEIGHT, paddle1_y + move);
if (action2 == 0) paddle2_y = max(0.0, paddle2_y - move);
if (action2 == 1) paddle2_y = min(FIELD_HEIGHT - PADDLE_HEIGHT, paddle2_y + move);
ball_x += ball_vx; ball_y += ball_vy;
if (ball_y < 0) { ball_y = -ball_y; ball_vy = -ball_vy; }
if (ball_y > FIELD_HEIGHT) { ball_y = 2*FIELD_HEIGHT - ball_y; ball_vy = -ball_vy; }
if (ball_vx < 0 && ball_x <= 0) {
if (ball_y >= paddle1_y && ball_y <= paddle1_y + PADDLE_HEIGHT) {
ball_vx = -ball_vx;
double hit_pos = (ball_y - paddle1_y) / PADDLE_HEIGHT - 0.5;
ball_vy += hit_pos * BALL_SPEED * 0.5;
ball_vy = max(-BALL_SPEED*1.5, min(BALL_SPEED*1.5, ball_vy));
ball_x = 0.01;
} else { score2++; reset_ball(); }
}
else if (ball_vx > 0 && ball_x >= FIELD_WIDTH) {
if (ball_y >= paddle2_y && ball_y <= paddle2_y + PADDLE_HEIGHT) {
ball_vx = -ball_vx;
double hit_pos = (ball_y - paddle2_y) / PADDLE_HEIGHT - 0.5;
ball_vy += hit_pos * BALL_SPEED * 0.5;
ball_vy = max(-BALL_SPEED*1.5, min(BALL_SPEED*1.5, ball_vy));
ball_x = FIELD_WIDTH - 0.01;
} else { score1++; reset_ball(); }
}
if (steps >= MAX_STEPS_PONG || score1 >= WIN_SCORE || score2 >= WIN_SCORE) game_over = true;
}
vector<double> get_state(int player) const {
vector<double> state(N_INPUT);
if (player == 0) {
state[0] = ball_x; state[1] = ball_y;
state[2] = ball_vx / BALL_SPEED; state[3] = ball_vy / BALL_SPEED;
state[4] = paddle1_y / FIELD_HEIGHT; state[5] = paddle2_y / FIELD_HEIGHT;
} else {
state[0] = 1.0 - ball_x; state[1] = ball_y;
state[2] = -ball_vx / BALL_SPEED; state[3] = ball_vy / BALL_SPEED;
state[4] = paddle2_y / FIELD_HEIGHT; state[5] = paddle1_y / FIELD_HEIGHT;
}
return state;
}
bool is_game_over() const { return game_over; }
int get_score(int player) const { return (player == 0) ? score1 : score2; }
double get_paddle1_y() const { return paddle1_y; }
double get_paddle2_y() const { return paddle2_y; }
double get_ball_x() const { return ball_x; }
double get_ball_y() const { return ball_y; }
};
// ==================== 控制台图形演示 ====================
constexpr int SCREEN_WIDTH = 60;
constexpr int SCREEN_HEIGHT = 20;
void draw_pong(const Pong& env, int score1, int score2) {
cout << "\033[2J\033[H";
cout << "+" << string(SCREEN_WIDTH, '-') << "+\n";
auto map_y = [](double y) -> int { return static_cast<int>(y * (SCREEN_HEIGHT - 2)) + 1; };
auto map_x = [](double x) -> int { return static_cast<int>(x * (SCREEN_WIDTH - 2)) + 1; };
vector<string> screen(SCREEN_HEIGHT, string(SCREEN_WIDTH + 2, ' '));
for (int i = 0; i < SCREEN_HEIGHT; ++i) { screen[i][0] = '|'; screen[i][SCREEN_WIDTH+1] = '|'; }
int paddle1_top = map_y(env.get_paddle1_y());
int paddle1_bottom = map_y(env.get_paddle1_y() + PADDLE_HEIGHT);
for (int y = paddle1_top; y <= paddle1_bottom; ++y) if (y >= 0 && y < SCREEN_HEIGHT) screen[y][1] = '|';
int paddle2_top = map_y(env.get_paddle2_y());
int paddle2_bottom = map_y(env.get_paddle2_y() + PADDLE_HEIGHT);
for (int y = paddle2_top; y <= paddle2_bottom; ++y) if (y >= 0 && y < SCREEN_HEIGHT) screen[y][SCREEN_WIDTH] = '|';
int ball_x_screen = map_x(env.get_ball_x());
int ball_y_screen = map_y(env.get_ball_y());
if (ball_x_screen >= 1 && ball_x_screen <= SCREEN_WIDTH && ball_y_screen >= 0 && ball_y_screen < SCREEN_HEIGHT)
screen[ball_y_screen][ball_x_screen] = 'O';
for (int y = 0; y < SCREEN_HEIGHT; ++y) cout << screen[y] << "\n";
cout << "+" << string(SCREEN_WIDTH, '-') << "+\n";
cout << "Score: " << score1 << " - " << score2 << "\n";
cout << "Press any key to step, 'q' to quit, 'a' to toggle auto-play, space to pause auto.\n";
}
bool kbhit() {
struct termios oldt, newt; int ch; int oldf;
tcgetattr(STDIN_FILENO, &oldt);
newt = oldt; newt.c_lflag &= ~(ICANON | ECHO);
tcsetattr(STDIN_FILENO, TCSANOW, &newt);
oldf = fcntl(STDIN_FILENO, F_GETFL, 0);
fcntl(STDIN_FILENO, F_SETFL, oldf | O_NONBLOCK);
ch = getchar();
tcsetattr(STDIN_FILENO, TCSANOW, &oldt);
fcntl(STDIN_FILENO, F_SETFL, oldf);
if (ch != EOF) { ungetc(ch, stdin); return true; }
return false;
}
// 生成随机权重参数(维度与训练模型一致)
vector<double> generate_random_weights(int dim) {
vector<double> params(dim);
for (int i = 0; i < dim; ++i) {
params[i] = uniform_real(rng);
}
return params;
}
// ==================== 演示主函数 ====================
int main(int argc, char* argv[]) {
string model_file = "best_pong_model_adv_dyn.txt";
if (argc > 1) model_file = argv[1];
vector<double> params;
bool using_random = false;
// 尝试加载模型文件
ifstream fin(model_file);
if (!fin.is_open()) {
cerr << "警告:无法打开模型文件 " << model_file << ",将使用随机权重。" << endl;
using_random = true;
} else {
double val;
while (fin >> val) params.push_back(val);
fin.close();
// 检查维度
int expected_dim = N_HIDDEN1 * N_INPUT
+ N_HIDDEN1 * N_HIDDEN1
+ N_HIDDEN2 * N_HIDDEN1
+ N_OUTPUT * N_HIDDEN2
+ N_HIDDEN1
+ N_HIDDEN2
+ N_OUTPUT;
if ((int)params.size() != expected_dim) {
cerr << "警告:模型参数维度不匹配!期望 " << expected_dim << ",实际 " << params.size()
<< ",将使用随机权重。" << endl;
using_random = true;
params.clear();
} else {
cout << "成功加载模型,参数维度: " << params.size() << endl;
}
}
// 如果需要随机权重,则生成
if (using_random) {
int expected_dim = N_HIDDEN1 * N_INPUT
+ N_HIDDEN1 * N_HIDDEN1
+ N_HIDDEN2 * N_HIDDEN1
+ N_OUTPUT * N_HIDDEN2
+ N_HIDDEN1
+ N_HIDDEN2
+ N_OUTPUT;
params = generate_random_weights(expected_dim);
cout << "使用随机权重,参数维度: " << params.size() << endl;
cout << "注意:随机权重可能表现不佳,仅用于验证程序功能。" << endl;
}
cout << "网络结构: " << N_INPUT << " -> " << N_HIDDEN1 << " (循环) -> "
<< N_HIDDEN2 << " -> " << N_OUTPUT << endl;
RNN_SNN best_net(params);
int choice;
cout << "\n请选择模式:\n";
cout << "1 - 实时控制台图形演示 (一局)\n";
cout << "2 - 批量评估 (与随机策略对弈 100 局)\n";
cout << "3 - 退出\n";
cout << "请输入数字: ";
cin >> choice;
if (choice == 3) return 0;
if (choice == 1) {
cout << "开始实时演示,按任意键步进,按 'q' 退出,按 'a' 切换自动播放,自动播放时按空格暂停。\n";
Pong env;
best_net.reset_state();
bool auto_mode = false;
struct termios oldt, newt;
tcgetattr(STDIN_FILENO, &oldt);
newt = oldt;
newt.c_lflag &= ~(ICANON | ECHO);
tcsetattr(STDIN_FILENO, TCSANOW, &newt);
while (!env.is_game_over()) {
vector<double> state = env.get_state(0);
int action = best_net.forward(state);
uniform_int_distribution<int> dist(0, 2);
int action2 = dist(rng);
// 调试输出动作(可选,可注释掉)
// cout << "AI action: " << action << " ";
env.step(action, action2);
draw_pong(env, env.get_score(0), env.get_score(1));
if (!auto_mode) {
char key = getchar();
if (key == 'q') break;
if (key == 'a') {
auto_mode = true;
cout << "\n自动播放模式" << endl;
usleep(100000);
}
} else {
usleep(50000); // 20帧/秒
if (kbhit()) {
char key = getchar();
if (key == 'q') break;
if (key == ' ') {
auto_mode = false;
cout << "\n手动模式" << endl;
}
}
}
}
tcsetattr(STDIN_FILENO, TCSANOW, &oldt);
cout << "\n游戏结束!最终比分: " << env.get_score(0) << " : " << env.get_score(1) << endl;
if (env.get_score(0) > env.get_score(1)) cout << "AI 获胜!\n";
else if (env.get_score(0) < env.get_score(1)) cout << "随机对手获胜!\n";
else cout << "平局。\n";
} else if (choice == 2) {
const int DEMO_GAMES = 100;
vector<int> scores(DEMO_GAMES);
cout << "\n批量评估中 (" << DEMO_GAMES << " 局) ..." << endl;
for (int game = 0; game < DEMO_GAMES; ++game) {
RNN_SNN net(params);
Pong env;
net.reset_state();
while (!env.is_game_over()) {
vector<double> state = env.get_state(0);
int action = net.forward(state);
uniform_int_distribution<int> dist(0, 2);
int action2 = dist(rng);
env.step(action, action2);
}
scores[game] = env.get_score(0);
if ((game+1) % 20 == 0) cout << "." << flush;
}
cout << endl;
double avg_score = accumulate(scores.begin(), scores.end(), 0.0) / DEMO_GAMES;
int max_score = *max_element(scores.begin(), scores.end());
int min_score = *min_element(scores.begin(), scores.end());
int win_count = count_if(scores.begin(), scores.end(), [](int s) { return s > 0; });
int lose_count = count_if(scores.begin(), scores.end(), [](int s) { return s < 0; });
int draw_count = DEMO_GAMES - win_count - lose_count;
cout << fixed << setprecision(2);
cout << "\n========== 评估统计 ==========\n";
cout << "局数: " << DEMO_GAMES << endl;
cout << "平均得分: " << avg_score << endl;
cout << "最高分: " << max_score << endl;
cout << "最低分: " << min_score << endl;
cout << "获胜局数: " << win_count << " (" << 100.0*win_count/DEMO_GAMES << "%)" << endl;
cout << "失败局数: " << lose_count << " (" << 100.0*lose_count/DEMO_GAMES << "%)" << endl;
cout << "平局局数: " << draw_count << " (" << 100.0*draw_count/DEMO_GAMES << "%)" << endl;
cout << "=============================\n";
}
return 0;
}
best_pong_model_adv_dyn.txt
-0.845223
0.450172
-0.37189
-0.59501
-0.159849
-0.406481
-0.21484
0.466675
0.885094
-0.117648
-0.483216
-0.235449
-0.974355
-0.806158
-0.370667
-1.16137
-1.84086
0.717713
-0.567877
-0.469363
0.77805
0.957042
-0.731945
-0.384637
0.626338
-0.270856
0.167624
-0.755187
0.388279
0.0561348
0.471746
-0.363879
0.240065
-0.806298
-0.186455
-0.739827
0.208061
0.627659
-0.773108
-0.832786
-0.107428
-0.285038
-0.86431
2.27606
-0.81213
-0.944545
-0.876118
-0.739064
0.888569
0.371942
0.594492
-0.166166
-0.965309
-0.871631
1.09123
-1.10489
0.58516
-2.504
-0.193341
-0.267171
-0.185436
0.680193
2.13502
0.596983
0.930974
-0.698685
-0.781148
0.0193871
0.394585
-0.210358
-0.279408
0.887115
-0.46254
0.702621
0.724535
0.280933
0.508904
-0.594344
0.757752
-0.549501
-0.226271
0.279525
0.768678
0.922795
0.088871
0.796042
-0.490886
0.573177
-0.408751
-0.581571
-0.937579
0.66995
-0.781717
0.579831
-0.645132
0.322304
0.474407
-0.100593
0.4956
-0.939859
0.529673
0.244073
0.83731
-0.444014
-0.67409
-0.0298417
-0.739371
0.0797288
1.17626
0.222897
-0.508395
0.883452
-0.663968
-1.66678
-0.773194
0.99932
-0.0812166
-0.403997
0.126641
-0.410407
0.993032
-0.985794
-0.569801
0.991068
-0.312539
0.74366
-0.528389
0.0188687
0.960639
-0.252885
-0.475342
0.0404859
1.82861
0.420788
-0.426696
-0.482164
0.671723
0.467481
0.214529
0.815292
-0.568554
-2.08137
0.419281
0.342648
0.577517
1.08927
0.428804
0.943426
0.929803
0.189169
-0.0773863
0.927423
0.0456721
0.315017
0.494776
-0.109564
0.0350206
-0.0209009
-0.141516
-0.172241
-0.460267
0.536731
-0.924756
0.351839
0.312773
0.910862
0.599692
0.506551
0.389037
0.409848
-0.910109
-0.855657
0.604555
-0.147664
0.452027
1.18204
0.721876
0.583576
-0.216638
0.990776
1.33092
0.463485
-0.297555
0.766627
-0.919949
0.920868
0.617668
-0.711902
3.02326
-0.0107832
-0.62842
0.582181
0.38914
0.881086
-0.534614
-0.753083
0.802814
-0.533887
0.560449
-0.744721
0.711614
-0.790792
2.23875
-0.399292
-0.663189
1.42803
0.252535
-0.660979
1.49921
0.209481
-0.448918
-0.297021
-0.687483
-0.773446
0.534336
0.319584
-0.306075
-0.0860927
-0.800088
0.0757956
-0.0362516
-2.01446
-0.230609
-0.169272
0.533377
0.804314
0.435831
-1.18072
-0.569063
-0.917841
0.0942862
-0.175105
0.484656
-0.127209
-0.70509
-0.152533
0.82992
-0.307969
0.329111
0.327227
0.901622
-0.932354
-0.586609
0.445773
-0.675437
0.758341
-0.353218
1.04135
-0.191014
-0.874475
0.907587
-0.498862
-0.00441651
0.354772
0.517626
-2.17604
-0.0461481
-0.0136672
-0.115877
-0.0218927
-0.507028
-0.320416
-0.529471
0.264548
0.333239
0.772867
0.913761
0.832162
-0.900749
-0.745885
0.622284
-2.8798
0.741933
0.193875
0.437192
-0.905997
0.360538
-0.35469
-0.816001
-0.372665
0.637191
0.35078
-0.301014
0.0056218
-0.114045
0.848494
0.845212
-0.0415133
0.772325
-0.232023
-0.454234
0.909179
0.792353
0.295075
0.767667
0.783744
0.0678642
0.269869
-0.727397
-2.38421
0.391669
0.631414
-0.525355
0.890285
-2.56352
0.494749
0.608066
-0.442151
0.989271
0.394687
0.947206
0.932868
1.67273
2.25916
-0.789433
0.774709
-0.81634
-0.30686
-0.599224
0.466211
-0.350123
-0.578659
-0.122054
-0.705275
0.771224
0.540704
0.5229
-0.196858
-0.646384
-0.426532
-0.343569
-0.414843
-0.352407
-0.961355
0.145707
0.510765
0.902293
0.67689
-0.9642
0.283126
-0.385613
0.633513
0.233989
0.931202
-0.650873
-0.775078
-1.58928
-0.277535
-0.149135
0.1112
0.558872
-0.167937
0.869047
-0.543405
1.65386
0.179738
2.94895
3.03545
0.151391
-0.400984
-0.322348
-0.623546
0.168031
-0.296142
-0.398759
-0.865799
0.393483
-0.0318196
0.137233
-0.375567
0.857927
-0.664084
-0.855475
0.23443
0.513704
-0.842453
0.193706
-0.444882
-0.151814
0.0243783
-0.327859
-1.18478
0.863145
-0.119771
0.821787
-0.777182
-0.745379
0.134681
-0.606881
-0.798093
-0.428933
0.879202
-0.511533
-0.230174
0.69407
-0.344302
-0.754295
-0.139796
-0.19495
-0.534916
0.647245
0.311893
0.792051
0.318321
-0.312969
-0.874292
0.86762
-0.940906
-0.118865
0.327476
0.523305
-2.03548
0.713117
-0.0496671
0.746458
-0.615896
-0.979613
1.77777
-0.62001
-0.682997
-0.884511
-0.965936
-0.738663
0.853514
0.813073
0.378534
-0.65531
-0.836507
0.927591
-0.307366
-0.0143172
0.997276
0.150748
0.259263
-0.642073
0.487967
0.53284
-0.496499
-1.68256
1.49184
-0.959185
0.408944
0.369751
-0.0604678
0.673219
-0.753938
0.402187
0.781955
0.499394
-0.495829
-0.961502
-0.211828
0.722816
0.780377
0.212932
-0.395453
0.685051
0.409686
0.881135
2.17551
-0.491071
0.879693
-0.952876
-0.503052
-1.14131
-0.344469
0.238566
0.141403
0.335216
0.783562
-0.496796
-0.633728
0.863007
0.903141
-0.750328
-0.638109
0.734706
-0.29538
0.15962
0.314778
-0.950769
-0.685422
0.87997
0.175495
0.286959
-0.0884309
-0.716682
0.514851
0.0557019
0.302701
0.11267
0.936113
0.0339423
1.64757
0.0520335
-0.791371
0.609351
-1.06717
0.296111
0.477795
0.325728
-0.243043
0.981193
-0.103387
-0.207897
2.38796
-2.54709
-0.618475
0.277671
-0.983911
0.291782
0.679395
0.131469
-0.74663
-0.184705
0.70719
-0.474425
-0.349968
0.710537
0.525839
0.152258
-0.530633
-0.396425
0.436796
1.02452
-0.293911
0.637053
-0.506141
0.745766
-1.80467
0.281285
-0.523673
-0.0195748
0.238338
0.722169
-0.405298
-0.333594
0.0204768
0.63415
0.15396
0.141428
-0.764908
-0.534823
-0.381687
0.851726
0.385916
-0.74927
-0.739825
0.63801
-0.456513
0.875277
-0.195984
0.734546
-0.928167
-0.896218
-0.920692
0.33085
0.57656
-0.747576
0.167785
0.896133
-0.977131
0.0344068
-0.170824
-0.812947
-0.988641
0.78811
0.14465
-0.210952
0.396222
-0.414837
-0.874335
-0.735436
-1.03886
1.92743
-0.0561282
-0.597759
-0.184168
-0.27808
-0.352918
0.274451
0.786998
-0.495613
-0.865739
0.378218
-0.484501
0.637441
-0.490801
0.162519
0.645687
0.584519
-1.00664
0.357289
-0.154886
0.0885934
0.958833
-0.147798
1.47499
-0.226955
0.031946
-0.77636
0.719524
-0.208633
-0.901911
-0.13754
0.231229
0.352528
-0.871752
-0.0253423
-1.46138
0.726356
-0.656574
0.995143
0.491567
-0.703124
0.102175
0.965727
-0.80097
-0.761266
0.0250178
-0.0727205
0.839854
-1.08806
0.29969
-0.565473
-0.0902629
-0.79494
-0.793689
-0.51158
1.7176
-0.930688
0.189507
0.824813
0.0724994
0.00563023
0.901364
0.662777
-0.761843
0.691055
0.000461067
0.807827
0.269156
0.824881
-0.0755446
-0.654397
0.0973714
1.45659
-0.77357
-0.898999
0.425723
-0.24012
0.95598
0.0752455
-0.566345
-0.643778
-0.973524
-0.773168
-0.141923
-0.855677
0.424536
-0.685144
-0.423323
-0.412702
-2.74763
0.696755
-0.323085
0.792402
-0.506173
-0.161943
-0.111499
-0.0395083
0.943707
-0.639767
-0.936713
-0.335374
-0.964586
-0.0298077
0.364335
-0.681536
-0.885071
-0.629372
0.0280756
-0.860915
-1.94873
0.573363
-0.157292
-0.627904
-0.0602238
0.193158
-0.177992
0.446692
0.21338
-0.417862
0.233335
-0.203406
-0.240368
0.16356
-0.157474
-0.993227
-0.0846531
-0.38953
-0.777936
-0.493502
-0.00719041
0.238833
0.566398
-0.34347
-0.323288
0.889177
0.0751694
1.44029
0.561478
-0.868241
-0.957497
0.056005
2.22348
-0.852466
1.36097
0.831047
-0.481661
0.713718
0.233111
0.316526
-0.613791
-0.272301
0.156418
0.731533
0.350733
0.166554
0.979122
-1.84546
0.940076
-0.35365
0.519781
0.552912
0.338849
0.853927
0.62054
0.601798
-0.0144602
2.56923
-0.696903
0.975829
0.198525
0.968336
-1.23508
0.200656
-0.793483
-0.162511
0.15459
-0.91255
0.230142
0.170046
-0.203899
2.1238
0.852522
-0.324195
0.160552
0.154282
0.146985
0.579369
0.257603
0.619904
0.782284
-0.692165
-0.472643
-0.874485
0.441287
0.376282
-0.383913
0.205716
-0.941977
0.653146
0.483057
-0.444005
0.762592
-0.847559
-0.0499932
-0.0290519
-0.551632
-0.924207
0.260241
0.867722
左边是AI,右边是随机,只经过了初次训练,反复训练效果会更好
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐

所有评论(0)