第八章 分类 案例1:社交网络平台汽车广告精准营销

案例背景

本案例中使用的数据集包含年龄、工资和购买特定汽车的决定。案例中将简单使用KNN算法来预测某个特定的人是否会买车

数据读取

import numpy as np
import os
import matplotlib.pyplot as plt 
import pandas as pd 
import seaborn as sns
from sklearn.model_selection import train_test_split   
from sklearn.preprocessing import StandardScaler 
from matplotlib.colors import ListedColormap
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix,classification_report,accuracy_score
import warnings
warnings.filterwarnings('ignore') 
plt.style.use('fivethirtyeight')
dataset=pd.read_csv('./Social_Network_Ads.csv')
dataset.head(2)
User ID Gender Age EstimatedSalary Purchased
0 15624510 Male 19 19000 0
1 15810944 Male 35 20000 0
dataset.describe().T
count mean std min 25% 50% 75% max
User ID 400.0 1.569154e+07 71658.321581 15566689.0 15626763.75 15694341.5 15750363.0 15815236.0
Age 400.0 3.765500e+01 10.482877 18.0 29.75 37.0 46.0 60.0
EstimatedSalary 400.0 6.974250e+04 34096.960282 15000.0 43000.00 70000.0 88000.0 150000.0
Purchased 400.0 3.575000e-01 0.479864 0.0 0.00 0.0 1.0 1.0

平均年龄大概37岁,平均工资69742$

f,ax=plt.subplots(1,2,figsize=(18,8))
dataset['Gender'].value_counts().plot.pie(explode=[0,0.05],autopct='%1.1f%%',ax=ax[0],shadow=True)
ax[0].set_title('Purchase by Gender')
ax[0].set_ylabel('Count')
sns.countplot('Gender',data=dataset,ax=ax[1],order=dataset['Gender'].value_counts().index)
ax[1].set_title('Purchase by Gender')
plt.show()

在这里插入图片描述

看出男女比较平等(男女数量相当)

分析薪资分布

fig=plt.gcf()
fig.set_size_inches(10,7)
fig=sns.distplot(dataset['EstimatedSalary'],kde=True,bins=50)

在这里插入图片描述

我们看到人群的平均工资大概是70000$

f,ax=plt.subplots(1,2,figsize=(18,8))
dataset['Purchased'].value_counts().plot.pie(explode=[0,0.05],autopct='%1.1f%%',ax=ax[0],shadow=True)
ax[0].set_title('Purchase Distribution')
ax[0].set_ylabel('Count')
sns.countplot('Purchased',data=dataset,ax=ax[1],order=dataset['Purchased'].value_counts().index)
ax[1].set_title('Purchase Distribution')
plt.show()

在这里插入图片描述

所以数据集中的大多数人都没有买车。
我们的一切努力都应该是为了销售更多的汽车。

模型构建与训练

X=dataset.iloc[:,[2,3]].values
y=dataset.iloc[:,4].values
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=0) 
X_train=StandardScaler().fit_transform(X_train)
X_test=StandardScaler().fit_transform(X_test)
classifier=KNeighborsClassifier(n_neighbors=1,metric='minkowski',p=2)
classifier.fit(X_train,y_train)
KNeighborsClassifier(n_neighbors=1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier(n_neighbors=1)
y_pred=classifier.predict(X_test)
cm=confusion_matrix(y_test,y_pred)
f, ax = plt.subplots(figsize =(5,5))
sns.heatmap(cm,annot = True,linewidths=0.5,linecolor="red",fmt = ".0f",ax=ax)
plt.title("Test for Test Dataset")
plt.xlabel("predicted y values")
plt.ylabel("real y values")
plt.show()

在这里插入图片描述

print(accuracy_score(y_test,y_pred))
0.925

模型参数调优

error_rate=[]

for i in range(1,40):
    knn=KNeighborsClassifier(n_neighbors=i)
    knn.fit(X_train,y_train)
    pred_i=knn.predict(X_test)
    error_rate.append(np.mean(pred_i != y_test))
plt.figure(figsize=(10,6))
plt.plot(range(1,40),error_rate,color='blue',linestyle='dashed',marker='o',markerfacecolor='red',markersize=10)
plt.title('Error Rate Vs K Value')
plt.xlabel('K')
plt.ylabel('Error Rate')
plt.show()

在这里插入图片描述

classifier=KNeighborsClassifier(n_neighbors=5,metric='minkowski',p=2)
classifier.fit(X_train,y_train)
KNeighborsClassifier()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KNeighborsClassifier()
y_pred=classifier.predict(X_test)
cm=confusion_matrix(y_test,y_pred)
f, ax = plt.subplots(figsize =(5,5))
sns.heatmap(cm,annot = True,linewidths=0.5,linecolor="red",fmt = ".0f",ax=ax)
plt.title("Test for Test Dataset")
plt.xlabel("predicted y values")
plt.ylabel("real y values")
plt.show()

在这里插入图片描述

print(accuracy_score(y_test,y_pred))
0.9375
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐