Text2SQL 开源框架Vanna综合指南

概述

Vanna 是一个开源的 Text2SQL 框架,旨在将自然语言问题转换为 SQL 查询。它利用大型语言模型(LLM)的强大能力,结合数据库模式信息,为用户提供准确的 SQL 查询生成服务。Vanna 支持多种数据库和 LLM,并提供了简单易用的 Python API。

Vanna 编程接口

基本接口

Vanna 提供了简洁的 Python API,主要接口包括:

初始化 Vanna 实例
import vanna
from vanna.remote import VannaDefault

# 使用默认远程实例
vn = VannaDefault(api_key='your-api-key', model='your-model-name')

# 或者使用本地实例
from vanna.local import VannaLocal
vn = VannaLocal(model='your-model-name')
连接数据库
# 连接到 PostgreSQL
vn.connect_to_postgres(host='localhost', dbname='your_db', user='your_user', password='your_password')

# 连接到 MySQL
vn.connect_to_mysql(host='localhost', dbname='your_db', user='your_user', password='your_password')

# 连接到 SQLite
vn.connect_to_sqlite(url='sqlite:///path/to/your/database.db')

# 连接到 SQL Server
vn.connect_to_sql_server(server='your_server', database='your_db', user='your_user', password='your_password')
训练模型
# 使用 DDL 训练
vn.train(ddl="CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), email VARCHAR(100));")

# 使用文档训练
vn.train(documentation="我们的客户表存储所有客户信息,包括姓名和电子邮件。")

# 使用查询训练
vn.train(question="有多少客户?", sql="SELECT COUNT(*) FROM customers;")
生成 SQL
# 生成 SQL 查询
sql = vn.generate_sql("有多少客户?")
print(sql)

# 直接运行查询
df = vn.run_sql("有多少客户?")
print(df)

# 生成查询并附带可视化
vn.ask("有多少客户?")
高级功能
# 获取查询解释
explanation = vn.generate_explanation(sql, "有多少客户?")

# 获取查询可视化
fig = vn.generate_plotly_code(sql, "有多少客户?")

# 获取查询建议
suggestions = vn.generate_followup_questions("有多少客户?")

# 保存和加载训练数据
vn.save_training_data('training_data.json')
vn.load_training_data('training_data.json')

Vanna 支持的大模型

Vanna 支持多种大型语言模型,包括:

OpenAI 模型

  • GPT-4
  • GPT-4 Turbo
  • GPT-3.5 Turbo

Anthropic 模型

  • Claude 3 Opus
  • Claude 3 Sonnet
  • Claude 3 Haiku

Google 模型

  • Gemini Pro
  • Gemini Pro Vision

开源模型

  • Llama 2
  • Llama 3
  • Mistral
  • Mixtral

本地模型

Vanna 还支持通过 Ollama 等工具运行本地模型:

from vanna.local import VannaLocal
vn = VannaLocal(model='llama2')

自定义模型

Vanna 允许集成自定义的 LLM 端点:

from vanna.custom import VannaCustom
vn = VannaCustom(llm_endpoint='your-custom-endpoint')

Vanna 支持的数据库

Vanna 支持多种主流数据库系统:

关系型数据库

  • PostgreSQL - 完全支持,包括高级功能
  • MySQL - 完全支持
  • SQLite - 完全支持
  • SQL Server - 完全支持
  • Oracle - 基本支持
  • MariaDB - 完全支持

云数据库

  • Amazon RDS (PostgreSQL, MySQL, SQL Server)
  • Google Cloud SQL (PostgreSQL, MySQL, SQL Server)
  • Azure Database (PostgreSQL, MySQL, SQL Server)
  • Snowflake - 完全支持
  • BigQuery - 完全支持
  • Redshift - 基本支持

NoSQL 数据库

  • DuckDB - 完全支持
  • ClickHouse - 基本支持

数据仓库

  • Apache Hive - 基本支持
  • Apache Spark SQL - 基本支持

连接示例

# 连接到不同的数据库
vn.connect_to_postgres(...)
vn.connect_to_mysql(...)
vn.connect_to_sqlite(...)
vn.connect_to_sql_server(...)
vn.connect_to_snowflake(...)
vn.connect_to_bigquery(...)
vn.connect_to_duckdb(...)

Vanna 安装

基本安装

使用 pip 安装 Vanna:

pip install vanna

特定数据库支持安装

根据需要安装特定数据库的连接器:

# PostgreSQL
pip install psycopg2-binary

# MySQL
pip install pymysql

# SQL Server
pip install pyodbc

# Oracle
pip install cx_Oracle

# Snowflake
pip install snowflake-connector-python

# BigQuery
pip install google-cloud-bigquery

# DuckDB
pip install duckdb

完整安装(包含所有依赖)

pip install vanna[all]

开发环境安装

如果需要从源代码安装或参与开发:

git clone https://github.com/vanna-ai/vanna.git
cd vanna
pip install -e .

验证安装

import vanna
print(vanna.__version__)

Vanna 简单示例

示例 1: 基本使用

import vanna
from vanna.remote import VannaDefault

# 初始化 Vanna
vn = VannaDefault(api_key='your-api-key', model='your-model-name')

# 连接到 SQLite 数据库
vn.connect_to_sqlite(url='sqlite:///chinook.db')

# 使用 DDL 训练
vn.train(ddl="""
CREATE TABLE customers (
    CustomerId INT PRIMARY KEY,
    FirstName VARCHAR(40),
    LastName VARCHAR(40),
    Company VARCHAR(80),
    Address VARCHAR(70),
    City VARCHAR(40),
    State VARCHAR(40),
    Country VARCHAR(40),
    PostalCode VARCHAR(10),
    Phone VARCHAR(24),
    Fax VARCHAR(24),
    Email VARCHAR(60),
    SupportRepId INT
);
""")

# 使用文档训练
vn.train(documentation="customers 表存储所有客户信息,包括姓名、地址和联系方式。")

# 生成并执行查询
question = "有多少来自美国的客户?"
sql = vn.generate_sql(question)
print(f"生成的 SQL: {sql}")

# 执行查询
df = vn.run_sql(sql)
print(f"查询结果: {df}")

# 使用 ask 方法一次性完成
vn.ask("有多少来自美国的客户?")

示例 2: 多表查询

import vanna
from vanna.local import VannaLocal

# 初始化本地 Vanna 实例
vn = VannaLocal(model='chinook-model')

# 连接到数据库
vn.connect_to_sqlite(url='sqlite:///chinook.db')

# 训练多个表
vn.train(ddl="""
CREATE TABLE invoices (
    InvoiceId INT PRIMARY KEY,
    CustomerId INT,
    InvoiceDate DATETIME,
    BillingAddress VARCHAR(70),
    BillingCity VARCHAR(40),
    BillingState VARCHAR(40),
    BillingCountry VARCHAR(40),
    BillingPostalCode VARCHAR(10),
    Total DECIMAL(10, 2)
);

CREATE TABLE invoice_items (
    InvoiceLineId INT PRIMARY KEY,
    InvoiceId INT,
    TrackId INT,
    UnitPrice DECIMAL(10, 2),
    Quantity INT
);

CREATE TABLE tracks (
    TrackId INT PRIMARY KEY,
    Name VARCHAR(200),
    AlbumId INT,
    MediaTypeId INT,
    GenreId INT,
    Composer VARCHAR(220),
    Milliseconds INT,
    Bytes INT,
    UnitPrice DECIMAL(10, 2)
);
""")

# 使用文档训练
vn.train(documentation="""
- invoices 表存储客户发票信息
- invoice_items 表存储发票中的具体项目
- tracks 表存储音轨信息
- 可以通过 invoices 和 invoice_items 连接查询客户的购买详情
- 可以通过 invoice_items 和 tracks 连接查询音轨销售情况
""")

# 复杂查询
question = "哪个客户花费最多?请列出客户名称和总金额。"
vn.ask(question)

# 另一个复杂查询
question = "最畅销的5首歌曲是什么?"
vn.ask(question)

示例 3: 自定义训练数据

import vanna
from vanna.remote import VannaDefault

vn = VannaDefault(api_key='your-api-key', model='your-model-name')
vn.connect_to_postgres(host='localhost', dbname='sales', user='postgres', password='password')

# 从文件加载 DDL
with open('schema.sql', 'r') as f:
    ddl = f.read()
vn.train(ddl=ddl)

# 从文件加载文档
with open('documentation.md', 'r') as f:
    documentation = f.read()
vn.train(documentation=documentation)

# 从 CSV 文件加载问答对
import pandas as pd
qa_df = pd.read_csv('qa_pairs.csv')
for _, row in qa_df.iterrows():
    vn.train(question=row['question'], sql=row['sql'])

# 现在可以使用训练好的模型
vn.ask("上个月销售额最高的产品是什么?")

示例 4: 可视化查询结果

import vanna
from vanna.remote import VannaDefault

vn = VannaDefault(api_key='your-api-key', model='your-model-name')
vn.connect_to_sqlite(url='sqlite:///sales.db')

# 训练模型(省略训练步骤)

# 生成带可视化的查询
question = "每月销售额趋势如何?"
vn.ask(question, visualize=True)

# 或者分别生成 SQL 和可视化
sql = vn.generate_sql(question)
df = vn.run_sql(sql)
fig = vn.generate_plotly_code(sql, question)
print(fig)

示例 5: 错误处理和重试

import vanna
from vanna.remote import VannaDefault

vn = VannaDefault(api_key='your-api-key', model='your-model-name')
vn.connect_to_sqlite(url='sqlite:///chinook.db')

# 训练模型(省略训练步骤)

def safe_ask(question, max_retries=3):
    for i in range(max_retries):
        try:
            result = vn.ask(question)
            return result
        except Exception as e:
            print(f"尝试 {i+1} 失败: {str(e)}")
            if i == max_retries - 1:
                print("达到最大重试次数,查询失败")
                return None
            print("正在重试...")
    return None

# 使用安全查询
result = safe_ask("有多少来自加拿大的客户?")
if result:
    print("查询成功")
else:
    print("查询失败")

高级功能

自定义提示模板

from vanna.remote import VannaDefault

vn = VannaDefault(api_key='your-api-key', model='your-model-name')

# 自定义 SQL 生成提示
custom_prompt = """
你是一个 SQL 专家。根据以下数据库模式和用户问题生成准确的 SQL 查询。

数据库模式:
{schema}

用户问题: {question}

请生成 SQL 查询:
"""

vn.set_sql_generation_prompt(custom_prompt)

查询验证

# 在执行前验证 SQL
sql = vn.generate_sql("有多少客户?")
is_valid = vn.validate_sql(sql)
if is_valid:
    df = vn.run_sql(sql)
else:
    print("生成的 SQL 无效")

批量处理

questions = [
    "有多少客户?",
    "最畅销的产品是什么?",
    "上个月的销售额是多少?"
]

results = []
for question in questions:
    result = vn.ask(question)
    results.append(result)

最佳实践

  1. 充分训练模型:提供足够的 DDL、文档和示例查询对
  2. 明确问题:使用清晰、具体的自然语言问题
  3. 验证结果:对生成的 SQL 进行验证,特别是复杂查询
  4. 处理错误:实现适当的错误处理和重试机制
  5. 性能优化:对于大型数据库,考虑限制查询结果集大小

总结

Vanna 是一个功能强大的 Text2SQL 框架,它通过结合大型语言模型和数据库模式信息,能够准确地将自然语言问题转换为 SQL 查询。其简洁的 API、多数据库支持和灵活的模型集成使其成为数据分析和商业智能应用的理想选择。通过适当的训练和配置,Vanna 可以显著降低数据查询的技术门槛,使非技术用户也能轻松访问和分析数据。

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐