{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "OoasdhSAp0zJ" }, "outputs": [], "source": [ "##### Copyright 2019 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "cIrwotvGqsYh", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "C81KT2D_j-xR" }, "source": [ "# 使用 Estimator 构建线性模型\n", "\n", "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看\n", "在 Google Colab 中运行 在 GitHub 上查看源代码\n", "下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "JOccPOFMm5Tc" }, "source": [ "> 警告:不建议将 Estimator 用于新代码。Estimator 运行 `v1.Session` 风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的[兼容性保证](https://tensorflow.org/guide/versions)范围内,但除了安全漏洞之外不会得到任何修复。请参阅[迁移指南](https://tensorflow.org/guide/migrate)以了解详情。" ] }, { "cell_type": "markdown", "metadata": { "id": "tUP8LMdYtWPz" }, "source": [ "## 概述\n", "\n", "本端到端演示使用 `tf.estimator` API 来训练逻辑回归模型。该模型通常用作其他更复杂算法的基线。\n", "\n", "注:Keras 逻辑回归示例[已提供](https://tensorflow.org/guide/migrate/tutorials/keras/regression),并推荐在本教程中使用。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "vkC_j6VpqrDw" }, "source": [ "## 安装" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rutbJGmpqvm3", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!pip install sklearn\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "54mb4J9PqqDh", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from IPython.display import clear_output\n", "from six.moves import urllib" ] }, { "cell_type": "markdown", "metadata": { "id": "fsjkwfsGOBMT" }, "source": [ "## 加载 Titanic 数据集\n", "\n", "使用 Titanic 数据集的目的是在给定诸如性别、年龄、阶级等特征的情况下预测乘客能否生存(相当病态)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bNiwh-APcRVD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tensorflow.compat.v2.feature_column as fc\n", "\n", "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DSeMKcx03d5R", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Load dataset.\n", "dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')\n", "dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')\n", "y_train = dftrain.pop('survived')\n", "y_eval = dfeval.pop('survived')" ] }, { "cell_type": "markdown", "metadata": { "id": "jjm4Qj0u7_cp" }, "source": [ "## 探索数据" ] }, { "cell_type": "markdown", "metadata": { "id": "UrQzxKKh4d6u" }, "source": [ "该数据集包含以下特征" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rTjugo3n308g", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dftrain.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y86q1fj44lZs", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dftrain.describe()" ] }, { "cell_type": "markdown", "metadata": { "id": "8JSa_duD4tFZ" }, "source": [ "训练和评估集中分别有 627 个和 264 个样本。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Fs3Nu5pV4v5J", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dftrain.shape[0], dfeval.shape[0]" ] }, { "cell_type": "markdown", "metadata": { "id": "RxCA4Nr45AfF" }, "source": [ "大部分乘客的年龄为 20 多岁和 30 多岁。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RYeCMm7K40ZN", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dftrain.age.hist(bins=20)" ] }, { "cell_type": "markdown", "metadata": { "id": "DItSwJ_B5B0f" }, "source": [ "男性乘客人数大约是女性乘客人数的两倍。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b03dVV9q5Dv2", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dftrain.sex.value_counts().plot(kind='barh')" ] }, { "cell_type": "markdown", "metadata": { "id": "rK6WQ29q5Jf5" }, "source": [ "大多数乘客位于“三等”舱。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dgpJVeCq5Fgd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "dftrain['class'].value_counts().plot(kind='barh')" ] }, { "cell_type": "markdown", "metadata": { "id": "FXJhGGL85TLp" }, "source": [ "与男性相比,女性的幸存机率要高得多。这显然是模型的预测性特征。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lSZYa7c45Ttt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "pd.concat([dftrain, y_train], axis=1).groupby('sex').survived.mean().plot(kind='barh').set_xlabel('% survive')" ] }, { "cell_type": "markdown", "metadata": { "id": "qCHvgeorEsHa" }, "source": [ "## 模型的特征工程" ] }, { "cell_type": "markdown", "metadata": { "id": "Dhcq8Ds4mCtm" }, "source": [ "> 警告:不推荐为新代码使用本教程中介绍的 tf.feature_columns 模块。Keras 预处理层介绍了此功能,有关迁移说明,请参阅[迁移特征列](https://tensorflow.google.cn/guide/migrate/migrating_feature_columns)指南。tf.feature_columns 模块旨在与 TF1 Estimators 结合使用。它不在我们的[兼容性保证](https://tensorflow.org/guide/versions)范围内,除了安全漏洞修正外,不会获得其他修正。" ] }, { "cell_type": "markdown", "metadata": { "id": "VqDKQLZn8L-B" }, "source": [ "Estimator 使用名为[特征列](https://tensorflow.google.cn/tutorials/structured_data/feature_columns)的系统来描述模型应如何解释每个原始输入特征。需要为 Estimator 提供数字输入向量, *特征列*描述了模型应如何转换各个特征。\n", "\n", "选择和制作一组正确的特征列是学习高效模型的关键。特征列可以是原始特征 `dict`(*基础特征列*)中的一项原始输入,也可以是使用一个或多个基础列定义的转换创建的任何新列(*派生特征列*)。\n", "\n", "线性 Estimator 同时使用数字和分类特征。特征列可与所有 TensorFlow Estimator 配合使用,其目的是定义用于建模的特征。此外,它们还提供了一些特征工程功能,例如独热编码、归一化和分桶。" ] }, { "cell_type": "markdown", "metadata": { "id": "puZFOhTDkblt" }, "source": [ "### 基础特征列" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GpveXYSsADS6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',\n", " 'embark_town', 'alone']\n", "NUMERIC_COLUMNS = ['age', 'fare']\n", "\n", "feature_columns = []\n", "for feature_name in CATEGORICAL_COLUMNS:\n", " vocabulary = dftrain[feature_name].unique()\n", " feature_columns.append(tf.feature_column.categorical_column_with_vocabulary_list(feature_name, vocabulary))\n", "\n", "for feature_name in NUMERIC_COLUMNS:\n", " feature_columns.append(tf.feature_column.numeric_column(feature_name, dtype=tf.float32))" ] }, { "cell_type": "markdown", "metadata": { "id": "Gt8HMtwOh9lJ" }, "source": [ "`input_function` 指定如何将数据转换为流式馈送输入流水线的 `tf.data.Dataset`。`tf.data.Dataset` 支持多种来源,例如数据帧、csv 格式文件等。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qVtrIHFnAe7w", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def make_input_fn(data_df, label_df, num_epochs=10, shuffle=True, batch_size=32):\n", " def input_function():\n", " ds = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))\n", " if shuffle:\n", " ds = ds.shuffle(1000)\n", " ds = ds.batch(batch_size).repeat(num_epochs)\n", " return ds\n", " return input_function\n", "\n", "train_input_fn = make_input_fn(dftrain, y_train)\n", "eval_input_fn = make_input_fn(dfeval, y_eval, num_epochs=1, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "P7UMVkQnkrgb" }, "source": [ "您可以检查数据集:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8ZcG_3KiCb1M", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "ds = make_input_fn(dftrain, y_train, batch_size=10)()\n", "for feature_batch, label_batch in ds.take(1):\n", " print('Some feature keys:', list(feature_batch.keys()))\n", " print()\n", " print('A batch of class:', feature_batch['class'].numpy())\n", " print()\n", " print('A batch of Labels:', label_batch.numpy())" ] }, { "cell_type": "markdown", "metadata": { "id": "lMNBMyodjlW3" }, "source": [ "您还可以使用 `tf.keras.layers.DenseFeatures` 层来检查特定特征列的结果:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IMjlmbPlDmkB", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "age_column = feature_columns[7]\n", "tf.keras.layers.DenseFeatures([age_column])(feature_batch).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "f4zrAdCIjr3s" }, "source": [ "`DenseFeatures` 仅接受密集张量,要检查分类列,您需要先将其转换为指示列:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1VXmXFTSFEvv", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "gender_column = feature_columns[0]\n", "tf.keras.layers.DenseFeatures([tf.feature_column.indicator_column(gender_column)])(feature_batch).numpy()" ] }, { "cell_type": "markdown", "metadata": { "id": "MEp59g5UkHYY" }, "source": [ "将所有基础特征添加到模型后,让我们开始训练模型。训练模型仅为使用 `tf.estimator` API 的单个命令:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aGXjdnqqdgIs", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns)\n", "linear_est.train(train_input_fn)\n", "result = linear_est.evaluate(eval_input_fn)\n", "\n", "clear_output()\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "3tOan4hDsG6d" }, "source": [ "### 派生特征列" ] }, { "cell_type": "markdown", "metadata": { "id": "NOG2FSTHlAMu" }, "source": [ "现在,您已达到 75% 的准确率。单独使用每个基本特征列可能不足以解释数据。例如,年龄和标签之间的相关性可能因性别不同而不同。因此,如果您只学习了 `gender=\"Male\"` 和 `gender=\"Female\"` 的单个模型权重,则将无法捕获每个年龄-性别组合(例如区分 `gender=\"Male\"` 和 `age=\"30\"` 以及 `gender=\"Male\"` 和 `age=\"40\"`)。\n", "\n", "要了解不同特征组合之间的区别,您可以向模型添加*交叉特征列*(也可以在添加交叉列之前对年龄列进行分桶):" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AM-RsDzNfGlu", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "age_x_gender = tf.feature_column.crossed_column(['age', 'sex'], hash_bucket_size=100)" ] }, { "cell_type": "markdown", "metadata": { "id": "DqDFyPKQmGTN" }, "source": [ "将组合特征添加到模型后,让我们再次训练模型:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s8FV9oPQfS-g", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "derived_feature_columns = [age_x_gender]\n", "linear_est = tf.estimator.LinearClassifier(feature_columns=feature_columns+derived_feature_columns)\n", "linear_est.train(train_input_fn)\n", "result = linear_est.evaluate(eval_input_fn)\n", "\n", "clear_output()\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": { "id": "rwfdZj7ImLwb" }, "source": [ "现在,准确率已达 77.6%,与仅使用基础特征进行训练相比略高。您可以尝试使用更多特征和转换,看看能否进一步提高准确率!" ] }, { "cell_type": "markdown", "metadata": { "id": "8_eyb9d-ncjH" }, "source": [ "现在,您可以使用训练模型对评估集内的乘客进行预测。TensorFlow 模型进行了优化,能够每次以一批或一组样本的方式进行预测。之前,`eval_input_fn` 是使用整个评估集定义的。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wiScyBcef6Dq", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "pred_dicts = list(linear_est.predict(eval_input_fn))\n", "probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])\n", "\n", "probs.plot(kind='hist', bins=20, title='predicted probabilities')" ] }, { "cell_type": "markdown", "metadata": { "id": "UEHRCd4sqrLs" }, "source": [ "最后,查看结果的受试者工作特征 (ROC),这将使我们能够在真正例率与假正例率之间更好地加以权衡。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kqEjsezIokIe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "from sklearn.metrics import roc_curve\n", "from matplotlib import pyplot as plt\n", "\n", "fpr, tpr, _ = roc_curve(y_eval, probs)\n", "plt.plot(fpr, tpr)\n", "plt.title('ROC curve')\n", "plt.xlabel('false positive rate')\n", "plt.ylabel('true positive rate')\n", "plt.xlim(0,)\n", "plt.ylim(0,)" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "linear.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }