不平衡数据分类 tensorflow 笔记

https://tensorflow.google.cn/tutorials/structured_data/imbalanced_data

 

如何对高度不平衡的数据集进行分类,在该数据集中,一类中的示例数量大大超过另一类中的示例数量

清理,拆分和规范化数据

1- 首先,TimeAmount列的变量太大,无法直接使用。删除Time列(因为尚不清楚其含义),并获取Amount列的日志以减小其范围。

2- 将数据集分为训练,验证和测试集。在模型拟合过程中使用验证集来评估损失和任何度量,但是模型不适合此数据。该测试集在训练阶段完全未使用,仅在最后用于评估模型对新数据的推广程度

3-使用sklearn StandardScaler标准化输入功能。这会将平均值设置为0,标准偏差设置为1。

建立模型

请注意,该模型使用的批次大于默认的2048批次,这对于确保每个批次都有一定的机会容纳少量阳性样品非常重要。如果批量太小,则可能没有欺诈性交易可借鉴。

可选:设置正确的初始偏差

将其设置为初始偏差,模型将给出更合理的初始猜测。

model = make_model(output_bias = initial_bias)
model.predict(train_features[:10])

绘制ROC

现在绘制ROC。该图非常有用,因为它一眼就能显示出只要调整输出阈值就可以达到的性能范围。


目标是识别欺诈交易,但是您没有太多可用于处理的积极样本,因此您希望分类器对可用的几个例子进行重点研究。您可以通过将每个类的Keras权重传递给参数来实现。这些将导致模型“更多关注”来自代表性不足的类的示例。


用class重量训练模型

现在尝试使用类权重对模型进行重新训练和评估,以了解其如何影响预测。

 

weighted_model = make_model()
weighted_model.load_weights(initial_weights)

weighted_history = weighted_model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks = [early_stopping],
    validation_data=(val_features, val_labels),
    # The class weights go here
    class_weight=class_weight) 
 

发表评论

电子邮件地址不会被公开。 必填项已用*标注