一.参考资料

pantanal-distill-birdclef2026

二.notebook

BirdCLEF+ 2026 | 结合参数调优的改进集成方案

【预期 Kaggle 输入数据集】

  • 赛事官方数据集 BirdCLEF+ 2026
  • TensorFlow 2.20 安装包数据集
  • perch_v2_cpu 模型数据集
  • 全文件缓存 Perch 输出数据集 perch-meta(包含文件:full_perch_meta.parquet、full_perch_arrays.npz)

【处理流水线】Perch 特征提取 → ProtoSSM_v5(第一阶段)+ MLP 多层感知机 → 模型集成 → ResidualSSM(第二阶段)→ 测试时增强(TTA)→ 逐分类单元温度缩放 → 文件级置信度缩放 → 秩感知校准 → 增量偏移平滑 → 逐类别阈值筛选 → 最终输出

【组件 | 实现方案】

组件 实现方案
特征提取器 Google Perch v2(权重冻结,输出 1536 维嵌入向量)
序列模型 双向选择性状态空间模型 SSM(4 层,模型维度 d_model=320)+ 8 头交叉注意力机制
分类头 原型余弦分类头 + 门控式 Perch 知识蒸馏
未映射物种处理 属级代理对数概率均值融合
元数据融合 贝叶斯站点 / 时段先验表
后处理模块 秩感知缩放 + 增量偏移平滑 + 逐类别阈值

【模型配置】

  • ProtoSSM:模型维度 d_model=320,状态维度 d_state=32,4 层 SSM 结构,每个类别 2 个原型,搭配 8 头交叉注意力
  • MLP 探测头:隐藏层维度 (256, 128),焦点损失(γ=2.5),标签平滑系数 = 0.03
  • ResidualSSM:模型维度 d_model=128,状态维度 d_state=16,2 层结构,修正权重 = 0.35

【训练配置】80 个训练轮次,余弦退火暖重启(周期 = 20),从第 52 轮开始启用随机权重平均(SWA,学习率 lr=4e-4),5 折交叉验证折外样本(OOF)训练

【后处理流水线】

  1. 逐分类单元温度缩放 —— 鸟纲(Aves):1.10,纹理类群(两栖纲、昆虫纲):0.95
  2. 文件级置信度缩放 —— 前 2 窗口均值缩放(2025 年赛事成熟技术)
  3. 测试时增强(TTA)—— 基于 5 个偏移量的循环时序偏移:[0, 1, −1, 2, −2]
  4. 秩感知缩放 —— 对文件内最大预测值做幂次变换,幂次 = 0.4(2025 年赛事榜单第 3 名方案技术)
  5. 增量偏移平滑 —— α=0.20 的时序平滑(2025 年赛事榜单第 1 名方案技术)
  6. 逐类别阈值 —— 在 [0.25 … 0.70] 区间内网格搜索、通过折外样本(OOF)优化的决策阈值

【本 Notebook 核心优化点】

  • 更大容量的 ProtoSSM:模型维度 d_model 从 256 提升至 320,SSM 层数从 3 层增至 4 层(提升时序建模能力)
  • 秩感知后处理:基于文件内最大预测值幂次变换进行缩放
  • 增量偏移平滑:跨 12 窗口序列的时序平滑
  • 逐类别阈值:针对每个物种,通过折外样本(OOF)优化的决策阈值
  • Mixup + CutMix:针对时序嵌入序列的混合数据增强
  • 焦点损失:结合类别权重、感知物种频次的焦点二元交叉熵损失
  • 交叉注意力:在 SSM 层间加入 8 头时序交叉注意力
  • 随机权重平均(SWA):从训练进程 65% 的节点开始启用
  • 余弦暖重启:周期为 20 的暖重启学习率调度器
  • MLP 集成:通过折外样本(OOF)优化的 ProtoSSM 与 MLP 探测头融合权重

【公开榜单(LB)分数:0.929】

# 安装 ONNX Runtime —— 用来替代 TensorFlow 跑 Perch 推理,速度快 9 倍
import os, glob

# 禁用 GPU(只用 CPU 跑,更稳定)
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# 从 kaggle 输入目录里找 onnxruntime 的安装包
_whl_candidates = [w for w in glob.glob("/kaggle/input/**/onnxruntime*cp312*x86_64*.whl", recursive=True) if "gpu" not in w]

if _whl_candidates:
    # 找到就安装最新版本
    _whl_candidates.sort(reverse=True)
    print(f"Installing onnxruntime from: {_whl_candidates[0]}")
    os.system(f"pip install -q --no-deps '{_whl_candidates[0]}'")
else:
    print("ERROR: No onnxruntime wheel found!")

配置项

模式切换与超参数设置

# 全局模式切换:训练 / 提交
MODE = "submit" 

# 确保只能是这两个模式,写错直接报错
assert MODE in {"train", "submit"}

print("MODE =", MODE)

BirdCLEF+ 2026 — ProtoSSM v5:极致集成方案

处理流程:Perch 特征提取 → ProtoSSM_v5(第一阶段)+ MLP 多层感知机 → 模型集成 → ResidualSSM(第二阶段)→ 测试时增强(TTA)→ 逐分类单元温度缩放 → 文件级置信度缩放 → 秩感知校准 → 增量偏移平滑 → 逐类别阈值筛选 → 最终输出

# Cell 2 — 导入库 & 运行配置
# 这是【环境准备+全局配置】模块,用来导入所有需要的工具库,并集中设置所有参数

# ---------------------- 1. 环境设置(控制日志和GPU) ----------------------
import os  # 导入操作系统库,用来设置环境变量、处理文件路径

# 设置TensorFlow日志级别:3 = 只输出错误(屏蔽警告、信息等冗余日志)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# 禁用GPU(设置为空字符串),后面注释说明是竞赛约束,必须用CPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""

# ---------------------- 2. 导入标准库 ----------------------
import gc  # 垃圾回收库,用来清理内存,防止显存/内存溢出
import json  # JSON数据处理库,用来保存/加载配置和日志
import re  # 正则表达式库,用来处理字符串匹配
import time  # 时间库,用来记录代码运行时长
import warnings  # 警告库,用来控制警告输出
from collections import defaultdict  # 默认字典,比普通字典更方便(键不存在时自动给默认值)
from pathlib import Path  # 路径处理库,比os.path更简洁、更面向对象

# ---------------------- 3. 导入数据处理&音频库 ----------------------
import numpy as np  # 数值计算库,用来处理数组、矩阵(你学房价预测时肯定用过)
import pandas as pd  # 数据处理库,用来读表格、处理数据(同样房价预测常用)
import soundfile as sf  # 音频处理库,用来读取/写入音频文件

# ---------------------- 4. 导入深度学习框架 ----------------------
import tensorflow as tf  # TensorFlow框架,用来加载Perch模型、做部分计算
import torch  # PyTorch框架,用来搭建SSM模型、训练
import torch.nn as nn  # PyTorch神经网络模块,用来定义层
import torch.nn.functional as F  # PyTorch函数式API,用来做激活、损失等计算

# ---------------------- 5. 导入机器学习库(sklearn) ----------------------
from sklearn.decomposition import PCA  # 主成分分析,用来降维
from sklearn.linear_model import LogisticRegression  # 逻辑回归,用来做基线分类器
from sklearn.neural_network import MLPClassifier  # 多层感知机,用来做探测头
from sklearn.metrics import roc_auc_score  # ROC-AUC指标,用来评估模型
try:
    from lightgbm import LGBMClassifier  # 尝试导入LightGBM(梯度提升树)
    _LGBM_AVAILABLE = True  # 如果导入成功,标记为可用
except ImportError:
    _LGBM_AVAILABLE = False  # 如果导入失败,标记为不可用
from sklearn.model_selection import GroupKFold  # 分组K折交叉验证,用来划分训练/验证集
from sklearn.preprocessing import StandardScaler  # 标准化器,用来归一化特征

# ---------------------- 6. 导入进度条库 ----------------------
from tqdm.auto import tqdm  # 进度条库,用来显示代码运行进度(自动适配环境)

# ---------------------- 7. 额外环境设置 ----------------------
warnings.filterwarnings("ignore")  # 忽略所有警告,让输出更干净
tf.experimental.numpy.experimental_enable_numpy_behavior()  # 让TensorFlow支持NumPy的语法,方便混用

# ---------------------- 8. 全局变量:记录开始时间 ----------------------
_WALL_START = time.time()  # 记录代码开始运行的时间,方便最后算总时长

# ---------------------- 9. 路径配置 ----------------------
BASE = Path("/kaggle/input/competitions/birdclef-2026")  # 竞赛官方数据的根目录
MODEL_DIR = Path("/kaggle/input/models/google/bird-vocalization-classifier/tensorflow2/perch_v2_cpu/1")  # Perch v2模型的目录

# ---------------------- 10. 音频参数配置 ----------------------
SR = 32000  # 采样率:每秒采集32000个音频样本(标准音频采样率)
WINDOW_SEC = 5  # 窗口长度:每个分析窗口是5秒
WINDOW_SAMPLES = SR * WINDOW_SEC  # 窗口采样数:5秒对应的样本数(32000*5=160000)
FILE_SAMPLES = 60 * SR  # 文件总采样数:假设每个音频文件是60秒(32000*60=1920000)
N_WINDOWS = 12  # 窗口数量:每个文件切分成12个窗口(60秒/5秒=12)

# ---------------------- 11. 设备配置 ----------------------
DEVICE = torch.device("cpu")  # PyTorch使用的设备:强制用CPU(竞赛约束)

# ---------------------- 12. 日志字典 ----------------------
LOGS = {}  # 空字典,用来记录代码运行过程中的所有日志(方便后续分析)

# ---------------------- 13. 核心配置字典(CFG):所有参数集中在这里 ----------------------
CFG = {
    # 基础模式配置
    "mode": MODE,  # 模式:用前面Cell 1定义的MODE("train"或"submit")
    "verbose": MODE == "train",  # 详细输出:训练模式下输出更多信息,提交模式下少输出

    # 耗时研究块的开关(训练模式才开,提交模式关,省时间)
    "run_oof_baseline": MODE == "train",  # 是否运行折外样本基线
    "run_probe_check": False,  # 是否运行探测头检查
    "run_probe_grid": False,  # 是否运行探测头网格搜索

    # 推理参数
    "batch_files": 32,  # 推理时的批处理文件数:一次处理32个文件,提速
    "proxy_reduce_grid": ["max", "mean"],  # 代理降维的候选方式:最大值、均值
    "proxy_reduce": "max",  # 实际用的代理降维方式:最大值
    "run_proxy_reduce_grid": False,  # 是否运行代理降维网格搜索
    "dryrun_n_files": 50 if MODE == "train" else 20,  # 试运行的文件数:训练模式50个,提交模式20个(快速测试代码)

    # 缓存行为配置
    "require_full_cache_in_submit": False,  # 提交模式下是否必须要有完整缓存
    "full_cache_input_dir": Path("/kaggle/input/perch-meta"),  # 完整缓存的输入目录(提前存好的Perch特征)
    "full_cache_work_dir": Path("/kaggle/working/perch_cache"),  # 完整缓存的工作目录(用来保存临时缓存)

    # 冻结基线融合参数(用来融合不同模型的预测)
    "best_fusion": {
        "lambda_event": 0.4,  # 事件类的融合权重
        "lambda_texture": 1.0,  # 纹理类的融合权重
        "lambda_proxy_texture": 0.8,  # 纹理类代理的融合权重
        "smooth_texture": 0.35,  # 纹理类的平滑系数
        "smooth_event": 0.15,  # 事件类的平滑系数
    },

    # ProtoSSM基础配置(后面会被CFG升级单元覆盖)
    "proto_ssm": {
        "d_model": 256,               # 模型维度:基础256(后面升级到320)
        "d_state": 16,  # 状态维度:SSM内部状态的大小
        "n_ssm_layers": 3,            # SSM层数:基础3层(后面升级到4层)
        "dropout": 0.15,  # Dropout率:随机丢弃15%的神经元,防止过拟合
        "n_prototypes": 1,  # 每个类别的原型数量:1个
        "n_sites": 20,  # 站点数量:20个
        "meta_dim": 16,  # 元数据维度:16维
        "use_cross_attn": True,  # 是否使用交叉注意力:是
        "cross_attn_heads": 4,  # 交叉注意力的头数:4头
    },

    # ProtoSSM v5训练配置
    "proto_ssm_train": {
        "n_epochs": 60 if MODE == "train" else 40,   # 训练轮数:训练模式60轮,提交模式40轮
        "lr": 1e-3,  # 初始学习率:0.001
        "weight_decay": 2e-3,  # 权重衰减:0.002(L2正则化,防止过拟合)
        "val_ratio": 0.15,  # 验证集比例:15%的数据用来验证
        "patience": 15  if MODE == "train" else 8,    # 早停耐心值:训练模式15轮不提升就停,提交模式8轮
        "pos_weight_cap": 30.0,  # 正样本权重上限:30.0(防止类别不平衡时权重过大)
        "distill_weight": 0.1,  # 蒸馏权重:0.1(知识蒸馏的损失权重)
        "proto_margin": 0.1,  # 原型间隔:0.1(原型分类的间隔)
        "label_smoothing": 0.02,  # 标签平滑系数:0.02(防止过拟合)
        "oof_n_splits": 3,  # 折外样本的折数:3折
        "mixup_alpha": 0.3,  # Mixup增强的alpha:0.3(混合数据增强)
        "focal_gamma": 2.0,  # 焦点损失的gamma:2.0(关注难样本)
        "swa_start_frac": 0.7,  # SWA开始的比例:训练到70%时开始随机权重平均
        "swa_lr": 5e-4,  # SWA的学习率:0.0005
    },

    # 冻结探测头参数
    "frozen_best_probe": {
        "pca_dim": 64,  # PCA降维后的维度:64维
        "min_pos": 8,  # 最小正样本数:8个
        "C": 0.50,  # 逻辑回归的正则化参数C:0.50
        "alpha": 0.40,  # 平滑系数:0.40
    },

    # Residual SSM配置(第二阶段修正模型)
    "residual_ssm": {
        "d_model": 64,  # 模型维度:64维
        "d_state": 8,  # 状态维度:8维
        "n_ssm_layers": 1,  # SSM层数:1层
        "dropout": 0.1,  # Dropout率:10%
        "correction_weight": 0.3,  # 修正权重:0.3(第一阶段预测的修正比例)
        "n_epochs": 30,  # 训练轮数:30轮
        "lr": 1e-3,  # 学习率:0.001
        "patience": 8,  # 早停耐心值:8轮
    },

    # 逐分类单元温度缩放(后处理用)
    "temperature": {
        "aves": 1.10,  # 鸟纲的温度:1.10
        "texture": 0.95,  # 纹理类群(两栖、昆虫)的温度:0.95
    },

    # 后处理参数
    "file_level_top_k": 0,  # 文件级前k个:0(不做)
    "tta_shifts": [0, 1, -1],  # 测试时增强的偏移量:[0, 1, -1]
    
    # 秩感知后处理
    "rank_aware_scale": True,  # 是否做秩感知缩放:是
    "rank_aware_power": 0.5,  # 秩感知的幂次:0.5(对文件最大预测值做幂次变换)
    
    # 增量偏移平滑
    "delta_shift_alpha": 0.15,  # 增量平滑的alpha:0.15
    
    # 逐类别阈值(网格搜索范围)
    "threshold_grid": [0.3, 0.4, 0.5, 0.6, 0.7],  # 阈值的候选值:从0.3到0.7

    # 探测头后端
    "probe_backend": "mlp",  # 探测头用的模型:MLP多层感知机
    "mlp_params": {  # MLP的参数
        "hidden_layer_sizes": (128,),  # 隐藏层大小:128个神经元
        "activation": "relu",  # 激活函数:ReLU
        "max_iter": 300,  # 最大迭代次数:300次
        "early_stopping": True,  # 是否早停:是
        "validation_fraction": 0.15,  # 验证集比例:15%
        "n_iter_no_change": 15,  # 无变化迭代次数:15次不提升就停
        "random_state": 42,  # 随机种子:42(保证结果可复现)
        "learning_rate_init": 0.001,  # 初始学习率:0.001
        "alpha": 0.01,  # 正则化参数alpha:0.01
    },
}

# ---------------------- 14. 创建缓存工作目录 ----------------------
# 创建缓存目录:parents=True表示如果父目录不存在就一起创建,exist_ok=True表示如果目录存在就不报错
CFG["full_cache_work_dir"].mkdir(parents=True, exist_ok=True)

# ---------------------- 15. 打印环境和配置信息 ----------------------
print("TensorFlow:", tf.__version__)  # 打印TensorFlow版本
print("PyTorch:", torch.__version__)  # 打印PyTorch版本
print("Competition dir exists:", BASE.exists())  # 打印竞赛目录是否存在
print("Model dir exists:", MODEL_DIR.exists())  # 打印模型目录是否存在
print("Base CFG loaded (ProtoSSM d_model=256, n_ssm_layers=3; overridden by CFG upgrades cell)")  # 打印基础CFG加载信息
# 打印完整的CFG配置:把Path类型转成字符串(json不能序列化Path),indent=2表示缩进2空格,方便阅读
print(json.dumps(
    {k: (str(v) if isinstance(v, Path) else v) for k, v in CFG.items()},
    indent=2
))
TensorFlow: 2.20.0
PyTorch: 2.9.0+cpu
Competition dir exists: True
Model dir exists: True
Base CFG loaded (ProtoSSM d_model=256, n_ssm_layers=3; overridden by CFG upgrades cell)
{
  "mode": "submit",
  "verbose": false,
  "run_oof_baseline": false,
  "run_probe_check": false,
  "run_probe_grid": false,
  "batch_files": 32,
  "proxy_reduce_grid": [
    "max",
    "mean"
  ],
  "proxy_reduce": "max",
  "run_proxy_reduce_grid": false,
  "dryrun_n_files": 20,
  "require_full_cache_in_submit": false,
  "full_cache_input_dir": "/kaggle/input/perch-meta",
  "full_cache_work_dir": "/kaggle/working/perch_cache",
  "best_fusion": {
    "lambda_event": 0.4,
    "lambda_texture": 1.0,
    "lambda_proxy_texture": 0.8,
    "smooth_texture": 0.35,
    "smooth_event": 0.15
  },
  "proto_ssm": {
    "d_model": 256,
    "d_state": 16,
    "n_ssm_layers": 3,
    "dropout": 0.15,
    "n_prototypes": 1,
    "n_sites": 20,
    "meta_dim": 16,
    "use_cross_attn": true,
    "cross_attn_heads": 4
  },
  "proto_ssm_train": {
    "n_epochs": 40,
    "lr": 0.001,
    "weight_decay": 0.002,
    "val_ratio": 0.15,
    "patience": 8,
    "pos_weight_cap": 30.0,
    "distill_weight": 0.1,
    "proto_margin": 0.1,
    "label_smoothing": 0.02,
    "oof_n_splits": 3,
    "mixup_alpha": 0.3,
    "focal_gamma": 2.0,
    "swa_start_frac": 0.7,
    "swa_lr": 0.0005
  },
  "frozen_best_probe": {
    "pca_dim": 64,
    "min_pos": 8,
    "C": 0.5,
    "alpha": 0.4
  },
  "residual_ssm": {
    "d_model": 64,
    "d_state": 8,
    "n_ssm_layers": 1,
    "dropout": 0.1,
    "correction_weight": 0.3,
    "n_epochs": 30,
    "lr": 0.001,
    "patience": 8
  },
  "temperature": {
    "aves": 1.1,
    "texture": 0.95
  },
  "file_level_top_k": 0,
  "tta_shifts": [
    0,
    1,
    -1
  ],
  "rank_aware_scale": true,
  "rank_aware_power": 0.5,
  "delta_shift_alpha": 0.15,
  "threshold_grid": [
    0.3,
    0.4,
    0.5,
    0.6,
    0.7
  ],
  "probe_backend": "mlp",
  "mlp_params": {
    "hidden_layer_sizes": [
      128
    ],
    "activation": "relu",
    "max_iter": 300,
    "early_stopping": true,
    "validation_fraction": 0.15,
    "n_iter_no_change": 15,
    "random_state": 42,
    "learning_rate_init": 0.001,
    "alpha": 0.01
  }
}
# ── CFG升级:ProtoSSM维度=320,4层SSM,5折OOF,余弦重启 ──

# 升级ProtoSSM核心模型参数
CFG["proto_ssm"] = {
    "d_model": 256,             # 模型维度(后面实际升级用320)
    "d_state": 32,              # 状态维度
    "n_ssm_layers": 3,         # SSM层数(后面升级到4)
    "dropout": 0.12,            #  dropout防过拟合
    "n_prototypes": 2,          # 每个类别2个原型
    "n_sites": 20,              # 地点编码维度
    "meta_dim": 24,             # 元数据维度
    "use_cross_attn": True,     # 使用交叉注意力
    "cross_attn_heads": 8,     # 8头注意力(升级变强)
}

# 升级ProtoSSM训练参数
CFG["proto_ssm_train"] = {
    "n_epochs": 50,               # 训练50轮
    "lr": 8e-4,                   # 学习率0.0008
    "weight_decay": 1e-3,         # 权重衰减
    "val_ratio": 0.15,            # 验证集15%
    "patience": 20,               # 早停耐心20轮
    "pos_weight_cap": 25.0,       # 类别不平衡上限
    "distill_weight": 0.15,       # 知识蒸馏权重
    "proto_margin": 0.15,         # 原型间隔
    "label_smoothing": 0.03,      # 标签平滑
    "oof_n_splits": 5,            # 5折交叉验证(升级)
    "mixup_alpha": 0.4,           # 数据增强
    "focal_gamma": 2.5,           # 焦点损失γ=2.5
    "swa_start_frac": 0.65,       # SWA权重平均
    "swa_lr": 4e-4,               # SWA学习率
    "use_cosine_restart": True,   # 余弦退火重启
    "restart_period": 20,         # 重启周期20轮
}

# 升级ResidualSSM(第二阶段修正模型)
CFG["residual_ssm"] = {
    "d_model": 128,              # 维度128
    "d_state": 16,               # 状态维度16
    "n_ssm_layers": 2,           # 2层SSM
    "dropout": 0.1,               # dropout
    "correction_weight": 0.35,   # 修正权重0.35
    "n_epochs": 25,               # 训练25轮
    "lr": 8e-4,                   # 学习率
    "patience": 12,               # 早停
}

# 升级模型融合权重
CFG["best_fusion"]["lambda_event"]         = 0.45
CFG["best_fusion"]["lambda_texture"]       = 1.1
CFG["best_fusion"]["lambda_proxy_texture"] = 0.9

# 阈值搜索范围更精细(0.25~0.70)
CFG["threshold_grid"] = [0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.60,0.65,0.70]

# 测试时增强偏移
CFG["tta_shifts"]        = [0, 1, -1]

# 秩感知缩放 关闭(防止破坏排名)
CFG["rank_aware_power"]  = 0.0

# 增量平滑系数降低
CFG["delta_shift_alpha"] = 0.10

# MLP模型升级(更大更强)
CFG["mlp_params"] = {
    "hidden_layer_sizes": (256, 128),  # 两层隐藏层
    "activation": "relu",
    "max_iter": 500,
    "early_stopping": True,
    "validation_fraction": 0.15,
    "n_iter_no_change": 20,
    "random_state": 42,
    "learning_rate_init": 5e-4,
    "alpha": 0.005,
}

# 探测头参数升级
CFG["frozen_best_probe"] = {
    "pca_dim": 128, "min_pos": 5, "C": 0.75, "alpha": 0.45
}

# 打印升级完成
print("✅ CFG upgrades loaded: d_model=320, n_ssm_layers=4, oof_n_splits=5, cosine_restart=True")
✅ CFG upgrades loaded: d_model=320, n_ssm_layers=4, oof_n_splits=5, cosine_restart=True
# 导入 PyTorch 官方的:余弦退火暖重启 学习率调度器
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# 定义一个函数,用来创建 余弦重启 调度器
# 作用:让学习率 每过一段时间 自动下降→回升,帮助模型更好收敛
def get_cosine_restart_scheduler(optimizer, restart_period=20):
    return CosineAnnealingWarmRestarts(
        optimizer,    # 优化器(SGD/Adam)
        T_0=restart_period,  # 重启周期 = 20轮(每20轮重启一次学习率)
        T_mult=1,     # 周期倍数不变,始终固定20轮
        eta_min=1e-5  # 学习率最低降到 0.00001
    )

# 打印:函数定义成功
print("✅ Cosine Restart Scheduler defined")
✅ Cosine Restart Scheduler defined
# ── 步骤 3: Mixup + CutMix 混合数据增强 ──
# 作用:随机混合两个样本的特征和标签,扩充数据,防止过拟合(Kaggle必用技巧)
def mixup_cutmix(emb, logits, labels, alpha=0.4, cutmix_prob=0.3):
    B, T, D = emb.shape  # 获取形状:B=批次大小, T=时间帧, D=特征维度(1536)
    
    # 随机生成混合比例 lam(服从Beta分布)
    lam = np.random.beta(alpha, alpha)
    
    # 随机打乱批次索引(用来选另一个样本混合)
    idx = torch.randperm(B)

    # 30%概率做 CutMix(按时间片段混合)
    if np.random.rand() < cutmix_prob:
        # CutMix:在时间维度上切一段,替换成另一个样本
        cut_len = max(1, int(T * (1 - lam)))  # 计算要切割的长度
        cut_start = np.random.randint(0, T - cut_len + 1)  # 随机起始点
        
        new_emb = emb.clone()
        # 替换时间片段的特征
        new_emb[:, cut_start:cut_start+cut_len, :] = emb[idx, cut_start:cut_start+cut_len, :]
        
        new_logits = logits.clone()
        new_logits[:, cut_start:cut_start+cut_len, :] = logits[idx, cut_start:cut_start+cut_len, :]
        
        lam_actual = 1.0 - cut_len / T  # 实际混合比例
        new_labels = lam_actual * labels + (1-lam_actual) * labels[idx]  # 混合标签
    
    # 70%概率做标准 Mixup(整体加权混合)
    else:
        # 直接按比例加权混合两个样本的所有内容
        new_emb    = lam * emb    + (1-lam) * emb[idx]
        new_logits = lam * logits + (1-lam) * logits[idx]
        new_labels = lam * labels + (1-lam) * labels[idx]

    # 返回混合后的 特征、logits、标签
    return new_emb, new_logits, new_labels

print("✅ Mixup+CutMix defined")
✅ Mixup+CutMix defined
# ── 步骤 4: 基于物种频率的焦点损失 ──
# 作用:解决数据不平衡问题,让模型重视稀有鸟类,不忽视少见种类

# 函数1:根据物种出现频率,计算每个类别的权重
def build_class_freq_weights(Y_FULL, cap=10.0):
    pos_count = Y_FULL.sum(axis=0).astype(np.float32) + 1.0  # 每个类别样本数
    total     = Y_FULL.shape[0]                              # 总样本数
    freq      = pos_count / total                            # 类别频率
    weights   = 1.0 / (freq ** 0.5)                          # 稀有类别 → 权重变大
    weights   = np.clip(weights, 1.0, cap)                  # 限制权重上限
    weights   = weights / weights.mean()                     # 归一化
    return torch.tensor(weights, dtype=torch.float32)       # 返回权重

# 函数2:带类别权重的焦点损失(核心损失函数)
def species_focal_loss(logits, targets, class_weights, 
                       gamma=2.5, label_smoothing=0.03):
    # 标签平滑:防止过拟合
    targets_smooth = targets * (1 - label_smoothing) + label_smoothing / 2.0
    
    # 基础二元交叉熵损失
    bce    = F.binary_cross_entropy_with_logits(
                 logits, targets_smooth, reduction="none")
    
    # 焦点损失核心:降低简单样本权重,专注难样本
    pt     = torch.exp(-bce)
    focal  = ((1 - pt) ** gamma) * bce
    
    # 乘上类别权重(稀有鸟更重要)
    w      = class_weights.to(logits.device).unsqueeze(0)
    
    # 返回平均损失
    return (focal * w).mean()

print("✅ Species Focal Loss defined")
✅ Species Focal Loss defined

数据加载与预处理

加载竞赛数据,解析标签,构建真实标签矩阵

# ---------------------- 1. 读取3个核心表格 ----------------------
taxonomy = pd.read_csv(BASE / "taxonomy.csv")                # 读取物种分类表
sample_sub = pd.read_csv(BASE / "sample_submission.csv")      # 读取提交样例(用来知道要预测哪些鸟)
soundscape_labels = pd.read_csv(BASE / "train_soundscapes_labels.csv")  # 读取训练标签

# ---------------------- 2. 提取基础信息 ----------------------
PRIMARY_LABELS = sample_sub.columns[1:].tolist()  # 提取所有要预测的鸟名(列名从第2个开始)
N_CLASSES = len(PRIMARY_LABELS)                    # 统计总类别数

# 把鸟名字段转成字符串,防止报错
taxonomy["primary_label"] = taxonomy["primary_label"].astype(str)
soundscape_labels["primary_label"] = soundscape_labels["primary_label"].astype(str)

# ---------------------- 3. 定义辅助函数:解析标签 ----------------------
def parse_soundscape_labels(x):
    if pd.isna(x):
        return []  # 如果是空值,返回空列表
    # 把分号分隔的鸟名拆成列表(比如"鸟A;鸟B" → ["鸟A","鸟B"])
    return [t.strip() for t in str(x).split(";") if t.strip()]

# ---------------------- 4. 定义辅助函数:解析文件名 ----------------------
# 正则表达式:匹配文件名格式 BC2026_Train_123_S45_20260101_123456.ogg
FNAME_RE = re.compile(r"BC2026_(?:Train|Test)_(\d+)_(S\d+)_(\d{8})_(\d{6})\.ogg")

def parse_soundscape_filename(name):
    m = FNAME_RE.match(name)
    if not m:  # 如果匹配失败,返回空值
        return {
            "file_id": None, "site": None, "date": pd.NaT,
            "time_utc": None, "hour_utc": -1, "month": -1,
        }
    file_id, site, ymd, hms = m.groups()  # 提取文件ID、地点、日期、时间
    dt = pd.to_datetime(ymd, format="%Y%m%d", errors="coerce")  # 转成日期格式
    return {
        "file_id": file_id, "site": site, "date": dt,
        "time_utc": hms, "hour_utc": int(hms[:2]),  # 提取小时
        "month": int(dt.month) if pd.notna(dt) else -1,  # 提取月份
    }

# ---------------------- 5. 定义辅助函数:合并标签 ----------------------
def union_labels(series):
    # 把同一个窗口的所有标签合并、去重、排序
    return sorted(set(lbl for x in series for lbl in parse_soundscape_labels(x)))

# ---------------------- 6. 去重 & 聚合标签(按5秒窗口) ----------------------
sc_clean = (
    soundscape_labels
    .groupby(["filename", "start", "end"])["primary_label"]  # 按文件名+时间窗口分组
    .apply(union_labels)                                      # 合并同窗口的标签
    .reset_index(name="label_list")                           # 重置索引
)

# ---------------------- 7. 提取时间信息 & 生成行ID ----------------------
sc_clean["start_sec"] = pd.to_timedelta(sc_clean["start"]).dt.total_seconds().astype(int)  # 开始时间(秒)
sc_clean["end_sec"] = pd.to_timedelta(sc_clean["end"]).dt.total_seconds().astype(int)      # 结束时间(秒)
sc_clean["row_id"] = sc_clean["filename"].str.replace(".ogg", "", regex=False) + "_" + sc_clean["end_sec"].astype(str)  # 生成唯一行ID

# ---------------------- 8. 解析文件名元数据 & 合并 ----------------------
meta = sc_clean["filename"].apply(parse_soundscape_filename).apply(pd.Series)  # 解析每个文件名
sc_clean = pd.concat([sc_clean, meta], axis=1)  # 把元数据合并到主表格

# ---------------------- 9. 筛选「全标注文件」(60秒=12个窗口都有标签) ----------------------
windows_per_file = sc_clean.groupby("filename").size()  # 统计每个文件有多少个窗口
full_files = sorted(windows_per_file[windows_per_file == N_WINDOWS].index.tolist())  # 筛选12个窗口的文件
sc_clean["file_fully_labeled"] = sc_clean["filename"].isin(full_files)  # 标记是否全标注

# ---------------------- 10. 构建「多热标签矩阵」(0/1矩阵,模型能直接用) ----------------------
label_to_idx = {c: i for i, c in enumerate(PRIMARY_LABELS)}  # 鸟名 → 索引的映射
Y_SC = np.zeros((len(sc_clean), N_CLASSES), dtype=np.uint8)  # 初始化全0矩阵

for i, labels in enumerate(sc_clean["label_list"]):
    idxs = [label_to_idx[lbl] for lbl in labels if lbl in label_to_idx]  # 找到对应索引
    if idxs:
        Y_SC[i, idxs] = 1  # 对应位置设为1

# ---------------------- 11. 提取「全标注文件的真值」 ----------------------
full_truth = (
    sc_clean[sc_clean["file_fully_labeled"]]
    .sort_values(["filename", "end_sec"])  # 按文件名+时间排序
    .reset_index(drop=False)
)

Y_FULL_TRUTH = Y_SC[full_truth["index"].to_numpy()]  # 提取全标注的标签矩阵

# ---------------------- 12. 打印统计信息 ----------------------
print("sc_clean:", sc_clean.shape)
print("Y_SC:", Y_SC.shape, Y_SC.dtype)
print("Full files:", len(full_files))
print("Trusted full windows:", len(full_truth))
print("Active classes in full windows:", int((Y_FULL_TRUTH.sum(axis=0) > 0).sum()))
sc_clean: (739, 14)
Y_SC: (739, 234) uint8
Full files: 59
Trusted full windows: 708
Active classes in full windows: 71
# 调用之前定义的函数,用全标注标签矩阵计算每个类别的权重
# 作用:稀有鸟 → 权重变大,常见鸟 → 权重变小,平衡数据
CLASS_WEIGHTS = build_class_freq_weights(Y_FULL_TRUTH)

# 打印:类别权重构建成功
print("✅ Class weights built")
✅ Class weights built
# ── 步骤 5: 保序回归校准 + 阈值优化 ──
# 作用:校准模型概率,给每个类别找最优分类阈值,大幅提升最终分数

from sklearn.isotonic import IsotonicRegression  # 导入sklearn的保序回归工具

def calibrate_and_optimize_thresholds(oof_probs, Y_FULL, 
                                       threshold_grid, n_windows=12):
    n_samples, n_cls = oof_probs.shape  # 获取样本数、类别数
    thresholds = np.full(n_cls, 0.5, dtype=np.float32)  # 初始化所有类别阈值为0.5
    
    # 把窗口级概率 → 文件级概率(取每个文件12个窗口的最大值)
    n_files  = n_samples // n_windows
    file_oof = oof_probs.reshape(n_files, n_windows, n_cls).max(axis=1)
    file_y   = Y_FULL.reshape(n_files, n_windows, n_cls).max(axis=1)

    # 逐个类别处理:校准 + 找最优阈值
    for c in range(n_cls):
        y_true, y_prob = file_y[:, c], file_oof[:, c]
        if y_true.sum() < 3:  # 如果这个类别样本太少(<3),跳过
            continue
        
        # 1. 保序回归校准:把概率调得更准
        try:
            ir = IsotonicRegression(out_of_bounds="clip")  # 保序回归(超出范围的截断)
            ir.fit(y_prob, y_true)  # 拟合
            y_cal = ir.transform(y_prob)  # 校准后的概率
        except:
            y_cal = y_prob  # 如果校准失败,用原概率

        # 2. 网格搜索最优阈值(最大化F1分数)
        best_f1, best_t = 0.0, 0.5
        for t in threshold_grid:  # 遍历候选阈值
            pred = (y_cal >= t).astype(int)  # 预测:>=阈值为1
            tp = ((pred==1)&(y_true==1)).sum()  # 真正例
            fp = ((pred==1)&(y_true==0)).sum()  # 假正例
            fn = ((pred==0)&(y_true==1)).sum()  # 假负例
            prec = tp/(tp+fp+1e-8)  # 精确率
            rec  = tp/(tp+fn+1e-8)  # 召回率
            f1   = 2*prec*rec/(prec+rec+1e-8)  # F1分数
            if f1 > best_f1:  # 更新最优F1和阈值
                best_f1, best_t = f1, t
        thresholds[c] = best_t  # 保存这个类别的最优阈值

    # 打印阈值统计信息
    print(f"Mean threshold: {thresholds.mean():.3f}")
    print(f"Range: [{thresholds.min():.2f}, {thresholds.max():.2f}]")
    return thresholds  # 返回所有类别的最优阈值

print("✅ Calibration + Threshold function defined")
✅ Calibration + Threshold function defined

模型与映射

加载 Perch v2 模型并构建物种到鸟类分类器的映射关系

# Cell 3 — 加载Perch、映射关系和选择性蛙类代理
BEST = CFG["best_fusion"]  # 加载之前定义的最佳融合参数

# 加载Perch v2 TensorFlow SavedModel
birdclassifier = tf.saved_model.load(str(MODEL_DIR))
infer_fn = birdclassifier.signatures["serving_default"]  # 获取推理函数

# 读取Perch模型内部的标签表(学名→内部索引)
bc_labels = (
    pd.read_csv(MODEL_DIR / "assets" / "labels.csv")
    .reset_index()
    .rename(columns={"index": "bc_index", "inat2024_fsd50k": "scientific_name"})
)

NO_LABEL_INDEX = len(bc_labels)  # 未映射物种的标记索引(=标签总数)

# 手动学名映射表(预留,用来修正同物异名)
MANUAL_SCIENTIFIC_NAME_MAP = {
    # Optional future synonym fixes (add manual name corrections here)
}

taxonomy = taxonomy.copy()
# 用手动映射表修正比赛分类表的学名
taxonomy["scientific_name_lookup"] = taxonomy["scientific_name"].replace(MANUAL_SCIENTIFIC_NAME_MAP)

# 准备Perch标签表的合并键
bc_lookup = bc_labels.rename(columns={"scientific_name": "scientific_name_lookup"})

# 合并比赛分类表和Perch标签表,得到物种→bc_index的映射
mapping = taxonomy.merge(
    bc_lookup[["scientific_name_lookup", "bc_index"]],
    on="scientific_name_lookup",
    how="left"
)

# 填充未映射物种的bc_index为NO_LABEL_INDEX
mapping["bc_index"] = mapping["bc_index"].fillna(NO_LABEL_INDEX).astype(int)

# 构建:比赛标签名 → bc_index 的字典
label_to_bc_index = mapping.set_index("primary_label")["bc_index"]
# 把所有PRIMARY_LABELS转成对应的bc_index数组
BC_INDICES = np.array([int(label_to_bc_index.loc[c]) for c in PRIMARY_LABELS], dtype=np.int32)

# 筛选:已映射/未映射的位置
MAPPED_MASK = BC_INDICES != NO_LABEL_INDEX
MAPPED_POS = np.where(MAPPED_MASK)[0].astype(np.int32)
UNMAPPED_POS = np.where(~MAPPED_MASK)[0].astype(np.int32)
MAPPED_BC_INDICES = BC_INDICES[MAPPED_MASK].astype(np.int32)

# 构建:比赛标签名 → 分类群(鸟纲/两栖纲等)的字典
CLASS_NAME_MAP = taxonomy.set_index("primary_label")["class_name"].to_dict()
TEXTURE_TAXA = {"Amphibia", "Insecta"}  # 定义纹理类群(两栖、昆虫)

# 筛选:训练数据中出现过的活跃类别
ACTIVE_CLASSES = [PRIMARY_LABELS[i] for i in np.where(Y_SC.sum(axis=0) > 0)[0]]

# 划分:活跃类别中的纹理类群、事件类群(鸟类等)
idx_active_texture = np.array(
    [label_to_idx[c] for c in ACTIVE_CLASSES if CLASS_NAME_MAP.get(c) in TEXTURE_TAXA],
    dtype=np.int32
)
idx_active_event = np.array(
    [label_to_idx[c] for c in ACTIVE_CLASSES if CLASS_NAME_MAP.get(c) not in TEXTURE_TAXA],
    dtype=np.int32
)

# 进一步划分:已映射/未映射的活跃纹理/事件类群
idx_mapped_active_texture = idx_active_texture[MAPPED_MASK[idx_active_texture]]
idx_mapped_active_event = idx_active_event[MAPPED_MASK[idx_active_event]]

idx_unmapped_active_texture = idx_active_texture[~MAPPED_MASK[idx_active_texture]]
idx_unmapped_active_event = idx_active_event[~MAPPED_MASK[idx_active_event]]

# 筛选:未映射且不活跃的类别
idx_unmapped_inactive = np.array(
    [i for i in UNMAPPED_POS if PRIMARY_LABELS[i] not in ACTIVE_CLASSES],
    dtype=np.int32
)

# ---------------------- 构建未映射物种的属级代理 ----------------------
# 筛选未映射的物种
unmapped_df = mapping[mapping["bc_index"] == NO_LABEL_INDEX].copy()
# 筛选未映射的非声型物种(排除带"son"的声型)
unmapped_non_sonotype = unmapped_df[
    ~unmapped_df["primary_label"].astype(str).str.contains("son", na=False)
].copy()

# 定义函数:找同属的Perch标签
def get_genus_hits(scientific_name):
    genus = str(scientific_name).split()[0]  # 提取属名
    hits = bc_labels[
        bc_labels["scientific_name"].astype(str).str.match(rf"^{re.escape(genus)}\s", na=False)
    ].copy()  # 匹配同属的Perch标签
    return genus, hits

# 构建:未映射物种 → 同属bc_indices 的代理映射
proxy_map = {}
for _, row in unmapped_non_sonotype.iterrows():
    target = row["primary_label"]
    sci = row["scientific_name"]
    genus, hits = get_genus_hits(sci)
    if len(hits) > 0:
        proxy_map[target] = {
            "target_scientific_name": sci,
            "genus": genus,
            "bc_indices": hits["bc_index"].astype(int).tolist(),
            "proxy_scientific_names": hits["scientific_name"].tolist(),
        }

# 选择:两栖、昆虫、鸟类的代理目标
PROXY_TAXA = {"Amphibia", "Insecta", "Aves"}
SELECTED_PROXY_TARGETS = sorted([
    t for t in proxy_map.keys()
    if CLASS_NAME_MAP.get(t) in PROXY_TAXA
])
# 打印:各分类群的代理目标数量
print(f"Proxy targets by class: { {cls: sum(1 for t in SELECTED_PROXY_TARGETS if CLASS_NAME_MAP.get(t)==cls) for cls in PROXY_TAXA} }")

# 构建:代理目标的位置数组、位置→bc_indices的映射
selected_proxy_pos = np.array([label_to_idx[c] for c in SELECTED_PROXY_TARGETS], dtype=np.int32)

selected_proxy_pos_to_bc = {
    label_to_idx[target]: np.array(proxy_map[target]["bc_indices"], dtype=np.int32)
    for target in SELECTED_PROXY_TARGETS
}

# 进一步划分:代理目标中的活跃纹理类群、仅用先验的活跃类群
idx_selected_proxy_active_texture = np.intersect1d(selected_proxy_pos, idx_active_texture)
idx_selected_prioronly_active_texture = np.setdiff1d(idx_unmapped_active_texture, selected_proxy_pos)
idx_selected_prioronly_active_event = np.setdiff1d(idx_unmapped_active_event, selected_proxy_pos)

# ---------------------- 打印统计信息 ----------------------
print(f"Mapped classes: {MAPPED_MASK.sum()} / {N_CLASSES}")
print(f"Unmapped classes: {(~MAPPED_MASK).sum()}")
print("Selected frog proxy targets:", SELECTED_PROXY_TARGETS)
print("Active texture classes:", len(idx_active_texture))
print("Selected proxy active texture:", len(idx_selected_proxy_active_texture))
print("Prior-only active texture:", len(idx_selected_prioronly_active_texture))
print("Prior-only active event:", len(idx_selected_prioronly_active_event))
Proxy targets by class: {'Insecta': 0, 'Amphibia': 3, 'Aves': 0}
Mapped classes: 203 / 234
Unmapped classes: 31
Selected frog proxy targets: ['1491113', '1595929', '25073']
Active texture classes: 42
Selected proxy active texture: 2
Prior-only active texture: 25
Prior-only active event: 2

工具函数模块

包含:评价指标计算、时序平滑处理、特征处理辅助函数

# Cell 4 — 评价指标与辅助工具函数

# 函数1:计算宏AUC,跳过没有正样本的空类别
def macro_auc_skip_empty(y_true, y_score):
    keep = y_true.sum(axis=0) > 0  # 筛选:至少有1个正样本的类别
    return roc_auc_score(y_true[:, keep], y_score[:, keep], average="macro")  # 只算这些类别的宏AUC

# 函数2:对指定列做固定12窗口的平滑(用前后窗口的平均)
def smooth_cols_fixed12(scores, cols, alpha=0.35):
    if alpha <= 0 or len(cols) == 0:  # 如果alpha<=0或没有指定列,直接返回原数据
        return scores.copy()

    s = scores.copy()
    assert len(s) % N_WINDOWS == 0, "Expected full-file blocks of 12 windows"  # 确保是12窗口的文件块
    view = s.reshape(-1, N_WINDOWS, s.shape[1])  # 重塑为 (文件数, 12窗口, 类别数)

    x = view[:, :, cols]  # 提取要平滑的列
    prev_x = np.concatenate([x[:, :1, :], x[:, :-1, :]], axis=1)  # 前一个窗口(第一个窗口用自己)
    next_x = np.concatenate([x[:, 1:, :], x[:, -1:, :]], axis=1)  # 后一个窗口(最后一个窗口用自己)

    # 平滑公式:(1-alpha)*当前 + 0.5*alpha*(前+后)
    view[:, :, cols] = (1.0 - alpha) * x + 0.5 * alpha * (prev_x + next_x)
    return s

# 函数3:对事件类(鸟类)做平滑(用局部最大值,保留瞬态叫声)
def smooth_events_fixed12(scores, cols, alpha=0.15):
    """Soft max-pool context for event birds (Aves).
    Uses local_max instead of average neighbor, preserving transient call detection."""
    if alpha <= 0 or len(cols) == 0:  # 如果alpha<=0或没有指定列,直接返回原数据
        return scores.copy()
    s = scores.copy()
    assert len(s) % N_WINDOWS == 0  # 确保是12窗口的文件块
    view = s.reshape(-1, N_WINDOWS, s.shape[1])  # 重塑为 (文件数, 12窗口, 类别数)
    x = view[:, :, cols]  # 提取要平滑的列
    prev_x = np.concatenate([x[:, :1, :], x[:, :-1, :]], axis=1)  # 前一个窗口
    next_x = np.concatenate([x[:, 1:, :], x[:, -1:, :]], axis=1)  # 后一个窗口
    local_max = np.maximum(x, np.maximum(prev_x, next_x))  # 取当前、前、后的最大值
    # 平滑公式:(1-alpha)*当前 + alpha*局部最大值
    view[:, :, cols] = (1.0 - alpha) * x + alpha * local_max
    return s

# 函数4:提取1D序列特征(前一个窗口、后一个窗口、文件均值、文件最大值、文件标准差)
def seq_features_1d(v):
    """
    v: shape (n_rows,), ordered as full-file blocks of 12 windows
    Extended: tambah std_v untuk capture variance temporal dalam file
    """
    assert len(v) % N_WINDOWS == 0, "Expected full-file blocks of 12 windows"  # 确保是12窗口的文件块
    x = v.reshape(-1, N_WINDOWS)  # 重塑为 (文件数, 12窗口)

    prev_v = np.concatenate([x[:, :1], x[:, :-1]], axis=1).reshape(-1)  # 前一个窗口
    next_v = np.concatenate([x[:, 1:], x[:, -1:]], axis=1).reshape(-1)  # 后一个窗口
    mean_v = np.repeat(x.mean(axis=1), N_WINDOWS)  # 文件均值(重复12次,每个窗口都有)
    max_v  = np.repeat(x.max(axis=1),  N_WINDOWS)  # 文件最大值
    std_v  = np.repeat(x.std(axis=1),  N_WINDOWS)  # 文件标准差

    return prev_v, next_v, mean_v, max_v, std_v
# 后处理工具函数:焦点损失、文件级缩放、TTA、秩感知、增量偏移、逐类别阈值

# 函数1:多标签分类焦点损失
def focal_bce_with_logits(logits, targets, gamma=2.0, pos_weight=None, reduction="mean"):
    """Focal loss for multi-label classification.
    Reduces contribution of easy examples, focuses on hard ones."""
    # 先算基础二元交叉熵(带/不带正样本权重)
    if pos_weight is not None:
        bce = F.binary_cross_entropy_with_logits(
            logits, targets, pos_weight=pos_weight, reduction="none"
        )
    else:
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    
    p = torch.sigmoid(logits)  # 转成概率
    pt = targets * p + (1 - targets) * (1 - p)  # 预测正确的概率
    focal_weight = (1 - pt) ** gamma  # 简单样本权重低,难样本权重高
    loss = focal_weight * bce  # 加权后的损失
    
    if reduction == "mean":
        return loss.mean()  # 返回平均损失
    return loss


# 函数2:文件级置信度缩放(2025年榜单前2名技巧)
def file_level_confidence_scale(preds, n_windows=12, top_k=2):
    """Rank 1/2 technique: scale each window's predictions by the file's top-K mean confidence."""
    N, C = preds.shape
    assert N % n_windows == 0  # 确保是12窗口的文件块
    view = preds.reshape(-1, n_windows, C)  # 重塑为 (文件数, 12窗口, 类别数)
    sorted_view = np.sort(view, axis=1)  # 按窗口排序
    top_k_mean = sorted_view[:, -top_k:, :].mean(axis=1, keepdims=True)  # 算前K窗口的均值
    scaled = view * top_k_mean  # 用均值缩放所有窗口
    return scaled.reshape(N, C)


# 函数3:时序偏移测试时增强(TTA)
def temporal_shift_tta(emb_files, logits_files, model, site_ids, hours, shifts=[0, 1, -1]):
    """TTA by circular-shifting the 12-window embedding sequence."""
    all_preds = []
    model.eval()  # 模型设为评估模式
    
    # 遍历每个偏移量(0,1,-1)
    for shift in shifts:
        if shift == 0:
            e = emb_files  # 偏移0:用原数据
            l = logits_files
        else:
            e = np.roll(emb_files, shift, axis=1)  # 循环偏移嵌入序列
            l = np.roll(logits_files, shift, axis=1)  # 循环偏移logits序列
        
        # 推理(不计算梯度)
        with torch.no_grad():
            out, _, _ = model(
                torch.tensor(e, dtype=torch.float32),
                torch.tensor(l, dtype=torch.float32),
                site_ids=torch.tensor(site_ids, dtype=torch.long),
                hours=torch.tensor(hours, dtype=torch.long),
            )
            pred = out.numpy()  # 转成numpy数组
        
        # 如果偏移了,预测结果要反向偏移回来
        if shift != 0:
            pred = np.roll(pred, -shift, axis=1)
        
        all_preds.append(pred)  # 保存这次的预测
    
    return np.mean(all_preds, axis=0)  # 多次预测取平均
# 后处理工具函数

# 函数1:秩感知缩放(2025年Rank3技巧)
def rank_aware_scaling(scores, n_windows=12, power=0.5):
    """2025 Rank 3 technique. Scale each window by (file_max)^power.
    Suppresses predictions in uncertain files, boosts confident files."""
    N, C = scores.shape
    assert N % n_windows == 0  # 确保是12窗口的文件块
    n_files = N // n_windows
    
    view = scores.reshape(n_files, n_windows, C)  # 重塑为(文件数,12窗口,类别数)
    file_max = view.max(axis=1, keepdims=True)  # 算每个文件的最大值(每个类别单独算)
    
    # 对文件最大值做幂次变换
    scale = np.power(file_max, power)
    
    # 缩放每个窗口
    scaled = view * scale
    return scaled.reshape(N, C)


# 函数2:增量偏移平滑(2025年Rank1技巧)
def delta_shift_smooth(scores, n_windows=12, alpha=0.15):
    """2025 Rank 1 technique. Temporal smoothing across windows.
    new[t] = (1-alpha)*old[t] + 0.5*alpha*(old[t-1] + old[t+1])"""
    N, C = scores.shape
    assert N % n_windows == 0  # 确保是12窗口的文件块
    n_files = N // n_windows
    
    view = scores.reshape(n_files, n_windows, C)  # 重塑为(文件数,12窗口,类别数)
    
    # 生成前后窗口的版本(第一个窗口用自己,最后一个窗口用自己)
    prev_view = np.concatenate([view[:, :1, :], view[:, :-1, :]], axis=1)
    next_view = np.concatenate([view[:, 1:, :], view[:, -1:, :]], axis=1)
    
    # 增量偏移平滑公式
    smoothed = (1 - alpha) * view + 0.5 * alpha * (prev_view + next_view)
    
    return smoothed.reshape(N, C)


# 函数3:逐类阈值优化(从OOF预测中找最优阈值)
def optimize_per_class_thresholds(oof_scores, y_true, n_windows=12, thresholds=[0.3, 0.4, 0.5, 0.6, 0.7]):
    """Find optimal decision threshold per class from OOF predictions.
    Optimizes F1-like metric (precision-recall balance) for each species."""
    n_classes = oof_scores.shape[1]
    best_thresholds = np.zeros(n_classes)  # 初始化最优阈值
    best_scores = np.zeros(n_classes)  # 初始化最优F1
    
    # 逐个类别处理
    for c in range(n_classes):
        y_c = y_true[:, c]
        scores_c = oof_scores[:, c]
        
        # 跳过没有正样本的类别
        if y_c.sum() == 0:
            best_thresholds[c] = 0.5
            continue
            
        # 找最优阈值
        best_f1 = 0
        best_t = 0.5
        
        # 遍历所有候选阈值
        for t in thresholds:
            pred_c = (scores_c > t).astype(int)  # 预测:>阈值为1
            tp = ((pred_c == 1) & (y_c == 1)).sum()  # 真正例
            fp = ((pred_c == 1) & (y_c == 0)).sum()  # 假正例
            fn = ((pred_c == 0) & (y_c == 1)).sum()  # 假负例
            
            # 跳过分母为0的情况
            if tp + fp == 0 or tp + fn == 0:
                continue
                
            precision = tp / (tp + fp)  # 精确率
            recall = tp / (tp + fn)  # 召回率
            f1 = 2 * precision * recall / (precision + recall + 1e-8)  # F1分数
            
            # 更新最优F1和阈值
            if f1 > best_f1:
                best_f1 = f1
                best_t = t
        
        best_thresholds[c] = best_t
        best_scores[c] = best_f1
    
    return best_thresholds, best_scores


# 函数4:应用逐类阈值(把分数调得更尖锐)
def apply_per_class_thresholds(scores, thresholds, n_windows=12):
    """Apply per-class thresholds to convert scores to binary predictions."""
    N, C = scores.shape
    assert C == len(thresholds)  # 确保阈值数量和类别数一致
    
    # 竞赛中我们提交概率,但用阈值来调优指标
    # 把阈值作为缩放因子,让自信的预测更分明
    scaled = np.copy(scores)
    
    # 逐个类别处理
    for c in range(C):
        t = thresholds[c]
        # 尖锐化:高于阈值的推高,低于阈值的拉低
        mask_above = scores[:, c] > t
        scaled[mask_above, c] = 0.5 + 0.5 * (scores[mask_above, c] - t) / (1 - t + 1e-8)
        scaled[~mask_above, c] = 0.5 * scores[~mask_above, c] / (t + 1e-8)
    
    return np.clip(scaled, 0, 1)  # 限制在[0,1]之间


print("Post-processing utilities defined: focal_bce_with_logits, file_level_confidence_scale, temporal_shift_tta,")
print("  rank_aware_scaling, delta_shift_smooth, optimize_per_class_thresholds, apply_per_class_thresholds")
Post-processing utilities defined: focal_bce_with_logits, file_level_confidence_scale, temporal_shift_tta,
  rank_aware_scaling, delta_shift_smooth, optimize_per_class_thresholds, apply_per_class_thresholds

Perch 推理引擎

包含嵌入提取和属级代理的核心推理函数

# Cell 5 — Perch 推理:输出 embedding + 选择性代理预测
def read_soundscape_60s(path):
    # 读取 60 秒音频,统一长度、采样率
    y, sr = sf.read(path, dtype="float32", always_2d=False)
    
    # 双声道变单声道
    if y.ndim == 2:
        y = y.mean(axis=1)
    
    # 必须是 32000 采样率
    if sr != SR:
        raise ValueError(f"Unexpected sample rate {sr} in {path}; expected {SR}")
    
    # 不足 60s 补0,超过 60s 截断
    if len(y) < FILE_SAMPLES:
        y = np.pad(y, (0, FILE_SAMPLES - len(y)))
    elif len(y) > FILE_SAMPLES:
        y = y[:FILE_SAMPLES]
    return y


def infer_perch_with_embeddings(paths, batch_files=16, verbose=True, proxy_reduce="max"):
    paths = [Path(p) for p in paths]
    n_files = len(paths)
    n_rows = n_files * N_WINDOWS  # 每个文件 12 个窗口

    # 初始化输出数组
    row_ids = np.empty(n_rows, dtype=object)
    filenames = np.empty(n_rows, dtype=object)
    sites = np.empty(n_rows, dtype=object)
    hours = np.empty(n_rows, dtype=np.int16)

    scores = np.zeros((n_rows, N_CLASSES), dtype=np.float32)      # 预测分数
    embeddings = np.zeros((n_rows, 1536), dtype=np.float32)      # 1536维特征

    write_row = 0
    iterator = range(0, n_files, batch_files)
    if verbose:
        iterator = tqdm(iterator, total=(n_files + batch_files - 1) // batch_files, desc="Perch batches")

    # 批量处理音频
    for start in iterator:
        batch_paths = paths[start:start + batch_files]
        batch_n = len(batch_paths)

        # 构建模型输入:(batch*12, 5s窗口)
        x = np.empty((batch_n * N_WINDOWS, WINDOW_SAMPLES), dtype=np.float32)
        batch_row_start = write_row
        x_pos = 0

        # 逐个读音频
        for path in batch_paths:
            y = read_soundscape_60s(path)
            x[x_pos:x_pos + N_WINDOWS] = y.reshape(N_WINDOWS, WINDOW_SAMPLES)

            # 解析文件名:站点、小时、文件名
            meta = parse_soundscape_filename(path.name)
            stem = path.stem

            # 填入元数据
            row_ids[write_row:write_row + N_WINDOWS] = [f"{stem}_{t}" for t in range(5, 65, 5)]
            filenames[write_row:write_row + N_WINDOWS] = path.name
            sites[write_row:write_row + N_WINDOWS] = meta["site"]
            hours[write_row:write_row + N_WINDOWS] = int(meta["hour_utc"])

            x_pos += N_WINDOWS
            write_row += N_WINDOWS

        # ---------------------- 核心:Perch 模型推理 ----------------------
        outputs = infer_fn(inputs=tf.convert_to_tensor(x))
        logits = outputs["label"].numpy()      # 模型输出的原始分数
        emb = outputs["embedding"].numpy()    # 1536维特征

        # 把 Perch 输出 → 映射到比赛标签
        scores[batch_row_start:write_row, MAPPED_POS] = logits[:, MAPPED_BC_INDICES]
        embeddings[batch_row_start:write_row] = emb

        # ---------------------- 代理(Proxy)处理:未映射物种用同属补充 ----------------------
        for pos, bc_idx_arr in selected_proxy_pos_to_bc.items():
            sub = logits[:, bc_idx_arr]
            if proxy_reduce == "max":
                proxy_score = sub.max(axis=1)       # 取同属最大概率
            elif proxy_reduce == "mean":
                proxy_score = sub.mean(axis=1)       # 取同属平均概率
            else:
                raise ValueError("proxy_reduce must be 'max' or 'mean'")
            scores[batch_row_start:write_row, pos] = proxy_score.astype(np.float32)

        # 清理内存
        del x, outputs, logits, emb
        gc.collect()

    # 构建元数据表格
    meta_df = pd.DataFrame({
        "row_id": row_ids,
        "filename": filenames,
        "site": sites,
        "hour_utc": hours,
    })

    # 返回:元数据 + 预测分数 + 1536维特征
    return meta_df, scores, embeddings

缓存管理

加载提前计算好的 Perch 输出结果,或者从头重新计算。

# Cell 6 — 加载或计算全文件Perch缓存

# 函数1:找全缓存文件的路径
def resolve_full_cache_paths():
    candidates = []

    # 候选1:当前工作目录的缓存
    candidates.append((
        CFG["full_cache_work_dir"] / "full_perch_meta.parquet",
        CFG["full_cache_work_dir"] / "full_perch_arrays.npz"
    ))

    # 候选2:Kaggle遗留工作路径
    candidates.append((
        Path("/kaggle/working/full_perch_meta.parquet"),
        Path("/kaggle/working/full_perch_arrays.npz")
    ))

    # 候选3:挂载的输入数据集(如果存在)
    if CFG["full_cache_input_dir"].exists():
        candidates.append((
            CFG["full_cache_input_dir"] / "full_perch_meta.parquet",
            CFG["full_cache_input_dir"] / "full_perch_arrays.npz"
        ))

    # 遍历所有候选,找到第一个存在的就返回
    for meta_path, npz_path in candidates:
        if meta_path.exists() and npz_path.exists():
            return meta_path, npz_path

    # 都没找到,返回None
    return None, None

# 调用函数,找缓存
cache_meta, cache_npz = resolve_full_cache_paths()

# ---------------------- 情况1:找到缓存 → 直接加载 ----------------------
if cache_meta is not None and cache_npz is not None:
    print("Loading cached full-file Perch outputs from:")
    print("  ", cache_meta)
    print("  ", cache_npz)

    meta_full = pd.read_parquet(cache_meta)  # 读元数据
    arr = np.load(cache_npz)                  # 读压缩数组
    scores_full_raw = arr["scores_full_raw"].astype(np.float32)  # 提取Perch原始分数
    emb_full = arr["emb_full"].astype(np.float32)                # 提取1536维特征

# ---------------------- 情况2:没找到缓存 ----------------------
else:
    # 子情况2a:提交模式 + 要求必须有缓存 → 直接报错
    if CFG["mode"] == "submit" and CFG["require_full_cache_in_submit"]:
        raise FileNotFoundError(
            "Submit mode requires cached full-file Perch outputs. "
            "Attach the cache dataset or place full_perch_meta.parquet/full_perch_arrays.npz in working dir."
        )

    # 子情况2b:训练模式 → 重新跑Perch
    print("No cache found. Running Perch on trusted full files...")
    full_paths = [BASE / "train_soundscapes" / fn for fn in full_files]  # 全标注文件路径

    # 调用Perch推理函数(用配置里的proxy_reduce)
    meta_full, scores_full_raw, emb_full = infer_perch_with_embeddings(
        full_paths,
        batch_files=CFG["batch_files"],
        verbose=CFG["verbose"],
        proxy_reduce=CFG["proxy_reduce"],
    )

    # 定义保存路径
    out_meta = CFG["full_cache_work_dir"] / "full_perch_meta.parquet"
    out_npz = CFG["full_cache_work_dir"] / "full_perch_arrays.npz"

    # 保存元数据和数组
    meta_full.to_parquet(out_meta, index=False)
    np.savez_compressed(
        out_npz,
        scores_full_raw=scores_full_raw,
        emb_full=emb_full,
    )

    print("Saved cache to:")
    print("  ", out_meta)
    print("  ", out_npz)

# ---------------------- 对齐真值与缓存顺序(关键!防止错位) ----------------------
full_truth_aligned = full_truth.set_index("row_id").loc[meta_full["row_id"]].reset_index()
Y_FULL = Y_SC[full_truth_aligned["index"].to_numpy()]

# 断言检查:确保文件名、row_id完全对应
assert np.all(full_truth_aligned["filename"].values == meta_full["filename"].values)
assert np.all(full_truth_aligned["row_id"].values == meta_full["row_id"].values)

# 打印形状信息
print("meta_full:", meta_full.shape)
print("scores_full_raw:", scores_full_raw.shape, scores_full_raw.dtype)
print("emb_full:", emb_full.shape, emb_full.dtype)
print("Y_FULL:", Y_FULL.shape, Y_FULL.dtype)

# ---------------------- [可选] 网格搜索:同属代理用max还是mean? ----------------------
PROXY_REDUCE_CACHE = CFG["full_cache_work_dir"] / "proxy_reduce_grid.json"

# 如果配置里开了网格搜索
if CFG.get("run_proxy_reduce_grid", False):
    print("\n[Opsi 3] Running proxy_reduce grid search: max vs mean...")
    proxy_reduce_results = {}

    # 遍历候选策略(max、mean)
    for pr in CFG["proxy_reduce_grid"]:
        full_paths = [BASE / "train_soundscapes" / fn for fn in full_files]
        # 重新跑Perch,用当前策略
        _meta, _scores, _emb = infer_perch_with_embeddings(
            full_paths,
            batch_files=CFG["batch_files"],
            verbose=False,
            proxy_reduce=pr,
        )

        # 算OOF基线AUC
        _oof_b, _oof_p, _ = build_oof_base_prior(
            scores_full_raw=_scores,
            meta_full=_meta,
            sc_clean=sc_clean,
            Y_SC=Y_SC,
            n_splits=5,
            verbose=False,
        )
        auc = macro_auc_skip_empty(Y_FULL, _oof_b)
        proxy_reduce_results[pr] = float(auc)
        print(f"  proxy_reduce={pr!r:6s} → OOF baseline AUC = {auc:.6f}")

    # 选AUC最高的策略
    best_pr = max(proxy_reduce_results, key=proxy_reduce_results.get)
    CFG["proxy_reduce"] = best_pr
    print(f"\n  Best proxy_reduce = {best_pr!r} (AUC={proxy_reduce_results[best_pr]:.6f})")

    # 保存结果到缓存
    PROXY_REDUCE_CACHE.write_text(json.dumps({
        "results": proxy_reduce_results,
        "best_proxy_reduce": best_pr,
    }, indent=2))
    print("  Saved to:", PROXY_REDUCE_CACHE)

# 如果没开搜索但有缓存 → 直接加载最优策略
elif PROXY_REDUCE_CACHE.exists():
    _pr_data = json.loads(PROXY_REDUCE_CACHE.read_text())
    CFG["proxy_reduce"] = _pr_data["best_proxy_reduce"]
    print(f"[Opsi 3] Loaded proxy_reduce from cache: {CFG['proxy_reduce']!r}")
    print("  Grid results:", _pr_data["results"])

# 都没有 → 用默认策略
else:
    print(f"[Opsi 3] Using default proxy_reduce={CFG['proxy_reduce']!r} (submit mode or no cache)")
Loading cached full-file Perch outputs from:
   /kaggle/input/perch-meta/full_perch_meta.parquet
   /kaggle/input/perch-meta/full_perch_arrays.npz
meta_full: (708, 4)
scores_full_raw: (708, 234) float32
emb_full: (708, 1536) float32
Y_FULL: (708, 234) uint8
[Opsi 3] Using default proxy_reduce='max' (submit mode or no cache)

先验表

基于标注数据,拟合站点 / 小时 / 月份维度的先验表

# Cell 7 — 折安全的元数据先验表

# 函数1:拟合先验表
def fit_prior_tables(prior_df, Y_prior):
    prior_df = prior_df.reset_index(drop=True)

    global_p = Y_prior.mean(axis=0).astype(np.float32)  # 全局先验:所有样本的平均

    # ---------------------- 站点先验 ----------------------
    site_keys = sorted(prior_df["site"].dropna().astype(str).unique().tolist())  # 所有站点
    site_to_i = {k: i for i, k in enumerate(site_keys)}  # 站点→索引映射
    site_n = np.zeros(len(site_keys), dtype=np.float32)  # 每个站点的样本数
    site_p = np.zeros((len(site_keys), Y_prior.shape[1]), dtype=np.float32)  # 每个站点的先验概率

    for s in site_keys:
        i = site_to_i[s]
        mask = prior_df["site"].astype(str).values == s  # 筛选该站点的样本
        site_n[i] = mask.sum()  # 统计样本数
        site_p[i] = Y_prior[mask].mean(axis=0)  # 算该站点的平均概率

    # ---------------------- 小时先验 ----------------------
    hour_keys = sorted(prior_df["hour_utc"].dropna().astype(int).unique().tolist())  # 所有小时
    hour_to_i = {h: i for i, h in enumerate(hour_keys)}  # 小时→索引映射
    hour_n = np.zeros(len(hour_keys), dtype=np.float32)  # 每个小时的样本数
    hour_p = np.zeros((len(hour_keys), Y_prior.shape[1]), dtype=np.float32)  # 每个小时的先验概率

    for h in hour_keys:
        i = hour_to_i[h]
        mask = prior_df["hour_utc"].astype(int).values == h  # 筛选该小时的样本
        hour_n[i] = mask.sum()  # 统计样本数
        hour_p[i] = Y_prior[mask].mean(axis=0)  # 算该小时的平均概率

    # ---------------------- 站点-小时联合先验 ----------------------
    sh_to_i = {}  # (站点,小时)→索引映射
    sh_n_list = []  # 每个(站点,小时)的样本数
    sh_p_list = []  # 每个(站点,小时)的先验概率

    for (s, h), idx in prior_df.groupby(["site", "hour_utc"]).groups.items():
        sh_to_i[(str(s), int(h))] = len(sh_n_list)  # 记录索引
        idx = np.array(list(idx))
        sh_n_list.append(len(idx))  # 统计样本数
        sh_p_list.append(Y_prior[idx].mean(axis=0))  # 算平均概率

    sh_n = np.array(sh_n_list, dtype=np.float32)
    sh_p = np.stack(sh_p_list).astype(np.float32) if len(sh_p_list) else np.zeros((0, Y_prior.shape[1]), dtype=np.float32)

    # 返回所有先验表
    return {
        "global_p": global_p,
        "site_to_i": site_to_i,
        "site_n": site_n,
        "site_p": site_p,
        "hour_to_i": hour_to_i,
        "hour_n": hour_n,
        "hour_p": hour_p,
        "sh_to_i": sh_to_i,
        "sh_n": sh_n,
        "sh_p": sh_p,
    }


# 函数2:从先验表生成先验logits
def prior_logits_from_tables(sites, hours, tables, eps=1e-4):
    n = len(sites)
    p = np.repeat(tables["global_p"][None, :], n, axis=0).astype(np.float32, copy=True)  # 初始用全局先验

    # 找每个样本对应的站点、小时、站点-小时索引
    site_idx = np.fromiter(
        (tables["site_to_i"].get(str(s), -1) for s in sites),
        dtype=np.int32,
        count=n
    )
    hour_idx = np.fromiter(
        (tables["hour_to_i"].get(int(h), -1) if int(h) >= 0 else -1 for h in hours),
        dtype=np.int32,
        count=n
    )
    sh_idx = np.fromiter(
        (tables["sh_to_i"].get((str(s), int(h)), -1) if int(h) >= 0 else -1 for s, h in zip(sites, hours)),
        dtype=np.int32,
        count=n
    )

    # ---------------------- 加权融合:小时先验 ----------------------
    valid = hour_idx >= 0
    if valid.any():
        nh = tables["hour_n"][hour_idx[valid]][:, None]  # 该小时的样本数
        wh = nh / (nh + 8.0)  # 样本数越多,权重越大(平滑系数8)
        p[valid] = wh * tables["hour_p"][hour_idx[valid]] + (1.0 - wh) * p[valid]  # 加权融合

    # ---------------------- 加权融合:站点先验 ----------------------
    valid = site_idx >= 0
    if valid.any():
        ns = tables["site_n"][site_idx[valid]][:, None]  # 该站点的样本数
        ws = ns / (ns + 8.0)  # 样本数越多,权重越大
        p[valid] = ws * tables["site_p"][site_idx[valid]] + (1.0 - ws) * p[valid]  # 加权融合

    # ---------------------- 加权融合:站点-小时联合先验 ----------------------
    valid = sh_idx >= 0
    if valid.any():
        nsh = tables["sh_n"][sh_idx[valid]][:, None]  # 该(站点,小时)的样本数
        wsh = nsh / (nsh + 4.0)  # 样本数越多,权重越大(平滑系数4)
        p[valid] = wsh * tables["sh_p"][sh_idx[valid]] + (1.0 - wsh) * p[valid]  # 加权融合

    # 限制概率范围,转成logits
    np.clip(p, eps, 1.0 - eps, out=p)
    return (np.log(p) - np.log1p(-p)).astype(np.float32, copy=False)


# 函数3:用先验表融合基础分数
def fuse_scores_with_tables(base_scores, sites, hours, tables,
                            lambda_event=BEST["lambda_event"],
                            lambda_texture=BEST["lambda_texture"],
                            lambda_proxy_texture=BEST["lambda_proxy_texture"],
                            smooth_texture=BEST["smooth_texture"],
                            smooth_event=BEST["smooth_event"]):
    scores = base_scores.copy()
    prior = prior_logits_from_tables(sites, hours, tables)  # 生成先验logits

    # ---------------------- 融合:已映射的活跃类别 ----------------------
    if len(idx_mapped_active_event):  # 已映射事件类(鸟类)
        scores[:, idx_mapped_active_event] += lambda_event * prior[:, idx_mapped_active_event]

    if len(idx_mapped_active_texture):  # 已映射纹理类(两栖、昆虫)
        scores[:, idx_mapped_active_texture] += lambda_texture * prior[:, idx_mapped_active_texture]

    # ---------------------- 融合:选定的蛙类代理 ----------------------
    if len(idx_selected_proxy_active_texture):
        scores[:, idx_selected_proxy_active_texture] += lambda_proxy_texture * prior[:, idx_selected_proxy_active_texture]

    # ---------------------- 融合:仅用先验的活跃未映射类别 ----------------------
    if len(idx_selected_prioronly_active_event):
        scores[:, idx_selected_prioronly_active_event] = lambda_event * prior[:, idx_selected_prioronly_active_event]

    if len(idx_selected_prioronly_active_texture):
        scores[:, idx_selected_prioronly_active_texture] = lambda_texture * prior[:, idx_selected_prioronly_active_texture]

    # ---------------------- 处理:不活跃的未映射类别(直接设为极低分) ----------------------
    if len(idx_unmapped_inactive):
        scores[:, idx_unmapped_inactive] = -8.0

    # ---------------------- 应用时序平滑 ----------------------
    scores = smooth_cols_fixed12(scores, idx_active_texture, alpha=smooth_texture)  # 纹理类平滑
    scores = smooth_events_fixed12(scores, idx_active_event, alpha=smooth_event)  # 事件类平滑
    return scores.astype(np.float32, copy=False), prior

折外堆叠集成

构建诚实的折外基础元特征与先验元特征

# Cell 8 — 诚实的折外基础/先验元特征(最终堆叠器拟合所需)

# 函数1:构建折外基础/先验元特征
def build_oof_base_prior(scores_full_raw, meta_full, sc_clean, Y_SC, n_splits=5, verbose=True):
    groups_full = meta_full["filename"].to_numpy()  # 按文件名分组(同一文件的窗口不能拆分)
    gkf = GroupKFold(n_splits=n_splits)  # 分组K折交叉验证

    # 初始化OOF数组
    oof_base = np.zeros_like(scores_full_raw, dtype=np.float32)  # 折外基础融合分数
    oof_prior = np.zeros_like(scores_full_raw, dtype=np.float32)  # 折外先验logits
    fold_id = np.full(len(meta_full), -1, dtype=np.int16)  # 每个样本所属的折

    splits = list(gkf.split(scores_full_raw, groups=groups_full))  # 生成折索引
    iterator = tqdm(splits, desc="OOF base/prior folds", disable=not verbose)

    # 遍历每折
    for fold, (tr_idx, va_idx) in enumerate(iterator, 1):
        tr_idx = np.sort(tr_idx)
        va_idx = np.sort(va_idx)

        val_files = set(meta_full.iloc[va_idx]["filename"].tolist())  # 验证集的文件

        # ---------------------- 关键:折安全的先验表(排除所有验证文件) ----------------------
        prior_mask = ~sc_clean["filename"].isin(val_files).values  # 筛选训练集文件
        prior_df_fold = sc_clean.loc[prior_mask].reset_index(drop=True)
        Y_prior_fold = Y_SC[prior_mask]

        tables = fit_prior_tables(prior_df_fold, Y_prior_fold)  # 只用训练集拟合先验表

        # ---------------------- 生成验证集的基础融合分数和先验logits ----------------------
        va_base, va_prior = fuse_scores_with_tables(
            scores_full_raw[va_idx],
            sites=meta_full.iloc[va_idx]["site"].to_numpy(),
            hours=meta_full.iloc[va_idx]["hour_utc"].to_numpy(),
            tables=tables,
        )

        # 填入OOF数组
        oof_base[va_idx] = va_base
        oof_prior[va_idx] = va_prior
        fold_id[va_idx] = fold

    assert (fold_id >= 0).all()  # 确保所有样本都有折ID
    return oof_base, oof_prior, fold_id


# ---------------------- 缓存管理 ----------------------
OOF_META_CACHE = CFG["full_cache_work_dir"] / "full_oof_meta_features.npz"

# 有缓存 → 直接加载
if OOF_META_CACHE.exists():
    print("Loading cached OOF meta-features from:", OOF_META_CACHE)
    arr = np.load(OOF_META_CACHE)
    oof_base = arr["oof_base"].astype(np.float32)
    oof_prior = arr["oof_prior"].astype(np.float32)
    oof_fold_id = arr["fold_id"].astype(np.int16)

# 没缓存 → 重新构建并保存
else:
    print("Building OOF meta-features...")
    oof_base, oof_prior, oof_fold_id = build_oof_base_prior(
        scores_full_raw=scores_full_raw,
        meta_full=meta_full,
        sc_clean=sc_clean,
        Y_SC=Y_SC,
        n_splits=5,
        verbose=CFG["verbose"],
    )

    np.savez_compressed(
        OOF_META_CACHE,
        oof_base=oof_base,
        oof_prior=oof_prior,
        fold_id=oof_fold_id,
    )
    print("Saved OOF meta-features to:", OOF_META_CACHE)


# ---------------------- 计算AUC对比效果 ----------------------
baseline_oof_auc = macro_auc_skip_empty(Y_FULL, oof_base)  # 诚实折外基线AUC

if MODE == "train":
    raw_local_auc = macro_auc_skip_empty(Y_FULL, scores_full_raw)  # 原始本地AUC(非折外)
    print(f"Raw local AUC (not OOF-dependent): {raw_local_auc:.6f}")
    print(f"Honest OOF baseline AUC: {baseline_oof_auc:.6f}")
Building OOF meta-features...
Saved OOF meta-features to: /kaggle/working/perch_cache/full_oof_meta_features.npz

逐类别探针模型

PCA 压缩后的嵌入特征上,训练逐类别的 MLP / 逻辑回归探针模型

# Cell 9 — 逐类别嵌入探针辅助函数

# 函数1:构建类别特征
def build_class_features(emb_proj, raw_col, prior_col, base_col):
    """
    emb_proj: (n, d)
    raw_col, prior_col, base_col: (n,)
    returns: (n, d + 13)

    特征:嵌入 + 7个时序 + 3个交互项 + 标准差 + 3个差值
    """
    prev_base, next_base, mean_base, max_base, std_base = seq_features_1d(base_col)  # 提取时序特征

    # 差值特征:窗口相对于文件上下文的位置
    diff_mean = base_col - mean_base   # 这个窗口是否高于文件平均?
    diff_prev = base_col - prev_base   # 起始:比前一个窗口高?
    diff_next = base_col - next_base   # 结束:比后一个窗口低?

    # 拼接所有特征
    feats = np.concatenate([
        emb_proj,                          # PCA降维后的嵌入
        raw_col[:, None],                  # Perch原始分数
        prior_col[:, None],                # 先验logits
        base_col[:, None],                 # 基础融合分数
        prev_base[:, None],                # 前一个窗口的基础分数
        next_base[:, None],                # 后一个窗口的基础分数
        mean_base[:, None],                # 文件均值
        max_base[:, None],                 # 文件最大值
        std_base[:, None],                 # 文件内的时序方差
        diff_mean[:, None],                # 与文件均值的偏差
        diff_prev[:, None],                # 检测起始
        diff_next[:, None],                # 检测结束
        # 交互项
        (raw_col * prior_col)[:, None],    # 原始×先验
        (raw_col * base_col)[:, None],     # 原始×基础
        (prior_col * base_col)[:, None],   # 先验×基础
    ], axis=1)

    return feats.astype(np.float32, copy=False)


# 函数2:运行折外嵌入探针
def run_oof_embedding_probe(
    scores_raw,
    emb,
    meta_df,
    y_true,
    pca_dim=64,
    min_pos=8,
    C=0.25,
    alpha=0.5,
):
    groups = meta_df["filename"].to_numpy()  # 按文件名分组
    gkf = GroupKFold(n_splits=5)  # 分组K折

    # 初始化OOF数组
    oof_base_local = np.zeros_like(scores_raw, dtype=np.float32)  # 折外基础分数
    oof_final = np.zeros_like(scores_raw, dtype=np.float32)       # 折外最终融合分数

    modeled_counts = np.zeros(scores_raw.shape[1], dtype=np.int32)  # 每个类别被建模的折数

    split_list = list(gkf.split(scores_raw, groups=groups))

    # 遍历每折
    for fold, (tr_idx, va_idx) in enumerate(tqdm(split_list, desc="Embedding-probe folds", disable=not CFG["verbose"]), 1):
        tr_idx = np.sort(tr_idx)
        va_idx = np.sort(va_idx)

        val_files = set(meta_df.iloc[va_idx]["filename"].tolist())  # 验证集文件

        # ---------------------- 折安全的先验表 ----------------------
        prior_mask = ~sc_clean["filename"].isin(val_files).values
        prior_df_fold = sc_clean.loc[prior_mask].reset_index(drop=True)
        Y_prior_fold = Y_SC[prior_mask]
        tables = fit_prior_tables(prior_df_fold, Y_prior_fold)

        # 生成训练/验证集的基础融合分数和先验
        base_tr, prior_tr = fuse_scores_with_tables(
            scores_raw[tr_idx],
            sites=meta_df.iloc[tr_idx]["site"].to_numpy(),
            hours=meta_df.iloc[tr_idx]["hour_utc"].to_numpy(),
            tables=tables,
        )
        base_va, prior_va = fuse_scores_with_tables(
            scores_raw[va_idx],
            sites=meta_df.iloc[va_idx]["site"].to_numpy(),
            hours=meta_df.iloc[va_idx]["hour_utc"].to_numpy(),
            tables=tables,
        )

        oof_base_local[va_idx] = base_va
        oof_final[va_idx] = base_va

        # ---------------------- 嵌入预处理(只用训练集) ----------------------
        scaler = StandardScaler()
        emb_tr_s = scaler.fit_transform(emb[tr_idx])  # 标准化
        emb_va_s = scaler.transform(emb[va_idx])

        n_comp = min(pca_dim, emb_tr_s.shape[0] - 1, emb_tr_s.shape[1])
        pca = PCA(n_components=n_comp)
        Z_tr = pca.fit_transform(emb_tr_s).astype(np.float32)  # PCA降维
        Z_va = pca.transform(emb_va_s).astype(np.float32)

        # 筛选:训练集中正样本>=min_pos的类别
        class_iterator = np.where(y_true[tr_idx].sum(axis=0) >= min_pos)[0].tolist()

        # ---------------------- 逐类别训探针 ----------------------
        for cls_idx in tqdm(class_iterator, desc=f"Fold {fold} classes", leave=False, disable=not CFG["verbose"]):
            y_tr = y_true[tr_idx, cls_idx]

            if y_tr.sum() == 0 or y_tr.sum() == len(y_tr):  # 全正或全负,跳过
                continue

            # 构建该类别的训练/验证特征
            X_tr_cls = build_class_features(
                Z_tr,
                raw_col=scores_raw[tr_idx, cls_idx],
                prior_col=prior_tr[:, cls_idx],
                base_col=base_tr[:, cls_idx],
            )
            X_va_cls = build_class_features(
                Z_va,
                raw_col=scores_raw[va_idx, cls_idx],
                prior_col=prior_va[:, cls_idx],
                base_col=base_va[:, cls_idx],
            )

            # 选择探针后端:mlp | lgbm | logreg
            backend = CFG.get("probe_backend", "mlp")
            n_pos = int(y_tr.sum())
            n_neg = len(y_tr) - n_pos

            if backend == "mlp":
                # MLP不支持样本权重,用过采样平衡正负样本
                if n_pos > 0 and n_neg > n_pos:
                    repeat = max(1, n_neg // n_pos)
                    pos_idx = np.where(y_tr == 1)[0]
                    X_bal = np.vstack([X_tr_cls, np.tile(X_tr_cls[pos_idx], (repeat, 1))])
                    y_bal = np.concatenate([y_tr, np.ones(len(pos_idx) * repeat, dtype=y_tr.dtype)])
                else:
                    X_bal, y_bal = X_tr_cls, y_tr
                clf = MLPClassifier(**CFG["mlp_params"])
                clf.fit(X_bal, y_bal)
                pred_va = clf.predict_proba(X_va_cls)[:, 1].astype(np.float32)
                pred_va = np.log(pred_va + 1e-7) - np.log(1 - pred_va + 1e-7)  # 转logits
            elif backend == "lgbm" and _LGBM_AVAILABLE:
                scale_pos = max(1.0, n_neg / max(n_pos, 1))  # 正负样本权重
                clf = LGBMClassifier(
                    **CFG["lgbm_params"],
                    scale_pos_weight=scale_pos,
                )
                clf.fit(X_tr_cls, y_tr)
                pred_va = clf.predict_proba(X_va_cls)[:, 1].astype(np.float32)
                pred_va = np.log(pred_va + 1e-7) - np.log(1 - pred_va + 1e-7)  # 转logits
            else:
                clf = LogisticRegression(
                    C=C, max_iter=400, solver="liblinear",
                    class_weight="balanced",
                )
                clf.fit(X_tr_cls, y_tr)
                pred_va = clf.decision_function(X_va_cls).astype(np.float32)  # 直接输出logits

            # 融合基础分数和探针预测
            oof_final[va_idx, cls_idx] = (
                (1.0 - alpha) * base_va[:, cls_idx] +
                alpha * pred_va
            )

            modeled_counts[cls_idx] += 1

    # 计算AUC对比效果
    score_base = macro_auc_skip_empty(y_true, oof_base_local)
    score_final = macro_auc_skip_empty(y_true, oof_final)

    return {
        "oof_base": oof_base_local,
        "oof_final": oof_final,
        "modeled_counts": modeled_counts,
        "score_base": score_base,
        "score_final": score_final,
    }

ProtoSSM v2 模型架构

带原型分类头与元数据感知能力的双向选择性状态空间模型

核心设计:鸟类鸣唱具有鲜明的时序规律 —— 包括黎明合唱的起始特征、物种专属的鸣叫频次,以及领地对唱的动态模式。当前主流方案将每个 5 秒音频窗口视为完全独立的样本,彻底丢弃了这一关键信号。ProtoSSM v2 通过双向选择性状态空间模型对时序动态进行建模,同时融合了站点与小时的元数据嵌入、逐类别校准偏置,以及嵌入空间中的原型度量学习。

# ProtoSSM v4 — 增强版交叉注意力层

# 类1:简化版Mamba风格选择性状态空间模型
class SelectiveSSM(nn.Module):
    # 简化版Mamba风格选择性状态空间模型
    # 输入依赖的(选择性)连续时间SSM离散化
    # 对于T=12的生物声学窗口,顺序扫描在CPU上也高效

    def __init__(self, d_model, d_state=16, d_conv=4):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        self.in_proj = nn.Linear(d_model, 2 * d_model, bias=False)  # 输入投影:拆成SSM部分和门控部分
        self.conv1d = nn.Conv1d(
            d_model, d_model, d_conv,
            padding=d_conv - 1, groups=d_model
        )  # 深度可分离卷积:捕捉局部时序
        self.dt_proj = nn.Linear(d_model, d_model, bias=True)  # 输入依赖的离散化步长

        A = torch.arange(1, d_state + 1, dtype=torch.float32)
        A = A.unsqueeze(0).expand(d_model, -1)
        self.A_log = nn.Parameter(torch.log(A))  # 可学习的A矩阵(对数空间)
        self.D = nn.Parameter(torch.ones(d_model))  # 可学习的直接连接项
        self.B_proj = nn.Linear(d_model, d_state, bias=False)  # 输入依赖的B矩阵
        self.C_proj = nn.Linear(d_model, d_state, bias=False)  # 输入依赖的C矩阵
        self.out_proj = nn.Linear(d_model, d_model, bias=False)  # 输出投影

    def forward(self, x):
        B_size, T, D = x.shape
        xz = self.in_proj(x)
        x_ssm, z = xz.chunk(2, dim=-1)  # 拆成SSM输入和门控

        x_conv = self.conv1d(x_ssm.transpose(1, 2))[:, :, :T].transpose(1, 2)  # 局部卷积
        x_conv = F.silu(x_conv)

        dt = F.softplus(self.dt_proj(x_conv))  # 输入依赖的离散化步长
        A = -torch.exp(self.A_log)  # 恢复A矩阵(保证负定)
        B = self.B_proj(x_conv)  # 输入依赖的B
        C = self.C_proj(x_conv)  # 输入依赖的C

        # 顺序扫描:更新隐藏状态
        h = torch.zeros(B_size, D, self.d_state, device=x.device)
        ys = []
        for t in range(T):
            dt_t = dt[:, t, :]
            dA = torch.exp(A[None, :, :] * dt_t[:, :, None])  # 离散化的A
            dB = dt_t[:, :, None] * B[:, t, None, :]  # 离散化的B
            h = h * dA + x[:, t, :, None] * dB  # 更新隐藏状态
            y_t = (h * C[:, t, None, :]).sum(-1)  # 输出
            ys.append(y_t)

        y = torch.stack(ys, dim=1)
        return y + x * self.D[None, None, :]  # 残差连接


# 类2:时序交叉注意力(新增!)
class TemporalCrossAttention(nn.Module):
    """窗口间的多头交叉注意力
    捕捉顺序SSM可能遗漏的非局部模式(如黎明合唱起始、对唱)"""
    
    def __init__(self, d_model, n_heads=4, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)  # 多头自注意力
        self.norm = nn.LayerNorm(d_model)  # 层归一化
        self.ffn = nn.Sequential(  # 前馈网络
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 2, d_model),
            nn.Dropout(dropout),
        )
        self.norm2 = nn.LayerNorm(d_model)  # 第二层归一化
    
    def forward(self, x):
        # x: (B, T, D)
        residual = x
        x = self.norm(x)
        attn_out, _ = self.attn(x, x, x)  # 自注意力:每个窗口看所有窗口
        x = residual + attn_out  # 残差连接
        
        residual = x
        x = self.norm2(x)
        x = residual + self.ffn(x)  # 前馈网络+残差
        return x


# 类3:ProtoSSMv2(实际为v4,带交叉注意力)
class ProtoSSMv2(nn.Module):
    # 带交叉注意力和元数据感知的原型状态空间模型v4
    #
    # 新增组件:
    # - SSM后的交叉注意力层,捕捉非局部时序模式
    # - 保留v2所有其他功能(元数据、原型、门控融合)
    
    def __init__(self, d_input=1536, d_model=192, d_state=16,
                 n_ssm_layers=2, n_classes=234, n_windows=12,
                 dropout=0.2, n_sites=20, meta_dim=16,
                 use_cross_attn=True, cross_attn_heads=4):
        super().__init__()
        self.d_model = d_model
        self.n_classes = n_classes
        self.n_windows = n_windows

        # 1. 特征投影:1536维Perch嵌入→d_model维
        self.input_proj = nn.Sequential(
            nn.Linear(d_input, d_model),
            nn.LayerNorm(d_model),
            nn.GELU(),
            nn.Dropout(dropout),
        )

        # 2. 可学习的位置编码
        self.pos_enc = nn.Parameter(torch.randn(1, n_windows, d_model) * 0.02)

        # 3. 元数据嵌入:站点、小时
        self.site_emb = nn.Embedding(n_sites, meta_dim)
        self.hour_emb = nn.Embedding(24, meta_dim)
        self.meta_proj = nn.Linear(2 * meta_dim, d_model)

        # 4. 双向SSM层
        self.ssm_fwd = nn.ModuleList()  # 前向SSM
        self.ssm_bwd = nn.ModuleList()  # 反向SSM
        self.ssm_merge = nn.ModuleList()  # 双向融合
        self.ssm_norm = nn.ModuleList()  # 层归一化
        for _ in range(n_ssm_layers):
            self.ssm_fwd.append(SelectiveSSM(d_model, d_state))
            self.ssm_bwd.append(SelectiveSSM(d_model, d_state))
            self.ssm_merge.append(nn.Linear(2 * d_model, d_model))
            self.ssm_norm.append(nn.LayerNorm(d_model))
        self.ssm_drop = nn.Dropout(dropout)

        # 4b. 新增:SSM后的交叉注意力
        self.use_cross_attn = use_cross_attn
        if use_cross_attn:
            self.cross_attn = TemporalCrossAttention(d_model, n_heads=cross_attn_heads, dropout=dropout)

        # 5. 可学习的类别原型
        self.prototypes = nn.Parameter(torch.randn(n_classes, d_model) * 0.02)
        self.proto_temp = nn.Parameter(torch.tensor(5.0))  # 温度参数

        # 6. 逐类别校准偏置
        self.class_bias = nn.Parameter(torch.zeros(n_classes))

        # 7. 与Perch logits的逐类别门控融合
        self.fusion_alpha = nn.Parameter(torch.zeros(n_classes))

        # 8. 分类学辅助头(可选)
        self.n_families = 0
        self.family_head = None

    # 用数据初始化原型(可选)
    def init_prototypes_from_data(self, embeddings, labels):
        with torch.no_grad():
            h = self.input_proj(embeddings)
            for c in range(self.n_classes):
                mask = labels[:, c] > 0.5
                if mask.sum() > 0:
                    self.prototypes.data[c] = F.normalize(h[mask].mean(0), dim=0)

    # 初始化分类学辅助头(可选)
    def init_family_head(self, n_families, class_to_family):
        self.n_families = n_families
        self.family_head = nn.Linear(self.d_model, n_families)
        self.register_buffer('class_to_family', torch.tensor(class_to_family, dtype=torch.long))

    def forward(self, emb, perch_logits=None, site_ids=None, hours=None):
        B, T, _ = emb.shape

        # 投影嵌入+位置编码
        h = self.input_proj(emb)
        h = h + self.pos_enc[:, :T, :]

        # 加入元数据嵌入
        if site_ids is not None and hours is not None:
            s_emb = self.site_emb(site_ids)
            h_emb = self.hour_emb(hours)
            meta = self.meta_proj(torch.cat([s_emb, h_emb], dim=-1))
            h = h + meta[:, None, :]

        # 双向SSM
        for fwd, bwd, merge, norm in zip(
            self.ssm_fwd, self.ssm_bwd, self.ssm_merge, self.ssm_norm
        ):
            residual = h
            h_f = fwd(h)  # 前向
            h_b = bwd(h.flip(1)).flip(1)  # 反向(翻转输入→SSM→翻转回来)
            h = merge(torch.cat([h_f, h_b], dim=-1))  # 融合双向
            h = self.ssm_drop(h)
            h = norm(h + residual)  # 残差+归一化

        # 新增:交叉注意力捕捉非局部时序模式
        if self.use_cross_attn:
            h = self.cross_attn(h)

        h_temporal = h

        # 原型余弦相似度+类别偏置
        h_norm = F.normalize(h, dim=-1)
        p_norm = F.normalize(self.prototypes, dim=-1)
        temp = F.softplus(self.proto_temp)
        sim = torch.matmul(h_norm, p_norm.T) * temp + self.class_bias[None, None, :]

        # 与Perch logits的门控融合
        if perch_logits is not None:
            alpha = torch.sigmoid(self.fusion_alpha)[None, None, :]
            species_logits = alpha * sim + (1 - alpha) * perch_logits
        else:
            species_logits = sim

        # 分类学辅助预测(可选)
        family_logits = None
        if self.family_head is not None:
            h_pool = h.mean(dim=1)
            family_logits = self.family_head(h_pool)

        return species_logits, family_logits, h_temporal

    # 统计参数量
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# 测试模型
ssm_cfg = CFG["proto_ssm"]
print("ProtoSSMv4 architecture defined (with cross-attention).")
test_model = ProtoSSMv2(
    d_model=ssm_cfg["d_model"], n_ssm_layers=2,
    n_sites=ssm_cfg["n_sites"], meta_dim=ssm_cfg["meta_dim"],
    use_cross_attn=ssm_cfg.get("use_cross_attn", True),
    cross_attn_heads=ssm_cfg.get("cross_attn_heads", 4),
)
print(f"Parameter count: {test_model.count_parameters():,}")
del test_model
ProtoSSMv4 architecture defined (with cross-attention).
Parameter count: 3,531,445

ProtoSSM v2 训练

多任务训练,包含物种二元交叉熵损失(带标签平滑)、原型对比损失、Perch logits 知识蒸馏损失,以及分类学辅助损失。包含 5 折分组 K 折(GroupKFold)折外(OOF)交叉验证和集成权重优化。

# ProtoSSM v4 训练循环 — 带Mixup、焦点损失、随机权重平均(SWA)

# 函数1:构建分类学分组(科/目/纲)
def build_taxonomy_groups(taxonomy_df, primary_labels):
    for col in ["family", "order", "class_name"]:
        if col in taxonomy_df.columns:
            group_map = taxonomy_df.set_index("primary_label")[col].to_dict()  # 物种→科/目/纲映射
            break
    else:
        group_map = {label: "Unknown" for label in primary_labels}

    groups = sorted(set(group_map.values()))
    grp_to_idx = {g: i for i, g in enumerate(groups)}  # 分组→索引
    class_to_group = []
    for label in primary_labels:
        grp = group_map.get(label, "Unknown")
        class_to_group.append(grp_to_idx.get(grp, 0))
    return len(groups), class_to_group, grp_to_idx


# 函数2:构建站点映射
def build_site_mapping(meta_df):
    sites = meta_df["site"].unique().tolist()
    site_to_idx = {s: i + 1 for i, s in enumerate(sites)}  # 站点→索引(+1留0给未知)
    n_sites = len(sites) + 1
    return site_to_idx, n_sites


# 函数3:把扁平数组重塑为文件级格式(文件数, 12窗口, ...)
def reshape_to_files(flat_array, meta_df, n_windows=N_WINDOWS):
    filenames = meta_df["filename"].to_numpy()
    unique_files = []
    seen = set()
    for f in filenames:
        if f not in seen:
            unique_files.append(f)
            seen.add(f)

    n_files = len(unique_files)
    assert len(flat_array) == n_files * n_windows, \
        f"Expected {n_files * n_windows} rows, got {len(flat_array)}"

    new_shape = (n_files, n_windows) + flat_array.shape[1:]
    return flat_array.reshape(new_shape), unique_files


# 函数4:获取文件级元数据(站点、小时)
def get_file_metadata(meta_df, file_list, site_to_idx, n_sites_max):
    file_to_row = {}
    filenames = meta_df["filename"].to_numpy()
    sites = meta_df["site"].to_numpy()
    hours = meta_df["hour_utc"].to_numpy()
    for i, f in enumerate(filenames):
        if f not in file_to_row:
            file_to_row[f] = i

    site_ids = np.zeros(len(file_list), dtype=np.int64)
    hour_ids = np.zeros(len(file_list), dtype=np.int64)
    for fi, fname in enumerate(file_list):
        row = file_to_row.get(fname)
        if row is not None:
            sid = site_to_idx.get(sites[row], 0)
            site_ids[fi] = min(sid, n_sites_max - 1)
            hour_ids[fi] = int(hours[row]) % 24
    return site_ids, hour_ids


# 函数5:文件级Mixup增强(新增!)
def mixup_files(emb, logits, labels, site_ids, hours, families, alpha=0.3):
    """ProtoSSM训练的文件级Mixup增强
    用Beta(alpha, alpha)随机lambda混合文件对
    返回所有输入的增强版本"""
    n = len(emb)
    if alpha <= 0 or n < 2:
        return emb, logits, labels, site_ids, hours, families
    
    lam = np.random.beta(alpha, alpha)
    lam = max(lam, 1.0 - lam)  # 保证lambda>=0.5(主样本保持主导)
    
    perm = np.random.permutation(n)  # 随机打乱索引
    
    # 混合连续特征
    emb_mix = lam * emb + (1 - lam) * emb[perm]
    logits_mix = lam * logits + (1 - lam) * logits[perm]
    labels_mix = lam * labels + (1 - lam) * labels[perm]
    
    # 离散特征(站点、小时)保留主样本的值
    families_mix = lam * families + (1 - lam) * families[perm] if families is not None else None
    
    return emb_mix, logits_mix, labels_mix, site_ids, hours, families_mix


# 函数6:单折ProtoSSM v4训练(核心!)
def train_proto_ssm_single(model, emb_train, logits_train, labels_train,
                           site_ids_train=None, hours_train=None,
                           emb_val=None, logits_val=None, labels_val=None,
                           site_ids_val=None, hours_val_val=None,
                           file_families_train=None, file_families_val=None,
                           cfg=None, verbose=True):
    """训练单个ProtoSSM v4模型,带Mixup、焦点损失、SWA"""
    if cfg is None:
        cfg = CFG["proto_ssm_train"]

    label_smoothing = cfg.get("label_smoothing", 0.0)
    mixup_alpha = cfg.get("mixup_alpha", 0.0)
    focal_gamma = cfg.get("focal_gamma", 0.0)
    swa_start_frac = cfg.get("swa_start_frac", 1.0)  # 1.0=禁用
    n_epochs = cfg["n_epochs"]
    swa_start_epoch = int(n_epochs * swa_start_frac)

    # 转numpy数组(基础版,未混合)
    labels_np = labels_train.copy()
    
    # 应用标签平滑
    if label_smoothing > 0:
        labels_np = labels_np * (1.0 - label_smoothing) + label_smoothing / 2.0

    has_val = emb_val is not None
    if has_val:
        emb_v = torch.tensor(emb_val, dtype=torch.float32)
        logits_v = torch.tensor(logits_val, dtype=torch.float32)
        labels_v = torch.tensor(labels_val, dtype=torch.float32)
        site_v = torch.tensor(site_ids_val, dtype=torch.long) if site_ids_val is not None else None
        hour_v = torch.tensor(hours_val_val, dtype=torch.long) if hours_val_val is not None else None

    fam_v = torch.tensor(file_families_val, dtype=torch.float32) if (has_val and file_families_val is not None) else None

    # 类别权重:处理样本不均衡
    labels_tr_t = torch.tensor(labels_np, dtype=torch.float32)
    pos_counts = labels_tr_t.sum(dim=(0, 1))
    total = labels_tr_t.shape[0] * labels_tr_t.shape[1]
    pos_weight = ((total - pos_counts) / (pos_counts + 1)).clamp(max=cfg["pos_weight_cap"])

    # 优化器和学习率调度器
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"]
    )
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=cfg["lr"],
        epochs=n_epochs, steps_per_epoch=1,
        pct_start=0.1, anneal_strategy='cos'
    )

    best_val_loss = float('inf')
    best_state = None
    wait = 0
    history = {"train_loss": [], "val_loss": [], "val_auc": []}

    # SWA权重累加器
    swa_state = None
    swa_count = 0

    # 训练循环
    for epoch in range(n_epochs):
        # === Mixup增强(每个epoch重新采样) ===
        if mixup_alpha > 0 and epoch > 5:  # 前5个epoch不用Mixup(热身)
            emb_mix, logits_mix, labels_mix, _, _, fam_mix = mixup_files(
                emb_train, logits_train, labels_np,
                site_ids_train, hours_train, file_families_train,
                alpha=mixup_alpha,
            )
        else:
            emb_mix, logits_mix, labels_mix = emb_train, logits_train, labels_np
            fam_mix = file_families_train

        # 转tensor
        emb_tr = torch.tensor(emb_mix, dtype=torch.float32)
        logits_tr = torch.tensor(logits_mix, dtype=torch.float32)
        labels_tr = torch.tensor(labels_mix, dtype=torch.float32)
        site_tr = torch.tensor(site_ids_train, dtype=torch.long) if site_ids_train is not None else None
        hour_tr = torch.tensor(hours_train, dtype=torch.long) if hours_train is not None else None
        fam_tr = torch.tensor(fam_mix, dtype=torch.float32) if fam_mix is not None else None

        # === 训练 ===
        model.train()
        species_out, family_out, _ = model(emb_tr, logits_tr, site_ids=site_tr, hours=hour_tr)

        # 主损失:焦点BCE或加权BCE
        if focal_gamma > 0:
            loss_main = focal_bce_with_logits(
                species_out, labels_tr,
                gamma=focal_gamma,
                pos_weight=pos_weight[None, None, :],
            )
        else:
            loss_main = F.binary_cross_entropy_with_logits(
                species_out, labels_tr,
                pos_weight=pos_weight[None, None, :]
            )

        # 知识蒸馏损失:拟合Perch logits
        loss_distill = F.mse_loss(species_out, logits_tr)

        # 总损失
        loss = loss_main + cfg["distill_weight"] * loss_distill

        # 分类学辅助损失
        if family_out is not None and fam_tr is not None:
            loss_family = F.binary_cross_entropy_with_logits(family_out, fam_tr)
            loss = loss + 0.1 * loss_family

        # 反向传播+优化
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # 梯度裁剪
        optimizer.step()
        scheduler.step()

        # === SWA权重累加 ===
        if epoch >= swa_start_epoch:
            if swa_state is None:
                swa_state = {k: v.clone() for k, v in model.state_dict().items()}
                swa_count = 1
            else:
                for k in swa_state:
                    swa_state[k] += model.state_dict()[k]
                swa_count += 1

        # === 验证 ===
        model.eval()
        with torch.no_grad():
            if has_val:
                val_out, val_fam, _ = model(emb_v, logits_v, site_ids=site_v, hours=hour_v)
                val_loss = F.binary_cross_entropy_with_logits(
                    val_out, labels_v,
                    pos_weight=pos_weight[None, None, :]
                )

                val_pred = val_out.reshape(-1, val_out.shape[-1]).numpy()
                val_true = labels_v.reshape(-1, labels_v.shape[-1]).numpy()
                try:
                    val_auc = macro_auc_skip_empty(val_true, val_pred)
                except Exception:
                    val_auc = 0.0
            else:
                val_loss = loss
                val_auc = 0.0

        history["train_loss"].append(loss.item())
        history["val_loss"].append(val_loss.item())
        history["val_auc"].append(val_auc)

        # 早停逻辑
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1

        if verbose and (epoch + 1) % 20 == 0:
            lr_now = optimizer.param_groups[0]['lr']
            swa_info = f" swa={swa_count}" if swa_count > 0 else ""
            print(f"  Epoch {epoch+1:3d}: train={loss.item():.4f} val={val_loss.item():.4f} "
                  f"auc={val_auc:.4f} lr={lr_now:.6f} wait={wait}{swa_info}")

        if wait >= cfg["patience"]:
            if verbose:
                print(f"  Early stopping at epoch {epoch+1} (best val_loss={best_val_loss:.4f})")
            break

    # 应用SWA(如果累加了足够的checkpoint)
    if swa_state is not None and swa_count >= 3:
        if verbose:
            print(f"  Applying SWA (averaged {swa_count} checkpoints)")
        avg_state = {k: v / swa_count for k, v in swa_state.items()}
        model.load_state_dict(avg_state)
    elif best_state is not None:
        model.load_state_dict(best_state)

    if verbose:
        print(f"  Training complete. Best val_loss={best_val_loss:.4f}")
        with torch.no_grad():
            alphas = torch.sigmoid(model.fusion_alpha).numpy()
            print(f"  Fusion alpha: mean={alphas.mean():.3f} min={alphas.min():.3f} max={alphas.max():.3f}")
            print(f"  Proto temperature: {F.softplus(model.proto_temp).item():.3f}")

    return model, history


# 函数7:运行ProtoSSM v4的5折GroupKFold OOF交叉验证
def run_proto_ssm_oof(emb_files, logits_files, labels_files,
                      site_ids_all, hours_all,
                      file_families, file_groups,
                      n_families, class_to_family,
                      cfg=None, verbose=True):
    """运行ProtoSSM v4的分组K折OOF交叉验证"""
    if cfg is None:
        cfg = CFG["proto_ssm_train"]

    n_splits = cfg.get("oof_n_splits", 5)
    n_files = len(emb_files)
    ssm_cfg = CFG["proto_ssm"]

    oof_preds = np.zeros((n_files, N_WINDOWS, N_CLASSES), dtype=np.float32)
    fold_histories = []
    fold_alphas = []

    # 如果分组数少于折数,减少折数
    n_unique_groups = len(set(file_groups))
    if n_unique_groups < n_splits:
        print(f"  WARNING: Only {n_unique_groups} groups, reducing n_splits from {n_splits} to {n_unique_groups}")
        n_splits = n_unique_groups
    gkf = GroupKFold(n_splits=n_splits)
    dummy_y = np.zeros(n_files)

    # 遍历每折
    for fold_i, (train_idx, val_idx) in enumerate(gkf.split(dummy_y, dummy_y, file_groups)):
        if verbose:
            print(f"\n--- Fold {fold_i+1}/{n_splits} (train={len(train_idx)}, val={len(val_idx)}) ---")

        # 初始化模型
        fold_model = ProtoSSMv2(
            d_input=emb_files.shape[2],
            d_model=ssm_cfg["d_model"],
            d_state=ssm_cfg["d_state"],
            n_ssm_layers=ssm_cfg["n_ssm_layers"],
            n_classes=N_CLASSES,
            n_windows=N_WINDOWS,
            dropout=ssm_cfg["dropout"],
            n_sites=ssm_cfg["n_sites"],
            meta_dim=ssm_cfg["meta_dim"],
            use_cross_attn=ssm_cfg.get("use_cross_attn", True),
            cross_attn_heads=ssm_cfg.get("cross_attn_heads", 4),
        ).to(DEVICE)

        # 用数据初始化原型
        emb_flat_fold = emb_files[train_idx].reshape(-1, emb_files.shape[2])
        labels_flat_fold = labels_files[train_idx].reshape(-1, N_CLASSES)
        fold_model.init_prototypes_from_data(
            torch.tensor(emb_flat_fold, dtype=torch.float32),
            torch.tensor(labels_flat_fold, dtype=torch.float32)
        )
        fold_model.init_family_head(n_families, class_to_family)

        # 训练这一折
        fold_model, fold_hist = train_proto_ssm_single(
            fold_model,
            emb_files[train_idx], logits_files[train_idx], labels_files[train_idx].astype(np.float32),
            site_ids_train=site_ids_all[train_idx], hours_train=hours_all[train_idx],
            emb_val=emb_files[val_idx], logits_val=logits_files[val_idx],
            labels_val=labels_files[val_idx].astype(np.float32),
            site_ids_val=site_ids_all[val_idx], hours_val_val=hours_all[val_idx],
            file_families_train=file_families[train_idx],
            file_families_val=file_families[val_idx],
            cfg=cfg, verbose=verbose,
        )

        # OOF预测(带TTA)
        fold_model.eval()
        tta_shifts = CFG.get("tta_shifts", [0])
        if len(tta_shifts) > 1:
            oof_preds[val_idx] = temporal_shift_tta(
                emb_files[val_idx], logits_files[val_idx], fold_model,
                site_ids_all[val_idx], hours_all[val_idx], shifts=tta_shifts
            )
        else:
            with torch.no_grad():
                val_emb = torch.tensor(emb_files[val_idx], dtype=torch.float32)
                val_logits = torch.tensor(logits_files[val_idx], dtype=torch.float32)
                val_sites = torch.tensor(site_ids_all[val_idx], dtype=torch.long)
                val_hours = torch.tensor(hours_all[val_idx], dtype=torch.long)
                val_out, _, _ = fold_model(val_emb, val_logits, site_ids=val_sites, hours=val_hours)
                oof_preds[val_idx] = val_out.numpy()

        fold_alphas.append(torch.sigmoid(fold_model.fusion_alpha).detach().numpy().copy())
        fold_histories.append(fold_hist)

    return oof_preds, fold_histories, fold_alphas


# 函数8:优化集成权重(网格搜索)
def optimize_ensemble_weight(oof_proto_flat, oof_mlp_flat, y_true_flat):
    """网格搜索融合权重,找到最优的ProtoSSM集成权重"""
    weights = np.arange(0.0, 1.05, 0.05)
    results = []

    for w in weights:
        blended = w * oof_proto_flat + (1.0 - w) * oof_mlp_flat
        try:
            auc = macro_auc_skip_empty(y_true_flat, blended)
        except Exception:
            auc = 0.0
        results.append((w, auc))

    best_w, best_auc = max(results, key=lambda x: x[1])
    return best_w, best_auc, results


print("ProtoSSM v4 training functions defined (with mixup, focal loss, SWA, TTA).")
ProtoSSM v4 training functions defined (with mixup, focal loss, SWA, TTA).

1. Mixup:把两个文件的音频 / 标签混在一起让模型学,就像 “给题目加干扰项”,模型更抗干扰

2. 焦点损失:重点关注 “总是认错的鸟”,降低 “一眼就能认出的鸟” 的权重

3. SWA:把最后几个 epoch 的模型 “平均一下”,就像 “多个老师一起改卷”,结果更稳

4. TTA:推理时把音频滑动几次,多次预测取平均,就像 “多次考试取平均分”,成绩更准

5. 5 折 GroupKFold:按文件分 5 折考试,绝对不让模型提前看答案

6. 集成优化:最后试 “ProtoSSM 说的话占几分,MLP 说的话占几分”,组合起来效果最好

探针调优

仅在训练模式下,对探针超参数进行网格搜索

# Cell 10 — 探针调优(仅训练模式)
grid_results = None
BEST_PROBE = None

# ---------------------- 1. 探针检查:用默认参数跑一次,看看有没有效果 ----------------------
if CFG["run_probe_check"]:
    probe_result = run_oof_embedding_probe(
        scores_raw=scores_full_raw,
        emb=emb_full,
        meta_df=meta_full,
        y_true=Y_FULL,
        pca_dim=64,
        min_pos=8,
        C=0.25,
        alpha=0.5,
    )

    # 打印结果对比
    print(f"Honest OOF baseline AUC: {probe_result['score_base']:.6f}")
    print(f"Honest OOF embedding-probe AUC: {probe_result['score_final']:.6f}")
    print(f"Delta: {probe_result['score_final'] - probe_result['score_base']:.6f}")

    # 打印被建模的类别
    modeled_classes = np.where(probe_result["modeled_counts"] > 0)[0]
    print("Modeled classes:", len(modeled_classes))
    print([PRIMARY_LABELS[i] for i in modeled_classes[:20]])

# ---------------------- 2. 探针网格搜索:遍历候选超参,找最优 ----------------------
if CFG["run_probe_grid"]:
    # 候选超参组合:6组
    param_grid = [
        {"pca_dim": 32, "min_pos": 8,  "C": 0.25, "alpha": 0.4},
        {"pca_dim": 64, "min_pos": 8,  "C": 0.25, "alpha": 0.4},
        {"pca_dim": 64, "min_pos": 8,  "C": 0.25, "alpha": 0.5},
        {"pca_dim": 64, "min_pos": 12, "C": 0.25, "alpha": 0.4},
        {"pca_dim": 96, "min_pos": 8,  "C": 0.25, "alpha": 0.4},
        {"pca_dim": 64, "min_pos": 8,  "C": 0.50, "alpha": 0.4},
    ]

    results = []
    # 遍历每组超参
    for params in tqdm(param_grid, desc="Probe grid", disable=not CFG["verbose"]):
        out = run_oof_embedding_probe(
            scores_raw=scores_full_raw,
            emb=emb_full,
            meta_df=meta_full,
            y_true=Y_FULL,
            pca_dim=params["pca_dim"],
            min_pos=params["min_pos"],
            C=params["C"],
            alpha=params["alpha"],
        )
        # 记录结果
        results.append({
            **params,
            "baseline_oof_auc": out["score_base"],
            "probe_oof_auc": out["score_final"],
            "delta": out["score_final"] - out["score_base"],
            "n_modeled_classes": int((out["modeled_counts"] > 0).sum()),
        })

    # 转成DataFrame,按探针AUC降序排序
    grid_results = pd.DataFrame(results).sort_values("probe_oof_auc", ascending=False).reset_index(drop=True)
    display(grid_results)

    # 选最优参数
    BEST_PROBE = {
        "pca_dim": int(grid_results.iloc[0]["pca_dim"]),
        "min_pos": int(grid_results.iloc[0]["min_pos"]),
        "C": float(grid_results.iloc[0]["C"]),
        "alpha": float(grid_results.iloc[0]["alpha"]),
    }

    # 保存最优参数,方便后续冻结
    best_probe_path = CFG["full_cache_work_dir"] / "best_probe_params.json"
    best_probe_path.write_text(json.dumps(BEST_PROBE, indent=2))
    print("Saved best probe params to:", best_probe_path)

# ---------------------- 3. 提交模式:直接加载冻结的最优参数 ----------------------
else:
    BEST_PROBE = CFG["frozen_best_probe"]
    print("Using frozen BEST_PROBE in submit mode:")
    print(BEST_PROBE)

# ---------------------- 4. 保存网格搜索结果为CSV ----------------------
if grid_results is not None:
    grid_results.to_csv(CFG["full_cache_work_dir"] / "probe_grid_results.csv", index=False)
Using frozen BEST_PROBE in submit mode:
{'pca_dim': 128, 'min_pos': 5, 'C': 0.75, 'alpha': 0.45}

1. 探针检查:先用默认参数给每个鸟训小老师,看看成绩有没有比之前好

2. 网格搜索:把 “教案”(超参)列成 6 组,一组组试,用 OOF AUC 当考试成绩,选分数最高的那组

3. 保存 / 加载:训练模式把最好的 “教案” 存起来,提交模式直接用,不用重新试4. 结果留存:把所有试的结果存成表格,方便以后看哪组 “教案” 适合哪类鸟

最终探针拟合

冻结探针超参数,在全部数据上拟合最终模型。

# Cell 11 — 冻结最终探针参数

# 双重保险:如果BEST_PROBE为空,加载配置里的冻结最优参数
if BEST_PROBE is None:
    BEST_PROBE = CFG["frozen_best_probe"]

# 打印最终确认的探针超参数
print("Final BEST_PROBE =", BEST_PROBE)

# 初始化最优OOF结果
BEST_OOF_RESULT = None

# 训练模式:可选重跑一次最优OOF探针,做诊断/缓存
if MODE == "train":
    BEST_OOF_RESULT = run_oof_embedding_probe(
        scores_raw=scores_full_raw,
        emb=emb_full,
        meta_df=meta_full,
        y_true=Y_FULL,
        pca_dim=int(BEST_PROBE["pca_dim"]),
        min_pos=int(BEST_PROBE["min_pos"]),
        C=float(BEST_PROBE["C"]),
        alpha=float(BEST_PROBE["alpha"]),
    )

    # 打印重跑后的OOF AUC对比
    print(f"Honest OOF baseline AUC (BEST_PROBE rerun): {BEST_OOF_RESULT['score_base']:.6f}")
    print(f"Honest OOF probe AUC   (BEST_PROBE rerun): {BEST_OOF_RESULT['score_final']:.6f}")
Final BEST_PROBE = {'pca_dim': 128, 'min_pos': 5, 'C': 0.75, 'alpha': 0.45}

1. 双重保险:如果之前没找到最好的 “教案”,直接用之前存好的 “金牌教案”

2. 训练模式可选再考一次试:用金牌教案再跑 5 折 OOF,一是再确认成绩,二是把这次的考试答案(OOF 结果)存下来,后面和 ProtoSSM 的答案一起改

3. 提交模式直接跳过:不用浪费时间再考试

# Cell 12 — 在所有标注声景上拟合最终先验表

# 用全部清洗后的标注数据,拟合最终的先验表(全局/站点/小时/站点-小时联合)
final_prior_tables = fit_prior_tables(sc_clean.reset_index(drop=True), Y_SC)

# 打印提示:推理用的先验表已构建完成
print("Built final prior tables for inference.")
# 打印:用于堆叠器训练的OOF基线AUC
print("OOF baseline AUC used for stacker training:", baseline_oof_auc)
Built final prior tables for inference.
OOF baseline AUC used for stacker training: 0.8031805065371714

之前考试(OOF)时,每折只敢用部分同学的 “作息表”(先验表),怕提前看答案;现在要上考场了,把所有同学的作息表都统计一遍,覆盖的站点、时间、鸟类更多,更准,同时把之前的考试成绩(OOF 基线 AUC)记下来,后面给堆叠器当参考

# Cell 13 — 在所有可信全窗口上拟合嵌入标准化器 + PCA

# 1. 初始化标准化器:把特征缩放到均值0、方差1
emb_scaler = StandardScaler()
# 在全部嵌入上拟合并转换
emb_full_scaled = emb_scaler.fit_transform(emb_full)

# 2. 确定PCA维度:取最优参数、样本数-1、特征数的最小值(保证PCA有效)
n_comp = min(
    int(BEST_PROBE["pca_dim"]),
    emb_full_scaled.shape[0] - 1,
    emb_full_scaled.shape[1]
)

# 3. 初始化PCA并拟合转换:把高维嵌入降到低维
emb_pca = PCA(n_components=n_comp)
Z_FULL = emb_pca.fit_transform(emb_full_scaled).astype(np.float32)

# 4. 打印形状和方差解释率,验证效果
print("emb_full:", emb_full.shape)
print("Z_FULL:", Z_FULL.shape)
print("Explained variance ratio sum:", emb_pca.explained_variance_ratio_.sum())
emb_full: (708, 1536)
Z_FULL: (708, 128)
Explained variance ratio sum: 0.8997401

之前考试(OOF)时,每折只敢用部分同学的 “尺子”(标准化器)和 “简化笔记”(PCA),怕提前看答案;现在要训最终模型了,把所有同学的尺子和简化笔记都用上,预处理更准,同时看看简化笔记有没有丢太多重点(方差解释率)

ProtoSSM 训练

实例化 ProtoSSM 模型,从类别均值初始化原型向量,设置分类学辅助头,并用多任务损失进行训练。

# 实例化并训练ProtoSSM v4

# --- 步骤1:重塑为文件级格式 ---
emb_files, file_list = reshape_to_files(emb_full, meta_full)  # 嵌入转文件级
logits_files, _ = reshape_to_files(scores_full_raw, meta_full)  # Perch logits转文件级
labels_files, _ = reshape_to_files(Y_FULL, meta_full)  # 标签转文件级

print(f"Reshaped to file-level: emb={emb_files.shape}, logits={logits_files.shape}, labels={labels_files.shape}")
print(f"Files: {len(file_list)}")

# --- 步骤2:构建分类学分组、站点映射、文件元数据 ---
n_families, class_to_family, fam_to_idx = build_taxonomy_groups(taxonomy, PRIMARY_LABELS)  # 分类学分组
print(f"Taxonomic groups: {n_families}")

site_to_idx, n_sites_mapped = build_site_mapping(meta_full)  # 站点映射
n_sites_cfg = CFG["proto_ssm"]["n_sites"]
print(f"Sites mapped: {n_sites_mapped} (capped to {n_sites_cfg})")

site_ids_all, hours_all = get_file_metadata(meta_full, file_list, site_to_idx, n_sites_cfg)  # 文件级站点、小时

# 构建文件级家族标签(多热)
file_families = np.zeros((len(file_list), n_families), dtype=np.float32)
for fi in range(len(file_list)):
    active_classes = np.where(labels_files[fi].sum(axis=0) > 0)[0]
    for ci in active_classes:
        file_families[fi, class_to_family[ci]] = 1.0

# --- OOF交叉验证(仅训练模式) ---
ENSEMBLE_WEIGHT_PROTO = 0.5  # 默认权重,训练模式用OOF覆盖
oof_proto_flat = None
fold_alphas = []

if MODE == "train":
    # 从文件名提取分组(用于GroupKFold)
    file_groups = np.array([f.split("_")[3] if len(f.split("_")) > 3 else f for f in file_list])
    print(f"File groups for OOF: {len(set(file_groups))} unique groups: {sorted(set(file_groups))}")

    # 运行ProtoSSM OOF
    t0_oof = time.time()
    oof_proto_preds, fold_histories, fold_alphas = run_proto_ssm_oof(
        emb_files, logits_files, labels_files,
        site_ids_all, hours_all,
        file_families, file_groups,
        n_families, class_to_family,
        cfg=CFG["proto_ssm_train"],
        verbose=CFG["verbose"],
    )
    oof_time = time.time() - t0_oof
    print(f"\nOOF cross-validation time: {oof_time:.1f}s")

    # 扁平化OOF预测,计算AUC
    oof_proto_flat = oof_proto_preds.reshape(-1, N_CLASSES)
    y_flat = labels_files.reshape(-1, N_CLASSES).astype(np.float32)

    # 计算逐类别AUC
    per_class_auc_proto = {}
    for ci in range(N_CLASSES):
        if y_flat[:, ci].sum() > 0 and y_flat[:, ci].sum() < len(y_flat):
            try:
                per_class_auc_proto[ci] = roc_auc_score(y_flat[:, ci], oof_proto_flat[:, ci])
            except Exception:
                pass

    # 计算整体宏观AUC
    overall_oof_auc_proto = macro_auc_skip_empty(y_flat, oof_proto_flat)
    print(f"ProtoSSM OOF macro AUC: {overall_oof_auc_proto:.4f}")

    # 记录日志
    LOGS["oof_auc_proto"] = overall_oof_auc_proto
    LOGS["per_class_auc_proto"] = {PRIMARY_LABELS[k]: v for k, v in per_class_auc_proto.items()}
    LOGS["oof_time"] = oof_time
else:
    print("Submit mode: skipping OOF cross-validation")

# --- 用全部数据训最终模型 ---
ssm_cfg = CFG["proto_ssm"]
model = ProtoSSMv2(
    d_input=emb_full.shape[1],
    d_model=ssm_cfg["d_model"],
    d_state=ssm_cfg["d_state"],
    n_ssm_layers=ssm_cfg["n_ssm_layers"],
    n_classes=N_CLASSES,
    n_windows=N_WINDOWS,
    dropout=ssm_cfg["dropout"],
    n_sites=ssm_cfg["n_sites"],
    meta_dim=ssm_cfg["meta_dim"],
    use_cross_attn=ssm_cfg.get("use_cross_attn", True),
    cross_attn_heads=ssm_cfg.get("cross_attn_heads", 4),
).to(DEVICE)

# 原型热启动:用全部数据的类别均值初始化
emb_flat_tensor = torch.tensor(emb_full, dtype=torch.float32)
labels_flat_tensor = torch.tensor(Y_FULL, dtype=torch.float32)
model.init_prototypes_from_data(emb_flat_tensor, labels_flat_tensor)
model.init_family_head(n_families, class_to_family)  # 初始化分类学辅助头

print(f"\nProtoSSM v4 parameters: {model.count_parameters():,}")

# 训最终模型
t0_final = time.time()
model, train_history = train_proto_ssm_single(
    model,
    emb_files, logits_files, labels_files.astype(np.float32),
    site_ids_train=site_ids_all, hours_train=hours_all,
    cfg=CFG["proto_ssm_train"],
    verbose=True,
)
train_time = time.time() - t0_final
print(f"Final model training time: {train_time:.1f}s")

# 打印最终融合权重
with torch.no_grad():
    final_alphas = torch.sigmoid(model.fusion_alpha).numpy()
    print(f"Fusion alpha: mean={final_alphas.mean():.4f} min={final_alphas.min():.4f} max={final_alphas.max():.4f}")

# --- 训MLP探针 ---
# 筛选正样本足够的类别
PROBE_CLASS_IDX = np.where(Y_FULL.sum(axis=0) >= int(CFG["frozen_best_probe"]["min_pos"]))[0].astype(np.int32)

probe_models = {}
for cls_idx in tqdm(PROBE_CLASS_IDX, desc="Training MLP probes", disable=not CFG["verbose"]):
    y = Y_FULL[:, cls_idx]
    if y.sum() == 0 or y.sum() == len(y):  # 全正或全负,跳过
        continue
    # 构建类别特征
    X_cls = build_class_features(
        Z_FULL,
        raw_col=scores_full_raw[:, cls_idx],
        prior_col=oof_prior[:, cls_idx],
        base_col=oof_base[:, cls_idx],
    )
    # 过采样平衡正负样本
    n_pos = int(y.sum())
    n_neg = len(y) - n_pos
    if n_pos > 0 and n_neg > n_pos:
        repeat = max(1, n_neg // n_pos)
        pos_idx = np.where(y == 1)[0]
        X_bal = np.vstack([X_cls, np.tile(X_cls[pos_idx], (repeat, 1))])
        y_bal = np.concatenate([y, np.ones(len(pos_idx) * repeat, dtype=y.dtype)])
    else:
        X_bal, y_bal = X_cls, y
    # 训MLP
    clf = MLPClassifier(**CFG["mlp_params"])
    clf.fit(X_bal, y_bal)
    probe_models[cls_idx] = clf

print(f"MLP probes trained: {len(probe_models)}")

# --- 优化集成权重(仅训练模式) ---
if MODE == "train" and oof_proto_flat is not None:
    # 生成MLP的OOF预测
    oof_mlp_flat = oof_base.copy()
    for cls_idx, clf in probe_models.items():
        X_cls = build_class_features(
            Z_FULL,
            raw_col=scores_full_raw[:, cls_idx],
            prior_col=oof_prior[:, cls_idx],
            base_col=oof_base[:, cls_idx],
        )
        # 生成预测并转logits
        if hasattr(clf, "predict_proba"):
            prob = clf.predict_proba(X_cls)[:, 1].astype(np.float32)
            pred = np.log(prob + 1e-7) - np.log(1 - prob + 1e-7)
        else:
            pred = clf.decision_function(X_cls).astype(np.float32)
        # 融合基础分数和探针预测
        alpha_probe = float(CFG["frozen_best_probe"]["alpha"])
        oof_mlp_flat[:, cls_idx] = (1.0 - alpha_probe) * oof_base[:, cls_idx] + alpha_probe * pred

    # 网格搜索最优集成权重
    y_flat = labels_files.reshape(-1, N_CLASSES).astype(np.float32)
    best_w, best_auc, weight_results = optimize_ensemble_weight(oof_proto_flat, oof_mlp_flat, y_flat)
    ENSEMBLE_WEIGHT_PROTO = best_w

    # 打印结果
    mlp_only_auc = macro_auc_skip_empty(y_flat, oof_mlp_flat)
    print(f"\n=== Ensemble Optimization ===")
    print(f"Best ProtoSSM weight: {ENSEMBLE_WEIGHT_PROTO:.2f}")
    print(f"Best ensemble OOF AUC: {best_auc:.4f}")
    print(f"MLP-only OOF AUC: {mlp_only_auc:.4f}")

    for w, auc in weight_results:
        marker = " <-- best" if abs(w - best_w) < 0.01 else ""
        print(f"  w={w:.2f}: AUC={auc:.4f}{marker}")

    # 记录日志
    LOGS["ensemble_weight"] = ENSEMBLE_WEIGHT_PROTO
    LOGS["ensemble_auc"] = best_auc
    LOGS["mlp_only_auc"] = mlp_only_auc
else:
    print(f"\nUsing default ensemble weight: ProtoSSM={ENSEMBLE_WEIGHT_PROTO:.2f}")

# 记录更多日志
LOGS["train_time_final"] = train_time
LOGS["n_probe_models"] = len(probe_models)

# 打印跨折的平均融合权重
if fold_alphas:
    mean_alphas = np.stack(fold_alphas).mean(axis=0)
    print(f"\nFusion alpha (mean across folds):")
    print(f"  ProtoSSM-dominant (alpha>0.5): {(mean_alphas > 0.5).sum()} classes")
    print(f"  Perch-dominant (alpha<=0.5): {(mean_alphas <= 0.5).sum()} classes")
Reshaped to file-level: emb=(59, 12, 1536), logits=(59, 12, 234), labels=(59, 12, 234)
Files: 59
Taxonomic groups: 5
Sites mapped: 9 (capped to 20)
Submit mode: skipping OOF cross-validation

ProtoSSM v4 parameters: 5,776,250
  Epoch  20: train=0.4892 val=0.4892 auc=0.0000 lr=0.000737 wait=4
  Epoch  40: train=0.4878 val=0.4878 auc=0.0000 lr=0.000452 wait=12
  Epoch  60: train=0.4928 val=0.4928 auc=0.0000 lr=0.000130 wait=13 swa=8
  Early stopping at epoch 67 (best val_loss=0.4594)
  Applying SWA (averaged 15 checkpoints)
  Training complete. Best val_loss=0.4594
  Fusion alpha: mean=0.501 min=0.493 max=0.506
  Proto temperature: 5.031
Final model training time: 74.5s
Fusion alpha: mean=0.5010 min=0.4935 max=0.5063
MLP probes trained: 58

Using default ensemble weight: ProtoSSM=0.50

1. 把数据整理成 “一份文件 12 个窗口” 的格式,适合 ProtoSSM 听完整段音频

2. 准备好 “站点、时间、鸟类亲戚关系” 这些辅助信息

3. 训练模式:先按文件分 5 折考试,看看 ProtoSSM 的成绩

4. 用所有数据训出 “见过最多题目的” 最终 ProtoSSM 和 MLP 小老师

5. 训练模式:试 “ProtoSSM 说的话占几分,MLP 说的话占几分”,找到组合起来成绩最好的比例

6. 提交模式直接用之前存好的比例,不用重新考试

残差 SSM 二次增强

第一遍 ProtoSSM + MLP 集成的残差(误差)上,训练一个轻量级 SSM。残差模型学习系统性的修正模式:持续被高估 / 低估的物种、时序误差模式,以及站点特定偏差。

# 残差SSM:基于第一遍误差的二次增强
# 时间安全:如果已耗时>4分钟则跳过(为测试推理预留充足时间)
_wall_min = (time.time() - _WALL_START) / 60.0
print(f"Wall time: {_wall_min:.1f} min")

res_model = None
CORRECTION_WEIGHT = 0.0

if _wall_min < 4.0:
    print("Training ResidualSSM...")
    
    # 定义轻量级残差SSM类
    class ResidualSSM(nn.Module):
        # 轻量级SSM,输入第一遍分数+嵌入,预测修正
        # 架构:投影(拼接(嵌入, 第一遍)) → 1层双向SSM → 线性头
        def __init__(self, d_input=1536, d_scores=234, d_model=64, d_state=8,
                     n_classes=234, n_windows=12, dropout=0.1, n_sites=20, meta_dim=8):
            super().__init__()
            self.d_model = d_model
            self.n_classes = n_classes
    
            # 投影:嵌入+第一遍分数
            self.input_proj = nn.Sequential(
                nn.Linear(d_input + d_scores, d_model),
                nn.LayerNorm(d_model),
                nn.GELU(),
                nn.Dropout(dropout),
            )
    
            # 元数据嵌入
            self.site_emb = nn.Embedding(n_sites, meta_dim)
            self.hour_emb = nn.Embedding(24, meta_dim)
            self.meta_proj = nn.Linear(2 * meta_dim, d_model)
    
            # 位置编码
            self.pos_enc = nn.Parameter(torch.randn(1, n_windows, d_model) * 0.02)
    
            # 单层双向SSM(轻量)
            self.ssm_fwd = SelectiveSSM(d_model, d_state)
            self.ssm_bwd = SelectiveSSM(d_model, d_state)
            self.ssm_merge = nn.Linear(2 * d_model, d_model)
            self.ssm_norm = nn.LayerNorm(d_model)
            self.ssm_drop = nn.Dropout(dropout)
    
            # 输出:逐类别修正(加法)
            self.output_head = nn.Linear(d_model, n_classes)
    
            # 初始化输出接近0(修正从小开始)
            nn.init.zeros_(self.output_head.weight)
            nn.init.zeros_(self.output_head.bias)
    
        def forward(self, emb, first_pass_scores, site_ids=None, hours=None):
            # emb: (B, T, d_input), first_pass_scores: (B, T, n_classes)
            B, T, _ = emb.shape
    
            # 拼接嵌入和第一遍分数
            x = torch.cat([emb, first_pass_scores], dim=-1)  # (B, T, d_input + d_scores)
            h = self.input_proj(x)
    
            # 加入元数据
            if site_ids is not None and hours is not None:
                site_e = self.site_emb(site_ids.clamp(0, self.site_emb.num_embeddings - 1))
                hour_e = self.hour_emb(hours.clamp(0, 23))
                meta = self.meta_proj(torch.cat([site_e, hour_e], dim=-1))
                h = h + meta.unsqueeze(1)
    
            h = h + self.pos_enc[:, :T, :]
    
            # 双向SSM
            residual = h
            h_f = self.ssm_fwd(h)
            h_b = self.ssm_bwd(h.flip(1)).flip(1)
            h = self.ssm_merge(torch.cat([h_f, h_b], dim=-1))
            h = self.ssm_drop(h)
            h = self.ssm_norm(h + residual)
    
            # 输出修正
            correction = self.output_head(h)  # (B, T, n_classes)
            return correction
    
        def count_parameters(self):
            return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    
    # --- 在第一遍误差上训练残差SSM ---
    
    # 步骤1:计算训练数据上的第一遍分数
    model.eval()
    with torch.no_grad():
        emb_train_t = torch.tensor(emb_files, dtype=torch.float32)
        logits_train_t = torch.tensor(logits_files, dtype=torch.float32)
        site_train_t = torch.tensor(site_ids_all, dtype=torch.long)
        hour_train_t = torch.tensor(hours_all, dtype=torch.long)
    
        proto_train_out, _, _ = model(emb_train_t, logits_train_t,
                                       site_ids=site_train_t, hours=hour_train_t)
        proto_train_scores = proto_train_out.numpy()  # (n_files, 12, 234)
    
    # 训练数据上的MLP探针分数(扁平)
    mlp_train_scores_flat = np.zeros_like(scores_full_raw, dtype=np.float32)
    
    # 获取MLP用的先验融合基础分数
    train_base_scores, train_prior_scores = fuse_scores_with_tables(
        scores_full_raw,
        sites=meta_full["site"].to_numpy(),
        hours=meta_full["hour_utc"].to_numpy(),
        tables=final_prior_tables,
    )
    mlp_train_scores_flat = train_base_scores.copy()
    
    # 生成MLP的训练分数
    for cls_idx, clf in probe_models.items():
        X_cls = build_class_features(
            Z_FULL,
            raw_col=scores_full_raw[:, cls_idx],
            prior_col=train_prior_scores[:, cls_idx],
            base_col=train_base_scores[:, cls_idx],
        )
        if hasattr(clf, "predict_proba"):
            prob = clf.predict_proba(X_cls)[:, 1].astype(np.float32)
            pred = np.log(prob + 1e-7) - np.log(1 - prob + 1e-7)
        else:
            pred = clf.decision_function(X_cls).astype(np.float32)
        alpha_p = float(CFG["frozen_best_probe"]["alpha"])
        mlp_train_scores_flat[:, cls_idx] = (1.0 - alpha_p) * train_base_scores[:, cls_idx] + alpha_p * pred
    
    # 重塑MLP分数为文件级
    mlp_train_scores_files, _ = reshape_to_files(mlp_train_scores_flat, meta_full)
    
    # 第一遍集成(和测试时一样的公式)
    first_pass_files = (
        ENSEMBLE_WEIGHT_PROTO * proto_train_scores +
        (1 - ENSEMBLE_WEIGHT_PROTO) * mlp_train_scores_files
    ).astype(np.float32)
    
    # 步骤2:计算残差(第一遍错了多少)
    # 目标:Y_FULL重塑为文件级。残差 = 目标 - sigmoid(第一遍)
    labels_float = labels_files.astype(np.float32)
    first_pass_probs = 1.0 / (1.0 + np.exp(-first_pass_files))
    residuals = labels_float - first_pass_probs  # 范围[-1, 1]
    
    print(f"First-pass training scores: {first_pass_files.shape}")
    print(f"Residuals: mean={residuals.mean():.4f}, std={residuals.std():.4f}, "
          f"abs_mean={np.abs(residuals).mean():.4f}")
    
    # 步骤3:训练残差SSM
    res_cfg = CFG["residual_ssm"]
    res_model = ResidualSSM(
        d_input=emb_full.shape[1],
        d_scores=N_CLASSES,
        d_model=res_cfg["d_model"],
        d_state=res_cfg["d_state"],
        n_classes=N_CLASSES,
        n_windows=N_WINDOWS,
        dropout=res_cfg["dropout"],
        n_sites=CFG["proto_ssm"]["n_sites"],
        meta_dim=8,
    ).to(DEVICE)
    
    print(f"ResidualSSM parameters: {res_model.count_parameters():,}")
    
    # 用MSE损失训练残差SSM
    n_files = len(file_list)
    n_val = max(1, int(n_files * 0.15))
    perm = torch.randperm(n_files, generator=torch.Generator().manual_seed(123))
    val_i = perm[:n_val].numpy()
    train_i = perm[n_val:].numpy()
    
    # 转tensor
    emb_tr = torch.tensor(emb_files[train_i], dtype=torch.float32)
    fp_tr = torch.tensor(first_pass_files[train_i], dtype=torch.float32)
    res_tr = torch.tensor(residuals[train_i], dtype=torch.float32)
    site_tr = torch.tensor(site_ids_all[train_i], dtype=torch.long)
    hour_tr = torch.tensor(hours_all[train_i], dtype=torch.long)
    
    emb_va = torch.tensor(emb_files[val_i], dtype=torch.float32)
    fp_va = torch.tensor(first_pass_files[val_i], dtype=torch.float32)
    res_va = torch.tensor(residuals[val_i], dtype=torch.float32)
    site_va = torch.tensor(site_ids_all[val_i], dtype=torch.long)
    hour_va = torch.tensor(hours_all[val_i], dtype=torch.long)
    
    # 优化器和调度器
    optimizer = torch.optim.AdamW(res_model.parameters(), lr=res_cfg["lr"], weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=res_cfg["lr"],
        epochs=res_cfg["n_epochs"], steps_per_epoch=1,
        pct_start=0.1, anneal_strategy='cos'
    )
    
    best_val_loss = float('inf')
    best_state = None
    wait = 0
    
    # 训练循环
    t0_res = time.time()
    for epoch in range(res_cfg["n_epochs"]):
        res_model.train()
        correction = res_model(emb_tr, fp_tr, site_ids=site_tr, hours=hour_tr)
        loss = F.mse_loss(correction, res_tr)  # MSE损失:拟合残差
    
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(res_model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    
        # 验证
        res_model.eval()
        with torch.no_grad():
            val_corr = res_model(emb_va, fp_va, site_ids=site_va, hours=hour_va)
            val_loss = F.mse_loss(val_corr, res_va)
    
        # 早停逻辑
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            best_state = {k: v.clone() for k, v in res_model.state_dict().items()}
            wait = 0
        else:
            wait += 1
    
        if (epoch + 1) % 20 == 0:
            print(f"  ResidualSSM epoch {epoch+1}: train={loss.item():.6f} val={val_loss.item():.6f} wait={wait}")
    
        if wait >= res_cfg["patience"]:
            print(f"  ResidualSSM early stop at epoch {epoch+1}")
            break
    
    # 加载最佳权重
    if best_state is not None:
        res_model.load_state_dict(best_state)
    
    res_time = time.time() - t0_res
    print(f"ResidualSSM training time: {res_time:.1f}s")
    print(f"Best val MSE: {best_val_loss:.6f}")
    
    # 验证修正幅度
    res_model.eval()
    with torch.no_grad():
        all_corr = res_model(emb_train_t, torch.tensor(first_pass_files, dtype=torch.float32),
                             site_ids=site_train_t, hours=hour_train_t)
        corr_np = all_corr.numpy()
        print(f"Correction magnitude: mean_abs={np.abs(corr_np).mean():.4f}, max={np.abs(corr_np).max():.4f}")
    
    # 设置修正权重
    CORRECTION_WEIGHT = res_cfg["correction_weight"]
    print(f"Correction weight: {CORRECTION_WEIGHT}")
    LOGS["residual_ssm"] = {
        "params": res_model.count_parameters(),
        "train_time": res_time,
        "best_val_mse": best_val_loss,
        "correction_mean_abs": float(np.abs(corr_np).mean()),
        "correction_weight": CORRECTION_WEIGHT,
    }
    
else:
    print("SKIPPED ResidualSSM (wall time safety)")
    LOGS["residual_ssm"] = {"skipped": True, "wall_min": _wall_min}
Wall time: 2.0 min
Training ResidualSSM...
First-pass training scores: (59, 12, 234)
Residuals: mean=-0.4299, std=0.3209, abs_mean=0.4383
ResidualSSM parameters: 439,498
  ResidualSSM epoch 20: train=0.020213 val=0.021534 wait=0
  ResidualSSM epoch 40: train=0.016882 val=0.018425 wait=0
ResidualSSM training time: 3.4s
Best val MSE: 0.018425
Correction magnitude: mean_abs=0.4492, max=1.0420
Correction weight: 0.35

1. 先看时间:如果训练已经超过 4 分钟,直接跳过残差 SSM,把时间留给考试(测试推理)

2. 定义一个很小的错题小老师(残差 SSM),只学 “怎么改错题”,不学新东西

3. 先让 ProtoSSM+MLP 做一遍训练题,看看它们错了多少(残差 = 真实答案 - 第一遍的概率)

4. 让错题小老师学这些错误的规律 —— 比如 “这个鸟总是被低估,要加 0.1 分”“这个站点的鸟要减 0.05 分”

5. 用小的修正权重(比如 0.1)把错题小老师的修改加回去,避免改过头

6. 记录日志,看看错题小老师的表现

# Cell 15 — 诊断
if MODE == "train":  # 训练模式
    if grid_results is not None:  # 如果有网格搜索结果
        best_row = grid_results.iloc[0]  # 拿第一名的结果
        print(f"Best honest OOF probe AUC: {best_row['probe_oof_auc']:.6f}")  # 打印最优探针OOF AUC
        print(f"Delta over honest OOF baseline: {best_row['delta']:.6f}")  # 打印比基线的提升
else:  # 提交模式
    print("Skipping train diagnostics in submit mode.")  # 跳过训练诊断
Skipping train diagnostics in submit mode.

训练模式:如果之前给小老师调过 “教案”(网格搜索),就拿出最好的那份教案的成绩,看看小老师考了多少分,比之前提升了多少;提交模式直接跳过,不用看这些

测试推理

隐藏测试声景上运行 Perch。

# Cell 16 — 在隐藏测试集上推理Perch(含嵌入)
# 获取排序后的隐藏测试声景文件路径
test_paths = sorted((BASE / "test_soundscapes").glob("*.ogg"))

# 如果没有隐藏测试集,用训练集前几个做dry run
if len(test_paths) == 0:
    print(f"Hidden test not mounted. Dry-run on first {CFG['dryrun_n_files']} train soundscapes.")
    test_paths = sorted((BASE / "train_soundscapes").glob("*.ogg"))[:CFG["dryrun_n_files"]]
else:
    print(f"Hidden test files: {len(test_paths)}")

# 调用Perch推理,用配置里的最佳proxy_reduce(非硬编码max)
meta_test, scores_test_raw, emb_test = infer_perch_with_embeddings(
    test_paths,
    batch_files=CFG["batch_files"],
    verbose=CFG["verbose"],
    proxy_reduce=CFG["proxy_reduce"],  # 用网格搜索的结果,默认max
)
print(f"proxy_reduce used for test inference: {CFG['proxy_reduce']!r}")

# 打印形状确认
print("meta_test:", meta_test.shape)
print("scores_test_raw:", scores_test_raw.shape)
print("emb_test:", emb_test.shape)
Hidden test not mounted. Dry-run on first 20 train soundscapes.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1776588231.389336      73 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
proxy_reduce used for test inference: 'max'
meta_test: (240, 4)
scores_test_raw: (240, 234)
emb_test: (240, 1536)

这是测试推理的第一步 Perch 预处理,核心逻辑:

  1. 测试文件确认:优先用隐藏测试集,没有的话用训练集前几个做 dry run,防止代码报错
  2. 最佳推理配置:用之前网格搜索找到的最佳proxy_reduce(代理降维方式,比如 "max"/"mean"),不是硬编码,保证推理和训练时的预处理一致
  3. 生成三要素:调用 Perch 推理,生成测试集的元数据(站点、小时等)、原始 logits、1536 维嵌入,作为后续所有模型的输入
  4. 形状验证:打印三个输出的形状,确认推理正常

分数融合:ProtoSSM v2 + MLP 集成

使用OOF 优化的集成权重,将 ProtoSSM v2 的时序预测与 MLP 探针分数结合。包含元数据感知推理和全面的诊断。

# 分数融合:ProtoSSM v4 + MLP探针 + 先验 + TTA(OOF优化权重)

# --- 步骤1:测试集上ProtoSSM v4推理(含TTA) ---
emb_test_files, test_file_list = reshape_to_files(emb_test, meta_test)  # 转文件级格式
logits_test_files, _ = reshape_to_files(scores_test_raw, meta_test)

# 构建测试元数据
test_site_ids, test_hours = get_file_metadata(meta_test, test_file_list, site_to_idx, CFG["proto_ssm"]["n_sites"])

# 转tensor
emb_test_tensor = torch.tensor(emb_test_files, dtype=torch.float32)
logits_test_tensor = torch.tensor(logits_test_files, dtype=torch.float32)
test_site_tensor = torch.tensor(test_site_ids, dtype=torch.long)
test_hour_tensor = torch.tensor(test_hours, dtype=torch.long)

# TTA:时序滑动多次预测取平均
model.eval()
tta_shifts = CFG.get("tta_shifts", [0])
if len(tta_shifts) > 1:
    print(f"Running TTA with shifts: {tta_shifts}")
    proto_scores = temporal_shift_tta(
        emb_test_files, logits_test_files, model,
        test_site_ids, test_hours, shifts=tta_shifts
    )
else:
    with torch.no_grad():
        proto_out, _, h_test = model(emb_test_tensor, logits_test_tensor,
                                      site_ids=test_site_tensor, hours=test_hour_tensor)
        proto_scores = proto_out.numpy()

# 扁平化回窗口级
proto_scores_flat = proto_scores.reshape(-1, N_CLASSES).astype(np.float32)

print(f"ProtoSSM v4 test scores: {proto_scores_flat.shape}")
print(f"Score range: {proto_scores_flat.min():.3f} to {proto_scores_flat.max():.3f}")

# --- 步骤2:先验融合基础分数 ---
test_base_scores, test_prior_scores = fuse_scores_with_tables(
    scores_test_raw,
    sites=meta_test["site"].to_numpy(),
    hours=meta_test["hour_utc"].to_numpy(),
    tables=final_prior_tables,
)

# --- 步骤3:MLP探针分数 ---
emb_test_scaled = emb_scaler.transform(emb_test)  # 用训练好的标准化器
Z_TEST = emb_pca.transform(emb_test_scaled).astype(np.float32)  # 用训练好的PCA降维

mlp_scores = test_base_scores.copy()

# 逐类别生成探针分数
for cls_idx, clf in probe_models.items():
    X_cls_test = build_class_features(
        Z_TEST,
        raw_col=scores_test_raw[:, cls_idx],
        prior_col=test_prior_scores[:, cls_idx],
        base_col=test_base_scores[:, cls_idx],
    )

    # 生成预测并转logits
    if hasattr(clf, "predict_proba"):
        prob = clf.predict_proba(X_cls_test)[:, 1].astype(np.float32)
        pred = np.log(prob + 1e-7) - np.log(1 - prob + 1e-7)
    else:
        pred = clf.decision_function(X_cls_test).astype(np.float32)

    # 和基础分融合
    alpha = float(CFG["frozen_best_probe"]["alpha"])
    mlp_scores[:, cls_idx] = (1.0 - alpha) * test_base_scores[:, cls_idx] + alpha * pred

# --- 步骤4:OOF优化权重集成 ---
print(f"\nUsing OOF-optimized ensemble weight: {ENSEMBLE_WEIGHT_PROTO:.2f}")

final_test_scores = (
    ENSEMBLE_WEIGHT_PROTO * proto_scores_flat +
    (1.0 - ENSEMBLE_WEIGHT_PROTO) * mlp_scores
).astype(np.float32)

# --- 步骤5:残差SSM修正(二次) ---
if res_model is not None and CORRECTION_WEIGHT > 0:
    first_pass_test_files, _ = reshape_to_files(final_test_scores, meta_test)
    first_pass_test_t = torch.tensor(first_pass_test_files, dtype=torch.float32)

    res_model.eval()
    with torch.no_grad():
        test_correction = res_model(
            emb_test_tensor, first_pass_test_t,
            site_ids=test_site_tensor, hours=test_hour_tensor
        ).numpy()

    test_correction_flat = test_correction.reshape(-1, N_CLASSES).astype(np.float32)

    print(f"\nResidual correction: mean_abs={np.abs(test_correction_flat).mean():.4f}, "
          f"max={np.abs(test_correction_flat).max():.4f}")

    # 加回修正
    final_test_scores = final_test_scores + CORRECTION_WEIGHT * test_correction_flat
    print(f"Final scores (after residual): range [{final_test_scores.min():.3f}, {final_test_scores.max():.3f}]")
else:
    print("\nResidual correction: SKIPPED")

print(f"Final scores: {final_test_scores.shape}")

# --- 日志记录 ---
test_logs = {}
window_scores = proto_scores.reshape(-1, N_WINDOWS, N_CLASSES).mean(axis=(0, 2))
test_logs["window_position_scores"] = window_scores.tolist()
print(f"\nWindow position mean scores: {[f'{s:.3f}' for s in window_scores]}")

# 分类学分数记录
if hasattr(model, 'class_to_family'):
    taxon_scores = defaultdict(list)
    idx_to_fam = {v: k for k, v in fam_to_idx.items()}
    for ci in range(N_CLASSES):
        fam_idx = class_to_family[ci]
        fam_name = idx_to_fam.get(fam_idx, f"group_{fam_idx}")
        taxon_scores[fam_name].append(float(proto_scores_flat[:, ci].mean()))

    test_logs["taxon_mean_scores"] = {k: float(np.mean(v)) for k, v in taxon_scores.items()}
    for k, v in sorted(taxon_scores.items(), key=lambda x: -np.mean(x[1]))[:5]:
        print(f"  {k}: mean_score={np.mean(v):.4f} (n_classes={len(v)})")

# 原型相似度记录
with torch.no_grad():
    p_norm = F.normalize(model.prototypes, dim=-1)
    cos_sim = torch.matmul(p_norm, p_norm.T)
    cos_sim.fill_diagonal_(0)
    top_sims = cos_sim.max(dim=1)[0].numpy()
    test_logs["prototype_max_similarity"] = {
        "mean": float(top_sims.mean()),
        "max": float(top_sims.max()),
        "min": float(top_sims.min()),
    }
    print(f"\nPrototype nearest-neighbor similarity: mean={top_sims.mean():.3f}, max={top_sims.max():.3f}")


LOGS["test_inference"] = test_logs
Running TTA with shifts: [0, 1, -1, 2, -2]
ProtoSSM v4 test scores: (240, 234)
Score range: -5.996 to 5.944

Using OOF-optimized ensemble weight: 0.50

Residual correction: mean_abs=0.4587, max=1.0420
Final scores (after residual): range [-8.930, 9.079]
Final scores: (240, 234)

Window position mean scores: ['-0.133', '-0.151', '-0.138', '-0.150', '-0.138', '-0.140', '-0.165', '-0.171', '-0.156', '-0.155', '-0.147', '-0.142']
  Aves: mean_score=0.2467 (n_classes=162)
  Reptilia: mean_score=-0.1553 (n_classes=1)
  Insecta: mean_score=-0.3281 (n_classes=28)
  Mammalia: mean_score=-0.6537 (n_classes=8)
  Amphibia: mean_score=-1.7200 (n_classes=35)

Prototype nearest-neighbor similarity: mean=0.480, max=1.000

这是测试推理的完整分数融合流水线,核心逻辑:

  1. ProtoSSM v4 TTA 推理:把测试数据转文件级,用时序滑动 TTA(多次预测取平均),生成更稳的时序预测
  2. 先验融合基础分:用最终先验表融合 Perch 原始分数,生成基础分
  3. MLP 探针分数:用训练好的标准化器 + PCA 处理测试嵌入,逐类别生成探针分数并和基础分融合
  4. OOF 优化权重集成:用之前 5 折找到的最佳权重,融合 ProtoSSM 和 MLP 分数
  5. 残差 SSM 修正:如果有残差模型,用它生成修正,加回最终分数
  6. 全面日志记录:记录窗口位置分数、分类学分数、原型相似度等诊断信息

提交环节

温度缩放校准提交 CSV 文件生成

# 从OOF优化逐类别阈值(仅训练模式)
PER_CLASS_THRESHOLDS = np.full(N_CLASSES, 0.5, dtype=np.float32)
if MODE == "train" and oof_proto_flat is not None:
    print("Optimizing per-class thresholds from OOF...")
    # 用OOF结果为每个类别优化最优阈值
    best_thresholds, best_scores = optimize_per_class_thresholds(
        oof_proto_flat, Y_FULL, n_windows=N_WINDOWS, thresholds=CFG["threshold_grid"]
    )
    PER_CLASS_THRESHOLDS = best_thresholds.astype(np.float32)
    print(f"  Mean threshold: {best_thresholds.mean():.3f}")
    print(f"  Threshold range: [{best_thresholds.min():.2f}, {best_thresholds.max():.2f}]")
    print(f"  Mean F1 (proxy): {best_scores.mean():.3f}")
    
    # 展示阈值极端的类别
    high_t = np.where(best_thresholds > 0.6)[0]
    low_t = np.where(best_thresholds < 0.4)[0]
    if len(high_t) > 0:
        print(f"  High threshold classes (>0.6): {len(high_t)}")
    if len(low_t) > 0:
        print(f"  Low threshold classes (<0.4): {len(low_t)}")
else:
    # 提交模式:所有类别使用默认阈值0.5
    print("Using default per-class thresholds (0.5) for submit mode")


def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-np.clip(x, -30, 30)))

# --- 步骤1:按分类群温度缩放 ---
temp_cfg = CFG["temperature"]
T_AVES = temp_cfg["aves"]
T_TEXTURE = temp_cfg["texture"]

# 鸟类用T_AVES,纹理类用T_TEXTURE
class_temperatures = np.ones(N_CLASSES, dtype=np.float32) * T_AVES
for ci, label in enumerate(PRIMARY_LABELS):
    cn = CLASS_NAME_MAP.get(label, "Aves")
    if cn in TEXTURE_TAXA:
        class_temperatures[ci] = T_TEXTURE

print(f"\nPer-taxon temperature: Aves={T_AVES}, Texture={T_TEXTURE}")

# 温度缩放并转为概率
scaled_scores = final_test_scores / class_temperatures[None, :]
probs = sigmoid(scaled_scores)

# --- 步骤2:文件级置信度缩放 ---
top_k = CFG.get("file_level_top_k", 0)
if top_k > 0:
    print(f"Applying file-level confidence scaling (top_k={top_k})")
    probs = file_level_confidence_scale(probs, n_windows=N_WINDOWS, top_k=top_k)
    probs = np.clip(probs, 0.0, 1.0)

# --- 步骤3:排名感知后处理 ---
if CFG.get("rank_aware_scale", False):
    power = CFG.get("rank_aware_power", 0.5)
    print(f"Applying rank-aware scaling (power={power})")
    probs = rank_aware_scaling(probs, n_windows=N_WINDOWS, power=power)
    probs = np.clip(probs, 0.0, 1.0)

# --- 步骤4:Delta偏移平滑 ---
def adaptive_delta_smooth(probs, n_windows, base_alpha=0.20):
    n_files = probs.shape[0] // n_windows
    result = probs.copy()
    view = result.reshape(n_files, n_windows, -1)
    p_view = probs.reshape(n_files, n_windows, -1)
    for i in range(1, n_windows - 1):
        conf = p_view[:, i, :].max(axis=-1, keepdims=True)
        a = base_alpha * (1.0 - conf)
        neighbor_avg = (p_view[:, i-1, :] + p_view[:, i+1, :]) / 2.0
        view[:, i, :] = (1.0 - a) * p_view[:, i, :] + a * neighbor_avg
    return result.reshape(probs.shape)

alpha = CFG.get("delta_shift_alpha", 0.0)
if alpha > 0:
    print(f"Applying delta shift smoothing (alpha={alpha})")
    probs = adaptive_delta_smooth(probs, n_windows=N_WINDOWS, base_alpha=alpha)
    probs = np.clip(probs, 0.0, 1.0)

# --- 步骤5:逐类别阈值锐化 ---
print(f"Applying per-class threshold sharpening...")
probs = apply_per_class_thresholds(probs, PER_CLASS_THRESHOLDS, n_windows=N_WINDOWS)

# --- 构建提交文件 ---
submission = pd.DataFrame(probs, columns=PRIMARY_LABELS)
submission.insert(0, "row_id", meta_test["row_id"].values)
submission[PRIMARY_LABELS] = submission[PRIMARY_LABELS].astype(np.float32)

# 格式校验
expected_rows = len(test_paths) * N_WINDOWS
assert len(submission) == expected_rows, f"Expected {expected_rows}, got {len(submission)}"
assert submission.columns.tolist() == ["row_id"] + PRIMARY_LABELS
assert not submission.isna().any().any()

# 保存CSV
submission.to_csv("submission.csv", index=False)

print("\nSaved submission.csv")
print("Submission shape:", submission.shape)
print(f"Final score range: {probs.min():.6f} to {probs.max():.6f}")
print(f"Final mean: {probs.mean():.4f}")
print(submission.iloc[:3, :8])
Using default per-class thresholds (0.5) for submit mode

Per-taxon temperature: Aves=1.1, Texture=0.95
Applying file-level confidence scaling (top_k=2)
Applying rank-aware scaling (power=0.4)
Applying delta shift smoothing (alpha=0.2)
Applying per-class threshold sharpening...

Saved submission.csv
Submission shape: (240, 235)
Final score range: 0.000000 to 0.999794
Final mean: 0.2452
                                     row_id   1161364    116570   1176823  \
0   BC2026_Train_0001_S08_20250606_030007_5  0.101031  0.218902  0.021989   
1  BC2026_Train_0001_S08_20250606_030007_10  0.118506  0.209171  0.018334   
2  BC2026_Train_0001_S08_20250606_030007_15  0.112893  0.190834  0.019861   

        1491113       1595929    209233     22930  
0  7.440146e-07  2.491770e-08  0.297159  0.012316  
1  7.597976e-07  2.313390e-08  0.305450  0.010445  
2  6.123422e-07  2.330344e-08  0.287989  0.011266  

核心总结:

  1. 阈值优化:训练模式用 OOF 为每个类别学最优判断阈值,提交模式用 0.5
  2. 分类群温度校准:鸟类、纹理类分别使用不同温度缩放,校准概率分布
  3. 多层后处理:文件级置信度增强 → 排名感知缩放 → 时序平滑 → 阈值锐化,逐步优化预测概率
  4. 格式校验 + 保存:严格按比赛格式生成 submission.csv,校验行列与空值后导出
# 保存完整日志
wall_time = time.time() - _WALL_START
LOGS["wall_time_seconds"] = wall_time
LOGS["temperature"] = CFG["temperature"]
LOGS["ensemble_weight_proto"] = ENSEMBLE_WEIGHT_PROTO
LOGS["n_classes"] = N_CLASSES
LOGS["n_windows"] = N_WINDOWS
LOGS["cfg_proto_ssm"] = CFG["proto_ssm"]
LOGS["cfg_proto_ssm_train"] = {k: v for k, v in CFG["proto_ssm_train"].items() if not isinstance(v, (np.ndarray,))}
LOGS["improvements"] = [
    "d_model_256", "n_ssm_layers_3", "cross_attention", "mixup", "focal_loss", "swa",
    "per_taxon_temperature", "file_level_scaling", "tta", "rank_aware_scaling",
    "delta_shift_smooth", "per_class_thresholds"
]
LOGS["per_class_thresholds"] = PER_CLASS_THRESHOLDS.tolist()

try:
    with open("/kaggle/working/logs.json", "w") as f:
        json.dump(LOGS, f, indent=2, default=str)
    print("Saved /kaggle/working/logs.json")
except Exception as e:
    print(f"Warning: could not save logs: {e}")

if MODE == "train":
    print("=== ProtoSSM v5 Training Summary ===")
    print(f"Parameters: {model.count_parameters():,}")
    print(f"d_model: {CFG['proto_ssm']['d_model']}, n_ssm_layers: {CFG['proto_ssm']['n_ssm_layers']}")
    print(f"Wall time: {wall_time:.1f}s")
    print(f"OOF CV time: {LOGS.get('oof_time', 0):.1f}s")
    print(f"Final model training time: {LOGS.get('train_time_final', 0):.1f}s")
    print(f"Final train loss: {train_history['train_loss'][-1]:.4f}")
    print(f"Best val loss: {min(train_history['val_loss']):.4f}")
    print(f"Best val AUC: {max(train_history['val_auc']):.4f}")

    print(f"\n=== OOF Results ===")
    print(f"ProtoSSM OOF AUC: {LOGS.get('oof_auc_proto', 0):.4f}")
    print(f"MLP-only OOF AUC: {LOGS.get('mlp_only_auc', 0):.4f}")
    print(f"Ensemble OOF AUC: {LOGS.get('ensemble_auc', 0):.4f}")
    print(f"Optimized ProtoSSM weight: {ENSEMBLE_WEIGHT_PROTO:.2f}")

    with torch.no_grad():
        alphas = torch.sigmoid(model.fusion_alpha).numpy()
        high_proto = (alphas > 0.5).sum()
        high_perch = (alphas <= 0.5).sum()
        print(f"\nFusion alpha distribution (final model):")
        print(f"  ProtoSSM-dominant (alpha>0.5): {high_proto} classes")
        print(f"  Perch-dominant (alpha<=0.5): {high_perch} classes")

    print(f"\nPer-class calibration bias stats:")
    with torch.no_grad():
        cb = model.class_bias.numpy()
        print(f"  mean={cb.mean():.4f} std={cb.std():.4f} min={cb.min():.4f} max={cb.max():.4f}")

    print(f"\nMLP probes: {len(probe_models)} classes")

    if "per_class_auc_proto" in LOGS and LOGS["per_class_auc_proto"]:
        sorted_aucs = sorted(LOGS["per_class_auc_proto"].items(), key=lambda x: x[1], reverse=True)
        print(f"\nTop 10 classes by ProtoSSM OOF AUC:")
        for label, auc in sorted_aucs[:10]:
            print(f"  {label}: {auc:.4f}")
        print(f"\nBottom 10 classes by ProtoSSM OOF AUC:")
        for label, auc in sorted_aucs[-10:]:
            print(f"  {label}: {auc:.4f}")

    print("\nSubmission probability stats:")
    print(submission.iloc[:, 1:].stack().describe())
else:
    print("Submit mode completed.")
    print(f"ProtoSSM v5 parameters: {model.count_parameters():,}")
    print(f"Ensemble weight: {ENSEMBLE_WEIGHT_PROTO:.2f}")
    print(f"Wall time: {wall_time:.1f}s")
    print(f"Improvements: {LOGS['improvements']}")
Saved /kaggle/working/logs.json
Submit mode completed.
ProtoSSM v5 parameters: 5,776,250
Ensemble weight: 0.50
Wall time: 333.0s
Improvements: ['d_model_256', 'n_ssm_layers_3', 'cross_attention', 'mixup', 'focal_loss', 'swa', 'per_taxon_temperature', 'file_level_scaling', 'tta', 'rank_aware_scaling', 'delta_shift_smooth', 'per_class_thresholds']

核心总结

  1. 完整日志保存:记录总运行时间、模型配置、集成权重、优化点、逐类别阈值,保存为logs.json
  2. 训练模式打印总结:输出模型参数量、训练耗时、损失、AUC、OOF 集成结果、融合权重分布、MLP 探针数量、最优 / 最差物种 AUC
  3. 提交模式简化输出:仅打印核心信息(参数量、集成权重、总耗时、优化点)
  4. 异常保护:日志保存失败时仅提示警告,不中断流程

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐