scikit-learn包介绍#

scikit-learn官方文档(英文)

scikit-learn文档(中文)

scikit-learn 是一个提供各种机器学习算法的 Python 库。它是基于 NumPy, SciPy 和 Matplotlib 构建的,主要用于数据挖掘和数据分析。

为什么使用scikit-learn

  • 包含广泛的监督和非监督学习算法。

  • 提供统一和简洁的 API,方便快速开发。

  • 容易与 NumPy 和 Pandas 等库配合使用。

我在这里仅仅展示一个简单的用随机森林学习鸢尾花数据集的例子。后面的讲座中,将会有一位学长专门介绍机器学习。

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score


table = pd.read_csv("iris.csv")

X = table[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']] # 特征
y = table['species'] # 标签
y = y.astype('category').cat.codes  # 将标签转换为数字

#分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

#创建随机森林分类器,决策树数量为50
rfc = RandomForestClassifier(n_estimators=50)
rfc.fit(X_train, y_train)  #训练模型

# 评估模型
print("训练集准确率:", accuracy_score(y_train, rfc.predict(X_train)))
print("测试集准确率:", accuracy_score(y_test, rfc.predict(X_test)))
训练集准确率: 1.0
测试集准确率: 0.9