基于图神经网络的查询代价估计:替代传统统计信息的新路径

cover

一、代价估计的"统计盲区":当直方图无法捕捉关联性

数据库查询优化器的核心任务是选择最优执行计划,而代价估计是决策的基础。传统方法依赖统计信息(直方图、NDV、MCV)估算选择率,但统计信息只能捕捉单列分布,无法建模列间相关性。一个 WHERE city='北京' AND age=25 的查询,优化器假设两列独立,估算选择率为 P(city='北京') × P(age=25),但实际中北京用户的年龄分布可能与全国分布显著不同。这种独立性假设导致选择率偏差可达 10 倍以上,进而选择错误的 Join 顺序和索引。图神经网络(GNN)为查询代价估计提供了新思路——将查询计划和数据关系建模为图结构,学习列间关联性。

二、查询代价估计的问题建模

2.1 从 SQL 到图结构的转换

flowchart TB
    A[SQL 查询] --> B[解析为查询计划树]
    B --> C[转换为异构图]

    subgraph 异构图节点类型
        D[表节点 Table]
        E[列节点 Column]
        F[谓词节点 Predicate]
        G[Join 节点 Join]
    end

    subgraph 异构图边类型
        H[表→列 包含边]
        I[列→谓词 引用边]
        J[谓词→Join 连接边]
    end

    C --> D & E & F & G
    C --> H & I & J

    subgraph GNN 推理
        K[消息传递<br/>邻居特征聚合]
        L[节点嵌入更新]
        M[图级读出<br/>全局代价预测]
    end

    D & E & F & G --> K --> L --> M

2.2 传统代价估计的局限性

-- 传统优化器的选择率估算
-- 假设 city 和 age 独立
-- P(city='北京') = 0.05, P(age=25) = 0.03
-- 估算选择率 = 0.05 × 0.03 = 0.0015

-- 实际选择率可能远高于此(北京年轻用户集中)
-- 真实选择率 = 0.008,偏差 5.3 倍

-- 结果:优化器可能选择全表扫描而非索引扫描
SELECT * FROM users
WHERE city = '北京' AND age = 25;

三、GNN 查询代价估计方案

3.1 查询图构建

import torch
import torch.nn as nn
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, GATConv, Linear

class QueryGraphBuilder:
    """将 SQL 查询转换为异构图"""

    def __init__(self, schema_metadata: dict):
        self.schema = schema_metadata

    def build_graph(self, query_plan: dict) -> HeteroData:
        """从查询计划构建异构图"""
        data = HeteroData()

        # 表节点特征:行数、平均行长、索引数量
        table_features = []
        table_ids = {}
        for i, table in enumerate(query_plan['tables']):
            table_ids[table['name']] = i
            meta = self.schema.get(table['name'], {})
            table_features.append([
                meta.get('row_count', 0),
                meta.get('avg_row_len', 0),
                meta.get('index_count', 0),
            ])

        data['table'].x = torch.tensor(table_features, dtype=torch.float)

        # 列节点特征:NDV、空值率、数据类型编码
        column_features = []
        column_ids = {}
        col_idx = 0
        for table in query_plan['tables']:
            for col in table['columns']:
                column_ids[(table['name'], col['name'])] = col_idx
                meta = self.schema.get(table['name'], {}).get('columns', {}).get(col['name'], {})
                column_features.append([
                    meta.get('ndv', 0),
                    meta.get('null_ratio', 0),
                    self._encode_type(col['type']),
                ])
                col_idx += 1

        data['column'].x = torch.tensor(column_features, dtype=torch.float)

        # 谓词节点特征:操作符编码、常量值
        pred_features = []
        for pred in query_plan['predicates']:
            pred_features.append([
                self._encode_op(pred['op']),
                pred.get('selectivity_hint', 0),
            ])

        data['predicate'].x = torch.tensor(pred_features, dtype=torch.float)

        # 边:表→列
        table_to_col = []
        for (tname, cname), cidx in column_ids.items():
            table_to_col.append([table_ids[tname], cidx])
        data['table', 'contains', 'column'].edge_index = (
            torch.tensor(table_to_col, dtype=torch.long).t().contiguous()
        )

        # 边:列→谓词
        col_to_pred = []
        for pidx, pred in enumerate(query_plan['predicates']):
            key = (pred['table'], pred['column'])
            if key in column_ids:
                col_to_pred.append([column_ids[key], pidx])
        data['column', 'referenced_by', 'predicate'].edge_index = (
            torch.tensor(col_to_pred, dtype=torch.long).t().contiguous()
        )

        return data

    @staticmethod
    def _encode_type(type_str: str) -> int:
        type_map = {'int': 1, 'varchar': 2, 'date': 3, 'float': 4, 'bool': 5}
        return type_map.get(type_str.lower(), 0)

    @staticmethod
    def _encode_op(op: str) -> int:
        op_map = {'=': 1, '!=': 2, '>': 3, '<': 4, '>=': 5, '<=': 6, 'LIKE': 7, 'IN': 8}
        return op_map.get(op.upper(), 0)

3.2 异构图神经网络模型

class QueryCostModel(nn.Module):
    """基于异构图神经网络的查询代价估计模型"""

    def __init__(self, hidden_dim: int = 64):
        super().__init__()

        # 各节点类型的输入投影层
        self.table_proj = Linear(-1, hidden_dim)
        self.column_proj = Linear(-1, hidden_dim)
        self.predicate_proj = Linear(-1, hidden_dim)

        # 异构图卷积层
        self.conv1 = HeteroConv({
            ('table', 'contains', 'column'): GATConv((-1, -1), hidden_dim, heads=2),
            ('column', 'referenced_by', 'predicate'): GATConv((-1, -1), hidden_dim, heads=2),
        }, aggr='mean')

        self.conv2 = HeteroConv({
            ('table', 'contains', 'column'): GATConv((-1, -1), hidden_dim, heads=2),
            ('column', 'referenced_by', 'predicate'): GATConv((-1, -1), hidden_dim, heads=2),
        }, aggr='mean')

        # 代价预测头
        self.cost_head = nn.Sequential(
            Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            Linear(hidden_dim, 1),
            nn.Softplus(),  # 确保输出为正值
        )

    def forward(self, data: HeteroData) -> torch.Tensor:
        # 输入投影
        x_dict = {
            'table': self.table_proj(data['table'].x),
            'column': self.column_proj(data['column'].x),
            'predicate': self.predicate_proj(data['predicate'].x),
        }

        # 两层消息传递
        x_dict = self.conv1(x_dict, data.edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}

        x_dict = self.conv2(x_dict, data.edge_index_dict)

        # 图级读出:各节点类型的全局平均池化
        table_emb = x_dict['table'].mean(dim=0)
        column_emb = x_dict['column'].mean(dim=0)
        pred_emb = x_dict['predicate'].mean(dim=0)

        # 拼接后预测代价
        graph_emb = torch.cat([table_emb, column_emb, pred_emb], dim=-1)
        cost = self.cost_head(graph_emb)

        return cost

3.3 训练数据采集与标注

class CostLabelCollector:
    """从数据库采集真实执行代价作为训练标签"""

    def __init__(self, db_connection):
        self.db = db_connection

    def collect(self, query: str) -> dict:
        """执行查询并采集真实代价指标"""
        # 使用 EXPLAIN ANALYZE 获取真实执行统计
        explain_result = self.db.execute(f"EXPLAIN ANALYZE {query}")

        return {
            'actual_rows': explain_result['actual_rows'],
            'actual_time_ms': explain_result['actual_total_time'],
            'actual_io_reads': explain_result.get('shared_read_blocks', 0),
            'plan_nodes': explain_result['plan_nodes'],
        }

    def collect_workload(self, queries: list) -> list:
        """批量采集工作负载的代价标签"""
        dataset = []
        for query in queries:
            try:
                label = self.collect(query)
                dataset.append({
                    'query': query,
                    'label': label,
                })
            except Exception as e:
                # 跳过执行失败的查询
                continue
        return dataset

四、边界分析与架构权衡

4.1 训练数据的分布偏移

模型在工作负载 A 上训练后,对工作负载 B 的预测精度可能显著下降。原因是不同工作负载的谓词分布、Join 模式和数据访问模式差异大。解决方案:定期用最新工作负载重新训练,或采用在线学习持续更新模型。

4.2 推理延迟对优化器的影响

传统代价估计耗时微秒级,GNN 推理需要毫秒级(1-5ms)。在高并发短查询场景中,优化器本身的耗时占比显著增加。优化方案:对查询模板缓存预测结果,相似查询直接查表而非重新推理。

4.3 模型可解释性不足

GNN 的预测结果难以解释——优化器无法理解"为什么模型认为这个 Join 代价更高"。在生产环境中,模型预测与传统估计之间的差异需要可审计。建议:模型仅作为传统估计的校准信号,而非完全替代,保留传统估计作为回退。

4.4 冷启动问题

新表和新列没有历史统计信息,GNN 也无法获取足够的特征输入。冷启动阶段需要回退到传统统计信息估计,待积累足够查询样本后再启用 GNN 预测。

五、总结

基于图神经网络的查询代价估计,通过将查询计划建模为异构图(表、列、谓词节点及其关系边),学习列间相关性和谓词交互效应,弥补传统统计信息的独立性假设缺陷。GATConv 的消息传递机制让节点特征在图结构中传播,最终通过图级读出预测查询代价。工程实践中需注意训练数据的分布偏移、推理延迟对优化器的影响、模型可解释性不足和冷启动问题。GNN 代价估计最适合作为传统估计的校准层,而非完全替代,在保证可审计性的前提下提升选择率估算精度。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐