强化学习,启动!

如果训练不起来,请复制文末的代码。(我转到服务器后做了一些修改,尽量按照文末的代码复制粘贴)

TIMESTEPS = 10000
iters = 0
while True:
    print(f'On iteration:{iters}')
    model.learn(total_timesteps=TIMESTEPS, tb_log_name='PPO', reset_num_timesteps=False)
    model.save(f'{model_dir}/{TIMESTEPS*iters}')
  • TIMESTEPS = 10000:每一轮训练的时间步长(每轮训练 1 万步,对应星际 2 多局游戏)
  • while True:无限循环训练,持续迭代优化模型
  • model.learn(...)
    • total_timesteps=TIMESTEPS:本轮训练 1 万步
    • tb_log_name='PPO':TensorBoard 日志标签
    • reset_num_timesteps=False关键! 不重置训练步数,保证多轮训练的日志连续、模型迭代不中断
  • model.save(...):每轮训练结束后,自动保存模型,命名为{TIMESTEPS*iters}(如 10000、20000、30000...),方便后续加载不同训练阶段的模型

出现这个就是开始了

 然后在vscode打开tensorboard,实时查看训练日志

tensorboard --logdir ./logs

中断恢复

1.最新模型文件

首先在model对应的文件夹中找到最新的模型,我们用时间戳命名的,找最大的就行。

model_name = '<最新的模型>' 
model_dir = f'models/{model_name}/'
logs_dir = f'logs/{model_name}/'

2. 加载指定步数的模型(断点续训)

然后确定继续训练步数,不会从零开始,不会浪费之前的计算

model = PPO.load(f'{model_dir}/<指定继训步数>.zip', env=env, tensorboard_log=logs_dir)

3. 继续训练

model.learn(total_timesteps=TIMESTEPS, ...)
model.save(f'{model_dir}/{TIMESTEPS*iters}')
  • 每轮训 10000 步
  • 保存为 260000, 270000, 280000...
  • 实现无缝续训,直到满意

训练效果

经过一周的训练,我们打开tensorflow面板:

在训练步数一栏,步数大概在1000-2000徘徊

在奖励一栏,可以看到在32万步的时候,奖励达到了-85,这是非常非常非常棒的结果!!!

我们用callbacks记录了最好批次的模型,可以后续调用。

train/approx_kl 指标解释

train/approx_kl 代表 近似 KL 散度

  • 含义:KL 散度(Kullback-Leibler Divergence)用于衡量两个概率分布之间的差异。在 PPO 训练中,它衡量的是当前策略(更新后)旧策略(更新前)之间的差异有多大。
  • 通俗理解:它表示模型每一次更新时,“步子”迈得有多大。
    • 如果数值很高,说明新策略和旧策略差别很大,模型正在剧烈地改变它的行为。
    • 如果数值很低,说明新策略和旧策略差不多,模型只是在做微调。

数值范围与稳定性

  • 图表中的数值主要集中在 0 到 0.015 之间。
  • 相比于上一张 clip_fraction 图,虽然这里依然有剧烈的锯齿状波动,但整体维持在一个相对低位的水平。
  • 早期(0 - 500k步):波动非常剧烈,甚至在某些点冲得很高,说明训练初期模型在学习时非常“激进”,尝试大幅改变策略。
  • 后期(2M - 5M步):波动依然存在,但整体趋势似乎略有下降或保持稳定,没有持续发散(即没有一直往上涨)。

train/clip_fraction 指标解释

train/clip_fraction 是 PPO(Proximal Policy Optimization,近端策略优化)算法中的一个重要指标。

  • 含义:它表示在当前的训练批次中,有多少比例的梯度更新被“裁剪”了。
  • PPO 原理简述:PPO 为了防止策略更新步幅过大导致训练崩溃,引入了一个“裁剪”机制。它限制了新旧策略概率比率的变化范围(通常在 0.8 到 1.2 之间)。如果某个更新的比率超出了这个范围,就会被强行拉回到边界值,这就是“裁剪”。
  • 这个指标的意义
    • 数值过高(接近 1.0):说明绝大多数更新都被裁剪了。这通常意味着学习率太高,或者 PPO 的裁剪范围(epsilon)太小。模型想要迈大步子,但被强行拉住,可能导致学习效率低下。
    • 数值过低(接近 0):说明几乎没有更新被裁剪。这可能意味着学习率太低,或者策略更新非常保守。虽然训练稳定,但可能收敛速度很慢。
    • 理想状态:通常希望这个值在一个适中的范围(例如 0.1 - 0.3 左右,但这取决于具体任务),表明模型正在积极地学习,同时受到 PPO 机制的有效约束。

train/explained_variance 核心指标解释

train/explained_variance 衡量的是你的价值网络(Critic)预测的值与实际获得的回报之间的相关性。这是强化学习(尤其是 PPO 算法)中用于评估价值函数(Value Function / Critic)拟合好坏的重要指标。

  • 数值为 1:完美预测。模型完全理解了环境,能准确预测未来的收益。
  • 数值为 0:预测效果就像“瞎猜平均值”。模型没有学到任何有用的规律,预测值和随便猜一个平均数没区别。
  • 数值为负数:预测比瞎猜还差。模型不仅没学会,甚至还在“误导”自己(预测值与实际值反向变动)。

train/loss 指标解释

train/loss,这是整个训练过程中最核心的指标之一。

简单来说,Loss(损失)代表了模型“犯错”的程度。Loss 越低,说明模型预测得越准,犯错越少;Loss 越高,说明模型越困惑,预测越离谱。

在 PPO 算法中,这个 Loss 通常是 策略损失(Policy Loss) 和 价值函数损失(Value Loss) 的总和。

train/policy_gradient_loss 指标解释

train/policy_gradient_loss,这是 PPO 算法中专门衡量策略网络(Actor)更新力度的指标。

简单来说,它反映了模型为了获得更多奖励,想要改变自身行为的强烈程度。

核心指标解释

  • 含义:这是 PPO 目标函数中的核心部分。它计算的是在考虑了“优势”(Advantage,即某个动作比平时好多少)之后,策略网络应该被更新多少。
  • 数值通常为负:在 PPO 的实现中,为了使用梯度下降算法来最大化奖励,通常会对目标函数取反,所以你会看到这个 Loss 是负数。
  • 解读
    • 绝对值越大(越负):说明模型认为当前的策略很糟糕,或者发现了巨大的改进空间,因此想要进行大幅度的修改。
    • 绝对值越小(接近 0):说明模型认为当前的策略已经不错了,或者在这个批次的数据中没有学到什么新的有用信息,不需要做大的改动。

train/value_loss 指标解释

train/value_loss,它直接衡量的是 Critic(价值网络) 的预测误差。

这张图是整个训练问题的“罪魁祸首”,它揭示了为什么之前的总 Loss 会那么高。

  • 含义:Critic 的工作是预测未来能获得多少奖励(Value)。value_loss 衡量的是 Critic 的预测值与实际获得的回报(Return)之间的差距(通常使用均方误差 MSE)。
  • 解读
    • 数值越低越好:说明 Critic 预测得很准。
    • 数值高:说明 Critic 完全是在“瞎猜”,预测值和真实结果天差地别。

在vscode也可以看到每个指定步数存下的模型

在result.txt中也可以看到victory的频次明显增加了

完整训练代码

WorkerRushBot.py

from sc2 import maps
from sc2.player import Bot, Computer
from sc2.main import run_game
from sc2.data import Race, Difficulty
from sc2.bot_ai import BotAI
import pickle
import time
import random
import numpy as np
from sc2.ids.unit_typeid import UnitTypeId 

import math
import cv2
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import nest_asyncio
import asyncio
nest_asyncio.apply()
class WorkerRushBot(BotAI):       
    async def on_step(self, iteration: int):
        try:
            with open('transaction.pkl', 'rb') as f:
                transaction = pickle.load(f)
        except:
            return

        action = transaction['action']
        if action is None:
            return  # 没有动作,直接跳过

        action=transaction['action']
        #print(f'迭代次数:{iteration}')
        await self.distribute_workers()
        if action==0:
            have_builded = False
        
            # 1. 优先补水晶塔
            if self.supply_left < 4:
                if self.can_afford(UnitTypeId.PYLON) and self.already_pending(UnitTypeId.PYLON) == 0:
                    if self.townhalls.exists:
                        await self.build(UnitTypeId.PYLON, near=self.townhalls.first)
                        have_builded = True

            # --------------------------
            # ✅ 【扩张分矿】永远不阻塞!
            # --------------------------
            if self.townhalls.amount < 4 and self.can_afford(UnitTypeId.NEXUS):
                if self.already_pending(UnitTypeId.NEXUS) == 0:
                    await self.expand_now()

            # --------------------------
            # 【农民 / 气矿】独立运行
            # --------------------------
            for nexus in self.townhalls:
                # 农民
                if self.workers.closer_than(10, nexus).amount < 22:
                    if self.can_afford(UnitTypeId.PROBE) and nexus.is_idle:
                        nexus.train(UnitTypeId.PROBE)
        
                # 气矿
                for vespene in self.vespene_geyser.closer_than(15, nexus):
                        if self.can_afford(UnitTypeId.ASSIMILATOR):
                            worker = self.select_build_worker(vespene.position)
                            if worker is not None:
                                # 神族造气矿 官方唯一正确方法
                                worker.build_gas(vespene)
                                worker.stop(queue=True)
                                have_builded=False
                                #print('建造吸收间')
        
            # --------------------------
            # 【防御】不阻塞扩张
            # --------------------------
            pylon = self.structures(UnitTypeId.PYLON).ready
            if pylon.exists and not have_builded:
                near_pylon = pylon.first
                # 熔炉
                if not self.structures(UnitTypeId.FORGE).exists:
                    if self.can_afford(UnitTypeId.FORGE):
                        await self.build(UnitTypeId.FORGE, near=near_pylon)
                        have_builded = True
                # 光子炮
                elif self.structures(UnitTypeId.FORGE).ready:
                    if self.structures(UnitTypeId.PHOTONCANNON).amount < 3:
                        if self.can_afford(UnitTypeId.PHOTONCANNON):
                            await self.build(UnitTypeId.PHOTONCANNON, near=near_pylon)
                            have_builded = True
        
        

        if action==1:
            # 1:传送门、控制核心、星门
            have_build = False  # 正确拼写
        
            # 全局最多造 4 个星门(暴兵效率最大化)
            max_stargates = 4
            current_stargates = self.structures(UnitTypeId.STARGATE).amount
        
            for nexus in self.townhalls:
                # 每个基地都造传送门
                if not have_build:
                    if not self.structures(UnitTypeId.GATEWAY).closer_than(10, nexus).exists:
                        if self.can_afford(UnitTypeId.GATEWAY) and self.already_pending(UnitTypeId.GATEWAY) == 0:
                            await self.build(UnitTypeId.GATEWAY, near=nexus)
                            have_build = True  # 正确拼写
        
                # 每个基地都造控制核心
                if not have_build:
                    if not self.structures(UnitTypeId.CYBERNETICSCORE).closer_than(10, nexus).exists:
                        if self.can_afford(UnitTypeId.CYBERNETICSCORE) and self.already_pending(UnitTypeId.CYBERNETICSCORE) == 0:
                            await self.build(UnitTypeId.CYBERNETICSCORE, near=nexus)
                            have_build = True
        
                # ✅ 关键修复:允许造多个星门,直到 4 个
                if not have_build:
                    if current_stargates < max_stargates:  # 不限制“是否已有”,只限制总数
                        if self.can_afford(UnitTypeId.STARGATE) and self.already_pending(UnitTypeId.STARGATE) == 0:
                            await self.build(UnitTypeId.STARGATE, near=nexus)
                            have_build = True

        if action==2:
            #print(f'action={action}')
            #2:虚空辉光舰
            try:
                # 遍历所有【已建成、空闲】的星门
                for sg in self.structures(UnitTypeId.STARGATE).ready.idle:
                    # 如果钱够造虚空辉光舰
                    if self.can_afford(UnitTypeId.VOIDRAY):
                        # 让星门训练虚空辉光舰
                        sg.train(UnitTypeId.VOIDRAY)
                        #print('训练虚空战舰')
            except Exception as e:
                print(e)

        if action==3:
            #print(f'action={action}')
            #3:侦查
            # 1. 初始化 last_sent 时间戳(防止第一次运行报错)
            try:
                self.last_sent
            except:
                self.last_sent = 0
        
            # 2. 控制侦查频率:距离上次侦查超过100帧才执行
            if (iteration - self.last_sent) > 100:
                try:
                    # 3. 优先选择空闲的探机
                    if self.units(UnitTypeId.PROBE).idle.exists:
                        probe = random.choice(self.units(UnitTypeId.PROBE).idle)
                    # 4. 没有空闲探机,就随机选一个探机
                    else:
                        probe = random.choice(self.units(UnitTypeId.PROBE))
                    
                    # 5. 命令探机攻击/移动到敌人出生点
                    probe.attack(self.enemy_start_locations[0])
                    # 6. 更新最后一次侦查的帧号
                    self.last_sent = iteration
                    #print('侦查')
                except:
                    pass

        #4:进攻
        if action == 4:
            #print(f'action={action}')
            try:
                if(self.units(UnitTypeId.VOIDRAY).amount>1):
                    for voidray in self.units(UnitTypeId.VOIDRAY).idle:
                        # 优先级1:身边10格内有敌人单位 → 随机选一个攻击
                        if self.enemy_units.closer_than(10, voidray):
                            voidray.attack(random.choice(self.enemy_units.closer_than(10, voidray)))
                        # 优先级2:身边10格内有敌人建筑 → 随机选一个攻击
                        elif self.enemy_structures.closer_than(10, voidray):
                            voidray.attack(random.choice(self.enemy_structures.closer_than(10, voidray)))
                        # 优先级3:地图上有敌人单位 → 随机选一个攻击(A地板)
                        elif self.enemy_units:
                            voidray.attack(random.choice(self.enemy_units))
                        # 优先级4:地图上有敌人建筑 → 随机选一个攻击(拆家)
                        elif self.enemy_structures:
                            voidray.attack(random.choice(self.enemy_structures))
                        # 优先级5:找不到敌人 → 去敌人出生点
                        elif self.enemy_start_locations:
                            voidray.attack(self.enemy_start_locations[0])
                    #print('虚空辉光舰进攻')
            except Exception as e:
                print(e)
        
        #5:撤退
        if action == 5:
            #print(f'action={action}')
            try:
                if self.units(UnitTypeId.VOIDRAY).amount > 0:
                    for voidray in self.units(UnitTypeId.VOIDRAY):
                        voidray.attack(self.start_location)
                #print('撤退')
            except Exception as e:
                print(e)

        # 画图:生成地图状态观测
        # 1. 初始化空白地图
        map = np.zeros(
            (self.game_info.map_size[0], self.game_info.map_size[1], 3),
            dtype=np.uint8
        )
        
        # 2. 绘制矿产资源(水晶矿)
        for mineral in self.mineral_field:
            pos = mineral.position  # 获取矿点的坐标(x,y)
            c = [175, 255, 255]     # 基础颜色:青蓝色(代表水晶矿)
            # 计算剩余矿量比例:当前矿量 / 初始满矿量(2250)
            fraction = mineral.mineral_contents / 2250
        
            if mineral.is_visible:
                # 可见矿:按剩余矿量比例,调整颜色亮度(矿越多越亮)
                map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction * i) for i in c]
            else:
                # 不可见/战争迷雾中的矿:显示灰色(代表未知)
                map[math.ceil(pos.y)][math.ceil(pos.x)] = [50, 50, 50]

        #3:绘制瓦斯资源
        for vespene in self.vespene_geyser:
            pos = vespene.position  # 获取气矿的坐标(x,y)
            c = [255, 175, 255]     # 基础颜色:粉紫色(代表气矿)
            # 计算剩余气矿量比例:当前气矿量 / 初始满矿量(2250)
            fraction = vespene.vespene_contents / 2250
        
            if vespene.is_visible:
                # 可见气矿:按剩余气矿量比例,调整颜色亮度(气越多越亮)
                map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction * i) for i in c]
            else:
                # 不可见/战争迷雾中的气矿:显示灰色(代表未知)
                map[math.ceil(pos.y)][math.ceil(pos.x)] = [50, 50, 50]

        #4:绘制基础设施
        for structure in self.structures:
            pos = structure.position  # 获取建筑的坐标(x,y)
            
            # 区分建筑类型:基地(nexus)用特殊颜色,其他建筑用另一种颜色
            if structure.type_id == UnitTypeId.NEXUS:
                c = [255, 255, 175]  # 亮黄色(代表基地/主基地)
            else:
                c = [0, 255, 175]     # 青绿色(代表其他己方建筑,如水晶塔、传送门、星门等)
            
            # 计算血量比例:当前血量 / 最大血量(避免除零)
            fraction = structure.health_percentage
            
            # 按血量比例缩放颜色,绘制到地图上
            map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction * i) for i in c]

        #5:绘制我方单位
        for unit in self.units:
            pos = unit.position  # 获取单位的坐标(x,y)
            
            # 区分单位类型:虚空辉光舰用特殊蓝色,其他单位用亮绿色
            if unit.type_id == UnitTypeId.VOIDRAY:
                c = [255, 0, 0]  # 蓝色(代表核心作战单位:虚空辉光舰)
            else:
                c = [175, 255, 0]  # 亮绿色(代表其他己方单位:探机等)
            
            # 直接获取血量百分比(0~1),无需手动计算
            fraction = unit.health_percentage
            
            # 按血量比例缩放颜色,绘制到地图上
            map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction * i) for i in c]

        #6:绘制敌人的起始位置(出生点)
        for enemy_location in self.enemy_start_locations:
            pos = enemy_location  # 获取敌人出生点坐标
            # 纯红色(代表敌人老家,显眼)
            c = [0, 0, 255]
            # 直接赋值,不需要遍历i
            map[math.ceil(pos.y)][math.ceil(pos.x)] = c
        
        #7:绘制敌人的基础设施(建筑)
        for structure in self.enemy_structures:
            pos = structure.position
            # 亮红色(代表敌人建筑)
            c = [0, 100, 255] 
            # 按血量比例缩放颜色(满血最亮,残血变暗)
            fraction = structure.health_percentage
            map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction * i) for i in c]
        
        #8:绘制敌人的单位(兵力)
        for unit in self.enemy_units:
            pos = unit.position
            # 橙红色(代表敌人活跃的单位/部队)
            c = [100, 0, 255]
            # 按血量比例缩放颜色
            fraction = unit.health_percentage
            map[math.ceil(pos.y)][math.ceil(pos.x)] = [int(fraction * i) for i in c]

        # 计算奖励值
        reward = 0  # 初始化奖励为0
        try:
            # 遍历所有己方的虚空辉光舰
            for voidray in self.units(UnitTypeId.VOIDRAY):
                # 条件1:虚空舰正在攻击,且目标在攻击范围内(有效攻击)
                if voidray.is_attacking and voidray.target_in_range:
                    # 条件2:虚空舰8格范围内有敌人单位/建筑(在战场中,不是空跑)
                    if self.enemy_structures.closer_than(8, voidray) or self.enemy_units.closer_than(8, voidray):
                        # 满足所有条件,给奖励
                        reward += 0.015
        except Exception as e:
            # 捕获异常(比如没有虚空舰、敌人不存在),避免崩溃
            print(f'reward error:{e}')
            reward = 0  # 异常时奖励归零
        
        # 每10帧打印一次日志,方便调试
        if iteration % 10 == 0:
            print(f'iteration:{iteration},RW:{reward},VR:{self.units(UnitTypeId.VOIDRAY).amount}')
        
        # 9. 显示地图(缩放+翻转,适配OpenCV显示)
        cv2.imshow(
            'map',
            cv2.flip(
                cv2.resize(
                    map,
                    None,
                    fx=4, fy=4,  # 放大3倍,方便观察
                    interpolation=cv2.INTER_NEAREST  # 最近邻插值,保留像素块
                ),
                0  # 0=上下翻转,修正坐标系
            )
        )
        cv2.waitKey(1)  # 等待1ms,刷新窗口(必须加,否则窗口卡死)

        transaction['action']=None
        with open('transaction.pkl','wb') as f:
            pickle.dump(transaction,f)

        
            
if __name__ == "__main__":
    print('WorkerRushBot.py start')

    result=run_game(maps.get("2000AtmospheresAIE"), [
        Bot(Race.Protoss, WorkerRushBot()),
        Computer(Race.Zerg, Difficulty.Hard)
    ], realtime=False)

    # 1. 记录比赛结果到日志文件
    with open('result.txt', 'a') as f:
        f.write(f'{result}\n')

    # 2. 发放终局奖励/惩罚
    if str(result) == 'Result.Victory':
        print('Victory!')
        rwd = 500  # 胜利,给+500大额奖励
    else:
        rwd = -500  # 失败/平局,给-500大额惩罚

    # 3. 生成最终观测与交易数据,保存为pkl文件
    map = np.zeros((244,244,3), dtype = np.uint8)
    transaction = {'observation':map, 'reward':rwd, 'action':None, 'terminated':True,'truncated':False}
    with open('transaction.pkl', 'wb') as f:
        pickle.dump(transaction, f)

    # 4. 清理OpenCV窗口,避免残留
    cv2.destroyAllWindows()
    cv2.waitKey(1)
    time.sleep(1)

StarCraft2Env.py

import numpy as np
import gymnasium as gym
import time
import pickle
import subprocess
import os # 用于检查文件

class StarCraft2Env(gym.Env):
    def __init__(self):
        super(StarCraft2Env, self).__init__()
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(244, 244, 3), dtype=np.uint8)
        self.action_space = gym.spaces.Discrete(6)
        
        # 1. 在类初始化时定义超时变量
        self.wait_time = 0
        self.process = None

    def step(self, action):
        # --- 阶段 1:发送动作 ---
        start_time = time.time()
        self.wait_time = 0
        
        while True:
            try:
                # 读取当前状态
                with open('transaction.pkl', 'rb') as f:
                    transaction = pickle.load(f)
                
                # 只有当 Bot 把 action 变回 None 时,我们才写入新动作
                if transaction.get('action') is None:
                    transaction['action'] = action
                    with open('transaction.pkl', 'wb') as f:
                        pickle.dump(transaction, f)
                    break # 动作发送成功,跳出循环
            except Exception as e:
                pass # 文件被占用或不存在,忽略
            
            time.sleep(0.05)
            self.wait_time = time.time() - start_time
            
            # 【修复】发送动作也要有超时,防止 Bot 死了不回 None
            if self.wait_time > 30:
                print("!!! 超时:Bot 未能在 30秒内接收动作,可能已卡死。")
                self.close()
                return self.reset()[0], 0, True, False, {}

        # --- 阶段 2:等待结果 ---
        start_time = time.time()
        self.wait_time = 0
        
        while True:
            try:
                with open('transaction.pkl', 'rb') as f:
                    transaction = pickle.load(f)
                
                # 当 Bot 处理完动作,会将 action 设回 None,并填入新数据
                if transaction.get('action') is None:
                    observation = transaction['observation']
                    reward = transaction['reward']
                    terminated = transaction['terminated']
                    truncated = transaction['truncated']
                    return observation, reward, terminated, truncated, {}
                    
            except Exception as e:
                pass # 忽略读取错误
            
            time.sleep(0.05)
            self.wait_time = time.time() - start_time

            # 【修复】等待结果超时
            if self.wait_time > 60:
                print(f"!!! ERROR: 环境超时 (等待结果 > 60s),强制重置。")
                self.close()
                # 返回随机/零观测,并标记 done=True,让 RL 重新开始
                dummy_obs = np.zeros((244, 244, 3), dtype=np.uint8)
                return dummy_obs, 0, True, False, {}

    def reset(self, seed=None, options=None):
        print('--- Resetting Environment ---')
        self.wait_time = 0 # 【修复】重置计时器
        
        # 1. 杀掉旧进程
        self.close()
        
        # 2. 初始化文件
        # 创建一个初始的 transaction 文件,确保 Bot 启动时能读到
        map_data = np.zeros((244, 244, 3), dtype=np.uint8)
        transaction = {
            'observation': map_data, 
            'reward': 0, 
            'action': None, 
            'terminated': False, 
            'truncated': False
        }
        with open('transaction.pkl', 'wb') as f:
            pickle.dump(transaction, f)

        # 3. 启动新进程
        # 使用 shell=True 和 start 命令在 Windows 下启动新窗口
        try:
            subprocess.Popen(
                ['cmd', '/c', 'start', 'python', 'WorkerRushBot.py'],
                shell=True
            )
            print("Bot 进程已启动...")
        except Exception as e:
            print(f"启动 Bot 失败: {e}")

        # 4. 等待 Bot 初始化(给 Bot 一点时间加载游戏)
        time.sleep(5) 

        return map_data, {}

    def close(self):
        # 【新增】确保能杀掉旧进程
        if self.process:
            try:
                self.process.kill()
                print("旧进程已杀掉")
            except:
                pass
        self.process = None

SC2_Training.ipynb

依赖库导入

# 1. 导入依赖库
from stable_baselines3 import PPO
import os
import time
from StarCraft2Env import StarCraft2Env

import torch as th
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

定义保存路径

model_name = f'{int(time.time())}'
model_dir = f'models/{model_name}/'
logs_dir = f'logs/{model_name}/'

if not os.path.exists(model_dir):
    os.makedirs(model_dir)
if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)

创建环境实例

env = StarCraft2Env()

创建模型实例

model = PPO("MlpPolicy", env, verbose=1, tensorboard_log=logs_dir)

开始训练

TIMESTEPS = 5000000  # 500万步

print(f"开始训练,目标:{TIMESTEPS} 步...")
try:
    model.learn(total_timesteps=TIMESTEPS)
except KeyboardInterrupt:
    print("手动停止训练")
finally:
    env.close()  # 确保调用环境的关闭函数,杀掉 SC2 进程

print("训练结束!")
model.save("final_model")

保存模型

model.save(f'{model_dir}/{TIMESTEPS*iters}')

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐