百度 推荐算法

一、LightGBM原理及其改进

1.1. 梯度提升

LightGBM的核心是一个梯度提升框架,它通过迭代地构建决策树来最小化损失函数。每棵树试图拟合前一棵树的残差(即预测值与真实值之间的差异),从而逐步提高模型的整体性能。

1.2. 基于梯度的单边采样 (GOSS - Gradient-based One-Side Sampling)

GOSS是一种用来减少数据量同时保持数据分布特性的技术。它通过保留具有较大梯度的数据点(这些点对确定分裂点更重要),并对剩余的小梯度数据进行随机抽样来实现。(这种方法能够在不影响最终模型准确率的情况下显著减少计算量)

1.3. 互斥特征捆绑 (EFB - Exclusive Feature Bundling)

  • EFB允许将稀疏特征合并在一起而不会丢失太多信息。通过这种方式,LightGBM能够减少特征维度,进而加快训练速度(它利用了某些特征之间很少同时非零的特点来进行捆绑)

1.4. 直方图算法 (Histogram-based Algorithm)

  • 与传统的预排序方法不同,LightGBM采用基于直方图的方法来寻找最佳分裂点。这种方法首先将连续的特征值离散化成若干个区间(即直方图),然后根据这些区间内的统计信息来决定如何分裂。(直方图算法不仅减少了存储需求,还加速了分裂点的选择过程)

1.5. Leaf-wise (Best-first) Tree Growth

  • 传统GBDT通常采用level-wise增长策略,即每次分裂都会扩展所有叶子节点。相比之下,LightGBM采用了leaf-wise增长策略,每次选择增益最大的节点进行分裂。(虽然这种方法可能会导致更深的树结构,但它能更有效地降低损失函数值)

二、怎么修改gbdt用来分类?

GBDT的分类算法从思想上和GBDT的回归算法没有区别,但是由于样本输出是连续的值,而不是离散的类别,导致我们无法直接从输出类别去拟合类别输出的误差。
为了解决这个问题,主要有两个方法,

  • 一个是用指数损失函数,此时GBDT退化为Adaboost算法(Boosting的代表算法)。
  • 另一种方法是用类似于逻辑回归的对数似然损失函数的方法。也就是说,我们用的是类别的预测概率值和真实概率值的差来拟合损失。

另外,首先明确一点,gbdt 无论用于分类还是回归一直都是使用的 CART 回归树。不会因为我们所选择的任务是分类任务就选用分类树,这里面的核心是因为gbdt 每轮的训练是在上一轮的训练的残差基础之上进行训练的。这里的残差就是当前模型的负梯度值 。这个要求每轮迭代的时候,弱分类器的输出的结果相减是有意义的,残差相减是有意义的。

二、Encoder 的结构

特征提取网络就是一个编码器。

  • 经典的Transformer架构的encoder:由**多头自注意力机制(Multi-head Self-Attention Mechanism)、前馈神经网络(Feed-forward Neural Networks)以及残差连接和层归一化(Layer Normalization)**组成。
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Attention and Normalization
    x = tf.keras.layers.MultiHeadAttention(key_dim=head_size, num_heads=num_heads, dropout=dropout)(inputs, inputs)
    x = tf.keras.layers.Dropout(dropout)(x)
    x = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x + inputs)
    
    # Feed Forward Part
    res = tf.keras.layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(x)
    res = tf.keras.layers.Dropout(dropout)(res)
    res = tf.keras.layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(res)
    res = tf.keras.layers.LayerNormalization(epsilon=1e-6)(res + x)
    return res

三、手撕注意力机制和多头注意力机制

  1. 通过三个全连接生成q,k,v
  2. 通过reshape,分成多个q,k,v
  3. 计算自注意力分数(通过点积),并拼接全部的头
    见【搜广推校招面试三十四】
Logo

欢迎加入 MCP 技术社区!与志同道合者携手前行,一同解锁 MCP 技术的无限可能!

更多推荐