树模型 simple_ml.tree

一、ID3算法

from simple_ml.base.base_model import BaseClassifier


class ID3(BaseClassifier):

    __doc__ = "ID3 Decision Tree"

    def __init__(self, max_depth=None, min_samples_leaf=3):
        """
        决策树ID3算法
        :param max_depth:        树最大深度
        :param min_samples_leaf: 叶子节点最大样本数(最好是奇数,用以投票)
        """
        pass

ID3模型,支持解决:


1.1 初始化

  名称 类型 描述
Parameters: max_depth int 树最大深度
  min_samples_leaf int 叶子节点最大样本数(奇数)

1.2 类方法

1 拟合

def fit(self, x, y)

拟合特征

  名称 类型 描述
Parameters: x np.2darray 训练集特征
  y np.array 训练集标签
Returns:   Void  

2 预测

def predict(self, x)

给定测试集特征x,进行预测

  名称 类型 描述
Parameters: x np.2darray 测试集特征
Returns:   np.array 预测的结果

3 结果评价

def score(self, x, y)

拟合并进行预测,最后给出预测效果的得分

  名称 类型 描述
Parameters: x np.2darray 测试集特征
  y np.array 测试集标签
Returns:   float 预测结果评分,二分类给出F1值,多分类给出Macro F1值

4 分类作图

绘制分类效果图,如果维度大于2,则通过PCA降至两维

def classify_plot(self, x, y, title="")
  名称 类型 描述
Parameters: x np.2darray 测试集特征
  y np.array 测试集标签
Returns:   Void  

1.3 类属性

名称 类型 描述
root MultiTreeNode 树的根节点

二、分类回归树算法 (CART)

from simple_ml.base.base_model import BaseClassifier


class CART(BaseClassifier):

    __doc__ = "Classify and Regression Tree"

    def __init__(self, max_depth=10, min_samples_leaf=5):
        """
        分类回归树
        :param max_depth:        树最大深度
        :param min_samples_leaf: 叶子节点最大样本数(最好是奇数,用以投票)
        """
        pass

CART模型,支持解决:


2.1 初始化

  名称 类型 描述
Parameters: max_depth int 树最大深度
  min_samples_leaf int 叶子节点最大样本数(奇数)

2.2 类方法

1 拟合

def fit(self, x, y)

拟合特征

  名称 类型 描述
Parameters: x np.2darray 训练集特征
  y np.array 训练集标签
Returns:   Void  

2 预测

def predict(self, x)

给定测试集特征x,进行预测

  名称 类型 描述
Parameters: x np.2darray 测试集特征
Returns:   np.array 预测的结果

3 结果评价

def score(self, x, y)

拟合并进行预测,最后给出预测效果的得分

  名称 类型 描述
Parameters: x np.2darray 测试集特征
  y np.array 测试集标签
Returns:   float 预测结果评分,二分类给出F1值,多分类给出Macro F1值

4 分类作图

绘制分类效果图,如果维度大于2,则通过PCA降至两维

def classify_plot(self, x, y, title="")
  名称 类型 描述
Parameters: x np.2darray 测试集特征
  y np.array 测试集标签
Returns:   Void  

2.3 类属性

名称 类型 描述
root BinaryTreeNode 树的根节点

Examples

二分类

from simple_ml.classify_data import get_watermelon
from simple_ml.data_handle import train_test_split
from simple_ml.tree import ID3

x, y = get_watermelon()
x = x[:, :4]
x_train, y_train, x_test, y_test = train_test_split(x, y, 0.3, 918)
id3 = ID3()
id3.fit(x_train, y_train)
print(id3.score(x_test, y_test))

多分类

from simple_ml.classify_data import get_wine
from simple_ml.data_handle import train_test_split
from simple_ml.tree import CART


x, y = get_wine()
x_train, y_train, x_test, y_test = train_test_split(x, y, 0.3, 918)

cart = CART()
cart.fit(x_train, y_train)
print(cart.score(x_test, y_test))
cart.classify_plot(x_test, y_test)

回归

from simple_ml.classify_data import get_wine
from simple_ml.data_handle import train_test_split
from simple_ml.tree import CART


x, y = get_wine()

y = x[:, -1]
x = x[:, :-1]
x_train, y_train, x_test, y_test = train_test_split(x, y, 0.3, 918)

cart = CART()
cart.fit(x_train, y_train)
print(cart.score(x_test, y_test))

返回主页