基于 XGBoost 的股票涨跌预测实战(附完整可运行 Streamlit 代码)
摘要
本文详细介绍了如何使用集成学习算法 XGBoost 构建股票次日涨跌预测模型,并通过 Streamlit 框架快速搭建可视化交互界面。针对传统股票技术分析和基本面分析存在的滞后性、主观性痛点,XGBoost 能够高效捕捉金融数据中的非线性关系,通过正则化机制有效防止过拟合。文章提供了从数据获取、特征工程到模型训练、结果评估的完整代码实现,读者可直接运行体验,快速上手量化交易入门级项目。
关键词:XGBoost;股票预测;量化交易;特征工程;Streamlit
目录
- 为什么选择 XGBoost 做股票预测
- XGBoost 核心原理简介
- 实战:搭建 XGBoost 股票涨跌预测系统3.1 环境准备3.2 完整可运行代码3.3 运行效果说明
- 模型局限性与改进方向
- 免责声明
1. 为什么选择 XGBoost 做股票预测
1.1 传统股票分析方法的局限性
传统股票分析主要分为两大流派,但均存在难以克服的缺陷:
- 滞后性:技术指标(如均线、MACD、RSI)均基于历史价格计算,无法实时反映市场情绪和突发消息的影响
- 主观性:基本面分析高度依赖分析师的个人经验和判断,容易受情绪干扰,缺乏统一的量化标准
- 难以处理复杂关系:股票价格受宏观经济、行业政策、资金流向、投资者情绪等多重因素影响,传统线性模型无法捕捉这些因素间的复杂交互作用
1.2 XGBoost 算法的核心优势
XGBoost(Extreme Gradient Boosting)作为梯度提升树的优化实现,在结构化数据预测任务中表现卓越,尤其适合金融量化场景:
- 高效处理非线性关系:通过多棵决策树的串行组合,能够自动学习特征间的复杂非线性交互
- 强大的抗过拟合能力:内置 L1/L2 正则化项、列采样和行采样机制,有效降低噪声数据的干扰
- 灵活的金融场景适配:原生支持缺失值处理,可自定义损失函数,完美适配高频交易数据的特点
- 可解释性强:能够输出特征重要性排序,帮助投资者理解哪些因素对股价涨跌影响最大
2. XGBoost 核心原理简介
XGBoost 基于 Boosting 集成学习框架,其核心思想是串行训练多棵决策树,每一棵新树都拟合前序所有树的残差(预测误差),通过不断迭代逐步降低模型偏差。
其目标函数由两部分组成:
目标函数 = 损失函数 + 正则化项
- 损失函数:衡量模型预测值与真实值的差异,分类任务通常使用对数损失函数
- 正则化项:控制模型复杂度,防止过拟合,包括对树的深度、叶子节点数量和权重的惩罚
与传统 GBDT 相比,XGBoost 在以下方面进行了优化:
- 二阶泰勒展开近似损失函数,收敛速度更快
- 内置正则化项,有效防止过拟合
- 支持并行计算,训练效率大幅提升
- 自动处理缺失值,无需额外的数据预处理
3. 实战:搭建 XGBoost 股票涨跌预测系统
3.1 环境准备
首先安装所需的 Python 依赖库:
pip install streamlit akshare pandas numpy plotly xgboost scikit-learn
3.2 完整可运行代码
import streamlit as st
import akshare as ak
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from datetime import datetime, timedelta
# -------------------------- 全局设置 --------------------------
today = datetime.now()
default_start = today - timedelta(days=365)
default_end = today
st.set_page_config(page_title="XGBoost股票涨跌预测", layout="wide")
st.title('📈 基于XGBoost的股票涨跌预测系统')
st.markdown("---")
# -------------------------- 侧边栏参数输入 --------------------------
with st.sidebar:
st.header("⚙️ 参数设置")
stock_code = st.text_input("股票代码", "600000", help="输入A股股票代码,如600000(浦发银行)")
start_date = st.date_input("开始日期", default_start)
end_date = st.date_input("结束日期", default_end)
train_ratio = st.slider("训练集比例", 0.6, 0.95, 0.8, 0.05, help="用于训练模型的数据占比")
adjust_type = st.radio(
"复权类型",
options=[("前复权", "qfq"), ("后复权", "hfq"), ("不复权", "")],
index=0,
format_func=lambda x: x[0],
help="前复权以当前价格为基准调整历史价格,更适合技术分析"
)[1]
st.markdown("---")
st.subheader("模型参数")
n_estimators = st.slider("决策树数量", 50, 500, 200, 10)
max_depth = st.slider("树的最大深度", 3, 10, 6, 1)
learning_rate = st.slider("学习率", 0.01, 0.3, 0.05, 0.01)
# -------------------------- 数据获取与清洗 --------------------------
@st.cache_data(ttl=3600, show_spinner="正在获取股票数据...")
def get_stock_data(code, start, end, adjust):
"""
使用akshare获取A股历史行情数据并进行清洗
"""
try:
df = ak.stock_zh_a_hist(
symbol=code,
period="daily",
start_date=start.strftime("%Y%m%d"),
end_date=end.strftime("%Y%m%d"),
adjust=adjust
)
# 数据预处理
df = df.set_index('日期').sort_index()
df.index = pd.to_datetime(df.index)
df = df.rename(columns={
'开盘': 'open',
'最高': 'high',
'最低': 'low',
'收盘': 'close',
'成交量': 'volume'
})
return df[['open', 'high', 'low', 'close', 'volume']]
except Exception as e:
st.error(f"数据获取失败: {str(e)}")
return pd.DataFrame()
# 加载数据
with st.spinner('正在加载数据...'):
data = get_stock_data(stock_code, start_date, end_date, adjust_type)
if data.empty:
st.error("无法获取数据,请检查:\n1. 股票代码是否正确\n2. 日期范围是否有效\n3. 网络连接是否正常")
st.stop()
# -------------------------- 特征工程(核心) --------------------------
def create_features(df):
"""
构建预测所需的技术指标特征和目标变量
目标变量:次日收盘价是否高于当日收盘价(1=上涨,0=下跌)
"""
# 创建目标变量(次日涨跌)
df["target"] = (df["close"].shift(-1) > df["close"]).astype(int)
# 移除最后一行(没有次日数据)
df = df.iloc[:-1]
# 1. 移动平均线 MA
windows = [5, 10, 20]
for window in windows:
df[f'ma{window}'] = df['close'].rolling(window).mean()
# 2. 相对强弱指数 RSI(14)
delta = df['close'].diff()
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(14).mean()
avg_loss = loss.rolling(14).mean()
rs = avg_gain / (avg_loss + 1e-10) # 防止除零错误
df['rsi'] = 100 - (100 / (1 + rs))
# 3. 指数平滑异同移动平均线 MACD
exp12 = df['close'].ewm(span=12, adjust=False).mean()
exp26 = df['close'].ewm(span=26, adjust=False).mean()
df['macd'] = exp12 - exp26
df['signal'] = df['macd'].ewm(span=9, adjust=False).mean()
# 移除包含缺失值的行
return df.dropna()
# 处理数据
processed_data = create_features(data)
if len(processed_data) < 100:
st.warning(f"数据量不足(仅{len(processed_data)}条),建议选择更长的时间范围以获得更准确的预测结果")
st.stop()
# -------------------------- 数据集划分 --------------------------
# 特征列表
features = ['open', 'high', 'low', 'close', 'volume',
'ma5', 'ma10', 'ma20', 'rsi', 'macd', 'signal']
# 按时间顺序划分训练集和测试集(金融数据不能随机划分!)
split_idx = int(len(processed_data) * train_ratio)
X = processed_data[features]
y = processed_data['target']
X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]
# -------------------------- 模型训练与评估 --------------------------
with st.spinner("正在训练XGBoost模型..."):
# 初始化XGBoost分类器
model = XGBClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
learning_rate=learning_rate,
subsample=0.8,
colsample_bytree=0.9,
random_state=42,
use_label_encoder=False,
eval_metric='logloss'
)
# 训练模型
model.fit(X_train, y_train)
# 预测
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
# -------------------------- 结果展示 --------------------------
col1, col2 = st.columns(2)
with col1:
st.subheader("📊 模型评估结果")
st.metric("测试集准确率", f"{accuracy:.2%}")
# 混淆矩阵
cm = confusion_matrix(y_test, y_pred)
st.write("混淆矩阵:")
cm_df = pd.DataFrame(cm,
index=['实际下跌', '实际上涨'],
columns=['预测下跌', '预测上涨'])
st.dataframe(cm_df)
# 分类报告
st.write("分类报告:")
report = classification_report(y_test, y_pred, output_dict=True)
st.dataframe(pd.DataFrame(report).transpose())
with col2:
st.subheader("📈 特征重要性")
# 获取特征重要性
feature_importance = pd.DataFrame({
'特征': features,
'重要性': model.feature_importances_
}).sort_values('重要性', ascending=False)
# 可视化特征重要性
fig = go.Figure(go.Bar(
x=feature_importance['重要性'],
y=feature_importance['特征'],
orientation='h',
marker_color='royalblue'
))
fig.update_layout(
title='特征重要性排序',
xaxis_title='重要性得分',
yaxis_title='特征',
height=500
)
st.plotly_chart(fig, use_container_width=True)
# 预测结果可视化
st.subheader("📉 预测结果与实际走势对比")
test_data = processed_data.iloc[split_idx:].copy()
test_data['prediction'] = y_pred
test_data['pred_proba'] = y_pred_proba
fig = go.Figure()
# 实际收盘价
fig.add_trace(go.Scatter(
x=test_data.index,
y=test_data['close'],
name='实际收盘价',
line=dict(color='blue')
))
# 标记预测上涨的点
fig.add_trace(go.Scatter(
x=test_data[test_data['prediction'] == 1].index,
y=test_data[test_data['prediction'] == 1]['close'],
mode='markers',
name='预测上涨',
marker=dict(color='green', size=8, symbol='triangle-up')
))
# 标记预测下跌的点
fig.add_trace(go.Scatter(
x=test_data[test_data['prediction'] == 0].index,
y=test_data[test_data['prediction'] == 0]['close'],
mode='markers',
name='预测下跌',
marker=dict(color='red', size=8, symbol='triangle-down')
))
fig.update_layout(
title=f"{stock_code} 测试集预测结果对比",
xaxis_title='日期',
yaxis_title='收盘价',
height=600,
hovermode='x unified'
)
st.plotly_chart(fig, use_container_width=True)
# 明日预测
st.subheader("🔮 明日涨跌预测")
latest_data = processed_data.iloc[-1][features].values.reshape(1, -1)
tomorrow_pred = model.predict(latest_data)[0]
tomorrow_proba = model.predict_proba(latest_data)[0][tomorrow_pred]
pred_text = "上涨" if tomorrow_pred == 1 else "下跌"
pred_color = "green" if tomorrow_pred == 1 else "red"
st.markdown(f"模型预测明日 {stock_code} 股价将 <span style='color:{pred_color};font-weight:bold'>{pred_text}</span>,置信度为 {tomorrow_proba:.2%}", unsafe_allow_html=True)
st.info("⚠️ 以上预测仅基于历史数据统计分析,不构成任何投资建议!")
3.3 运行效果说明
将上述代码保存为stock_prediction.py,在终端运行以下命令即可启动 Web 应用:
streamlit run stock_prediction.py
应用启动后,浏览器会自动打开界面,你可以:
- 在左侧侧边栏输入股票代码、选择日期范围和模型参数
- 系统会自动获取数据、训练模型并展示评估结果
- 查看特征重要性排序,了解哪些技术指标对股价影响最大
- 直观对比预测结果与实际走势
- 获取明日股价涨跌预测及置信度
4. 模型局限性与改进方向
4.1 局限性
- 仅基于技术指标:本模型仅使用了价格和成交量数据,未考虑基本面、宏观经济、政策消息等重要因素
- 历史数据依赖性:模型基于历史数据训练,无法预测黑天鹅事件和市场突发变化
- 准确率上限:股票市场本质上是一个高度复杂的非线性系统,任何模型都无法达到 100% 的准确率
4.2 改进方向
- 丰富特征维度:加入基本面数据(市盈率、市净率、净利润等)、市场情绪数据(新闻舆情、股吧评论)、宏观经济数据(利率、汇率、CPI 等)
- 超参数优化:使用网格搜索(GridSearchCV)或贝叶斯优化自动寻找最优模型参数
- 模型融合:将 XGBoost 与其他模型(如 LightGBM、CatBoost、LSTM)进行融合,进一步提升预测准确率
- 加入风险控制:在预测结果中加入止损止盈策略,构建完整的交易系统
5. 免责声明
投资有风险,入市需谨慎。 本文所提供的所有内容仅为技术研究和学习交流之用,不构成任何投资建议、交易指导或收益承诺。任何依据本文内容进行的投资决策,其风险均由投资者自行承担。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐


所有评论(0)