{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ndo4ERqnwQOU" }, "outputs": [], "source": [ "##### Copyright 2020 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "MTKwbguKwT4R", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "xfNT-mlFwxVM" }, "source": [ "# 自编码器简介" ] }, { "cell_type": "markdown", "metadata": { "id": "0TD5ZrvEMbhZ" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看\n", " 在 Google Colab 运行\n", " 在 Github 上查看源代码\n", " 下载笔记本\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "ITZuApL56Mny" }, "source": [ "本教程通过以下三个示例介绍自编码器:基础知识、图像降噪和异常检测。\n", "\n", "自编码器是一种特殊类型的神经网络,经过训练后可将其输入复制到其输出。例如,给定一个手写数字的图像,自编码器首先将图像编码为低维的潜在表示,然后将该潜在表示解码回图像。自编码器学习压缩数据,同时最大程度地减少重构误差。\n", "\n", "要详细了解自编码器,请考虑阅读 Ian Goodfellow、Yoshua Bengio 和 Aaron Courville 撰写的[《深度学习》](https://www.deeplearningbook.org/)一书的第 14 章。" ] }, { "cell_type": "markdown", "metadata": { "id": "e1_Y75QXJS6h" }, "source": [ "## 导入 TensorFlow 和其他库" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YfIk2es3hJEd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import tensorflow as tf\n", "\n", "from sklearn.metrics import accuracy_score, precision_score, recall_score\n", "from sklearn.model_selection import train_test_split\n", "from tensorflow.keras import layers, losses\n", "from tensorflow.keras.datasets import fashion_mnist\n", "from tensorflow.keras.models import Model" ] }, { "cell_type": "markdown", "metadata": { "id": "iYn4MdZnKCey" }, "source": [ "## 加载数据集\n", "\n", "首先,您将使用 Fashion MNIST 数据集训练基本自编码器。此数据集中的每个图像均为 28x28 像素。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YZm503-I_tji", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "(x_train, _), (x_test, _) = fashion_mnist.load_data()\n", "\n", "x_train = x_train.astype('float32') / 255.\n", "x_test = x_test.astype('float32') / 255.\n", "\n", "print (x_train.shape)\n", "print (x_test.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "VEdCXSwCoKok" }, "source": [ "## 第一个示例:基本自编码器\n", "\n", "![Basic autoencoder results](images/intro_autoencoder_result.png)\n", "\n", "定义一个具有两个密集层的自编码器:一个将图像压缩为 64 维隐向量的 `encoder`,以及一个从隐空间重构原始图像的 `decoder`。\n", "\n", "要定义模型,请使用 [Keras Model Subclassing API](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0MUxidpyChjX", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class Autoencoder(Model):\n", " def __init__(self, latent_dim, shape):\n", " super(Autoencoder, self).__init__()\n", " self.latent_dim = latent_dim\n", " self.shape = shape\n", " self.encoder = tf.keras.Sequential([\n", " layers.Flatten(),\n", " layers.Dense(latent_dim, activation='relu'),\n", " ])\n", " self.decoder = tf.keras.Sequential([\n", " layers.Dense(tf.math.reduce_prod(shape), activation='sigmoid'),\n", " layers.Reshape(shape)\n", " ])\n", "\n", " def call(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded\n", "\n", "\n", "shape = x_test.shape[1:]\n", "latent_dim = 64\n", "autoencoder = Autoencoder(latent_dim, shape)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9I1JlqEIDCI4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())" ] }, { "cell_type": "markdown", "metadata": { "id": "7oJSeMTroABs" }, "source": [ "使用 `x_train` 作为输入和目标来训练模型。`encoder` 会学习将数据集从 784 个维度压缩到隐空间,而 `decoder` 将学习重构原始图像。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "h1RI9OfHDBsK", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "autoencoder.fit(x_train, x_train,\n", " epochs=10,\n", " shuffle=True,\n", " validation_data=(x_test, x_test))" ] }, { "cell_type": "markdown", "metadata": { "id": "wAM1QBhtoC-n" }, "source": [ "现在,模型已经训练完成,我们通过对测试集中的图像进行编码和解码来测试该模型。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pbr5WCj7FQUi", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "encoded_imgs = autoencoder.encoder(x_test).numpy()\n", "decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "s4LlDOS6FUA1", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "n = 10\n", "plt.figure(figsize=(20, 4))\n", "for i in range(n):\n", " # display original\n", " ax = plt.subplot(2, n, i + 1)\n", " plt.imshow(x_test[i])\n", " plt.title(\"original\")\n", " plt.gray()\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", "\n", " # display reconstruction\n", " ax = plt.subplot(2, n, i + 1 + n)\n", " plt.imshow(decoded_imgs[i])\n", " plt.title(\"reconstructed\")\n", " plt.gray()\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "r4gv6G8PoRQE" }, "source": [ "## 第二个示例:图像降噪\n", "\n", "![Image denoising results](https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tutorials/generative/images/image_denoise_fmnist_results.png?raw=true)\n", "\n", "经过训练后,自编码器还可以去除图像中的噪点。在以下部分中,您将通过对每个图像应用随机噪声来创建有噪版本的 Fashion MNIST 数据集。随后,您将使用有噪图像作为输入并以原始图像作为目标来训练自编码器。\n", "\n", "我们重新导入数据集以忽略之前所做的修改:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gDYHJA2PCQ3m", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "(x_train, _), (x_test, _) = fashion_mnist.load_data()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uJZ-TcaqDBr5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "x_train = x_train.astype('float32') / 255.\n", "x_test = x_test.astype('float32') / 255.\n", "\n", "x_train = x_train[..., tf.newaxis]\n", "x_test = x_test[..., tf.newaxis]\n", "\n", "print(x_train.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "aPZl_6P65_8R" }, "source": [ "向图像添加随机噪声:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "axSMyxC354fc", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "noise_factor = 0.2\n", "x_train_noisy = x_train + noise_factor * tf.random.normal(shape=x_train.shape) \n", "x_test_noisy = x_test + noise_factor * tf.random.normal(shape=x_test.shape) \n", "\n", "x_train_noisy = tf.clip_by_value(x_train_noisy, clip_value_min=0., clip_value_max=1.)\n", "x_test_noisy = tf.clip_by_value(x_test_noisy, clip_value_min=0., clip_value_max=1.)" ] }, { "cell_type": "markdown", "metadata": { "id": "wRxHe4XXltNd" }, "source": [ "绘制有噪图像:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "thKUmbVVCQpt", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "n = 10\n", "plt.figure(figsize=(20, 2))\n", "for i in range(n):\n", " ax = plt.subplot(1, n, i + 1)\n", " plt.title(\"original + noise\")\n", " plt.imshow(tf.squeeze(x_test_noisy[i]))\n", " plt.gray()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "Sy9SY8jGl5aP" }, "source": [ "### 定义卷积自编码器" ] }, { "cell_type": "markdown", "metadata": { "id": "vT_BhZngWMwp" }, "source": [ "在此示例中,您将使用 `encoder` 中的 [Conv2D](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv2D) 层和 `decoder` 中的 [Conv2DTranspose](https://tensorflow.google.cn/api_docs/python/tf/keras/layers/Conv2DTranspose) 层来训练卷积自编码器。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R5KjoIlYCQko", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class Denoise(Model):\n", " def __init__(self):\n", " super(Denoise, self).__init__()\n", " self.encoder = tf.keras.Sequential([\n", " layers.Input(shape=(28, 28, 1)),\n", " layers.Conv2D(16, (3, 3), activation='relu', padding='same', strides=2),\n", " layers.Conv2D(8, (3, 3), activation='relu', padding='same', strides=2)])\n", "\n", " self.decoder = tf.keras.Sequential([\n", " layers.Conv2DTranspose(8, kernel_size=3, strides=2, activation='relu', padding='same'),\n", " layers.Conv2DTranspose(16, kernel_size=3, strides=2, activation='relu', padding='same'),\n", " layers.Conv2D(1, kernel_size=(3, 3), activation='sigmoid', padding='same')])\n", "\n", " def call(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded\n", "\n", "autoencoder = Denoise()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QYKbiDFYCQfj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "autoencoder.compile(optimizer='adam', loss=losses.MeanSquaredError())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IssFr1BNCQX3", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "autoencoder.fit(x_train_noisy, x_train,\n", " epochs=10,\n", " shuffle=True,\n", " validation_data=(x_test_noisy, x_test))" ] }, { "cell_type": "markdown", "metadata": { "id": "G85xUVBGTAKp" }, "source": [ "我们来看一下编码器的摘要。请注意图像是如何从 28x28 像素下采样为 7x7 像素的。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oEpxlX6sTEQz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "autoencoder.encoder.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "DDZBfMx1UtXx" }, "source": [ "解码器将图像从 7x7 像素上采样为 28x28 像素。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pbeQtYMaUpro", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "autoencoder.decoder.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "A7-VAuEy_N6M" }, "source": [ "绘制由自编码器生成的有噪图像和去噪图像。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "t5IyPi1fCQQz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "encoded_imgs = autoencoder.encoder(x_test_noisy).numpy()\n", "decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sfxr9NdBCP_x", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "n = 10\n", "plt.figure(figsize=(20, 4))\n", "for i in range(n):\n", "\n", " # display original + noise\n", " ax = plt.subplot(2, n, i + 1)\n", " plt.title(\"original + noise\")\n", " plt.imshow(tf.squeeze(x_test_noisy[i]))\n", " plt.gray()\n", " ax.get_xaxis().set_visible(False)\n", " ax.get_yaxis().set_visible(False)\n", "\n", " # display reconstruction\n", " bx = plt.subplot(2, n, i + n + 1)\n", " plt.title(\"reconstructed\")\n", " plt.imshow(tf.squeeze(decoded_imgs[i]))\n", " plt.gray()\n", " bx.get_xaxis().set_visible(False)\n", " bx.get_yaxis().set_visible(False)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "ErGrTnWHoUYl" }, "source": [ "## 第三个示例:异常检测\n", "\n", "## 概述\n", "\n", "在此示例中,您将训练自编码器来检测 [ECG5000 数据集](http://www.timeseriesclassification.com/description.php?Dataset=ECG5000)上的异常。此数据集包含 5,000 个[心电图](https://en.wikipedia.org/wiki/Electrocardiography),每个心电图拥有 140 个数据点。您将使用简化版的数据集,其中每个样本都被标记为 `0`(对应于异常心律)或 `1`(对应于正常心律)。您需要关注如何识别异常心律。\n", "\n", "注:这是一个有标签的数据集,因此您可以将其表述为一个监督学习问题。此示例的目标是说明可应用于没有可用标签的大型数据集的异常检测概念(例如,如果您有成千上万个正常心律,而只有少量异常心律)。\n", "\n", "您将如何使用自编码器检测异常?回想一下,自编码器经过训练后可最大程度地减少重构误差。您将只基于正常心律训练自编码器,随后使用它来重构所有数据。我们的假设是,异常心律存在更高的重构误差。随后,如果重构误差超过固定阈值,则将心律分类为异常。" ] }, { "cell_type": "markdown", "metadata": { "id": "i5estNaur_Mh" }, "source": [ "### 加载心电图数据" ] }, { "cell_type": "markdown", "metadata": { "id": "y35nsXLPsDNX" }, "source": [ "您将使用的数据集基于 [timeseriesclassification.com](http://www.timeseriesclassification.com/description.php?Dataset=ECG5000) 中的数据集。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KmKRDJWgsFYa", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Download the dataset\n", "dataframe = pd.read_csv('http://storage.googleapis.com/download.tensorflow.org/data/ecg.csv', header=None)\n", "raw_data = dataframe.values\n", "dataframe.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UmuCPVYKsKKx", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# The last element contains the labels\n", "labels = raw_data[:, -1]\n", "\n", "# The other data points are the electrocadriogram data\n", "data = raw_data[:, 0:-1]\n", "\n", "train_data, test_data, train_labels, test_labels = train_test_split(\n", " data, labels, test_size=0.2, random_state=21\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "byK2vP7hsMbz" }, "source": [ "将数据归一化为 `[0,1]`。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tgMZVWRKsPx6", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "min_val = tf.reduce_min(train_data)\n", "max_val = tf.reduce_max(train_data)\n", "\n", "train_data = (train_data - min_val) / (max_val - min_val)\n", "test_data = (test_data - min_val) / (max_val - min_val)\n", "\n", "train_data = tf.cast(train_data, tf.float32)\n", "test_data = tf.cast(test_data, tf.float32)" ] }, { "cell_type": "markdown", "metadata": { "id": "BdSYr2IPsTiz" }, "source": [ "您将仅使用正常心律训练自编码器,在此数据集中,正常心律被标记为 `1`。将正常心律与异常心律分开。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VvK4NRe8sVhE", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_labels = train_labels.astype(bool)\n", "test_labels = test_labels.astype(bool)\n", "\n", "normal_train_data = train_data[train_labels]\n", "normal_test_data = test_data[test_labels]\n", "\n", "anomalous_train_data = train_data[~train_labels]\n", "anomalous_test_data = test_data[~test_labels]" ] }, { "cell_type": "markdown", "metadata": { "id": "wVcTBDo-CqFS" }, "source": [ "绘制正常的心电图。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZTlMIrpmseYe", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.grid()\n", "plt.plot(np.arange(140), normal_train_data[0])\n", "plt.title(\"A Normal ECG\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "QpI9by2ZA0NN" }, "source": [ "绘制异常的心电图。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zrpXREF2siBr", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.grid()\n", "plt.plot(np.arange(140), anomalous_train_data[0])\n", "plt.title(\"An Anomalous ECG\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "0DS6QKZJslZz" }, "source": [ "### 构建模型" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bf6owZQDsp9y", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class AnomalyDetector(Model):\n", " def __init__(self):\n", " super(AnomalyDetector, self).__init__()\n", " self.encoder = tf.keras.Sequential([\n", " layers.Dense(32, activation=\"relu\"),\n", " layers.Dense(16, activation=\"relu\"),\n", " layers.Dense(8, activation=\"relu\")])\n", " \n", " self.decoder = tf.keras.Sequential([\n", " layers.Dense(16, activation=\"relu\"),\n", " layers.Dense(32, activation=\"relu\"),\n", " layers.Dense(140, activation=\"sigmoid\")])\n", " \n", " def call(self, x):\n", " encoded = self.encoder(x)\n", " decoded = self.decoder(encoded)\n", " return decoded\n", "\n", "autoencoder = AnomalyDetector()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gwRpBBbg463S", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "autoencoder.compile(optimizer='adam', loss='mae')" ] }, { "cell_type": "markdown", "metadata": { "id": "zuTy60STBEy4" }, "source": [ "请注意,自编码器仅使用正常的心电图进行训练,但使用完整的测试集进行评估。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V6NFSs-jsty2", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "history = autoencoder.fit(normal_train_data, normal_train_data, \n", " epochs=20, \n", " batch_size=512,\n", " validation_data=(test_data, test_data),\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OEexphFwwTQS", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "plt.plot(history.history[\"loss\"], label=\"Training Loss\")\n", "plt.plot(history.history[\"val_loss\"], label=\"Validation Loss\")\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": { "id": "ceI5lKv1BT-A" }, "source": [ "如果重构误差比正常训练样本大一个标准差,您可以快速地将心电图归类为异常。首先,我们绘制训练集中的一个正常心电图,随后绘制自编码器对其进行编码和解码后的重构以及重构误差。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hmsk4DuktxJ2", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "encoded_data = autoencoder.encoder(normal_test_data).numpy()\n", "decoded_data = autoencoder.decoder(encoded_data).numpy()\n", "\n", "plt.plot(normal_test_data[0], 'b')\n", "plt.plot(decoded_data[0], 'r')\n", "plt.fill_between(np.arange(140), decoded_data[0], normal_test_data[0], color='lightcoral')\n", "plt.legend(labels=[\"Input\", \"Reconstruction\", \"Error\"])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "ocA_q9ufB_aF" }, "source": [ "创建一个类似的绘图,这次是一个异常的测试样本。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vNFTuPhLwTBn", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "encoded_data = autoencoder.encoder(anomalous_test_data).numpy()\n", "decoded_data = autoencoder.decoder(encoded_data).numpy()\n", "\n", "plt.plot(anomalous_test_data[0], 'b')\n", "plt.plot(decoded_data[0], 'r')\n", "plt.fill_between(np.arange(140), decoded_data[0], anomalous_test_data[0], color='lightcoral')\n", "plt.legend(labels=[\"Input\", \"Reconstruction\", \"Error\"])\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "ocimg3MBswdS" }, "source": [ "### 检测异常" ] }, { "cell_type": "markdown", "metadata": { "id": "Xnh8wmkDsypN" }, "source": [ "通过计算重构损失是否大于固定阈值来检测异常。在本教程中,您将计算训练集中正常样本的平均误差,如果重构误差比训练集大一个标准差,则将未来的样本分类为异常。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "TeuT8uTA5Y_w" }, "source": [ "根据训练集中的正常心电图绘制重构误差:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N7FltOnHu4-l", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "reconstructions = autoencoder.predict(normal_train_data)\n", "train_loss = tf.keras.losses.mae(reconstructions, normal_train_data)\n", "\n", "plt.hist(train_loss[None,:], bins=50)\n", "plt.xlabel(\"Train loss\")\n", "plt.ylabel(\"No of examples\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "mh-3ChEF5hog" }, "source": [ "选择一个比平均值高一个标准差的阈值。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "82hkl0Chs3P_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "threshold = np.mean(train_loss) + np.std(train_loss)\n", "print(\"Threshold: \", threshold)" ] }, { "cell_type": "markdown", "metadata": { "id": "uEGlA1Be50Nj" }, "source": [ "注:还有其他可用来选择阈值的策略,高于该阈值时,应将测试样本分类为异常,正确的方式将取决于您的数据集。您可以通过本教程末尾的链接了解更多信息。 " ] }, { "cell_type": "markdown", "metadata": { "id": "zpLSDAeb51D_" }, "source": [ "如果检查测试集中异常样本的重构误差,您会注意到大多数异常样本的重构误差都比阈值大。通过更改阈值,您可以调整分类器的[精确率](https://developers.google.com/machine-learning/glossary#precision)和[召回率](https://developers.google.com/machine-learning/glossary#recall)。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sKVwjQK955Wy", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "reconstructions = autoencoder.predict(anomalous_test_data)\n", "test_loss = tf.keras.losses.mae(reconstructions, anomalous_test_data)\n", "\n", "plt.hist(test_loss[None, :], bins=50)\n", "plt.xlabel(\"Test loss\")\n", "plt.ylabel(\"No of examples\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "PFVk_XGE6AX2" }, "source": [ "如果重构误差大于阈值,则将心电图分类为异常。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mkgJZfhh6CHr", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def predict(model, data, threshold):\n", " reconstructions = model(data)\n", " loss = tf.keras.losses.mae(reconstructions, data)\n", " return tf.math.less(loss, threshold)\n", "\n", "def print_stats(predictions, labels):\n", " print(\"Accuracy = {}\".format(accuracy_score(labels, predictions)))\n", " print(\"Precision = {}\".format(precision_score(labels, predictions)))\n", " print(\"Recall = {}\".format(recall_score(labels, predictions)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sOcfXfXq6FBd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "preds = predict(autoencoder, test_data, threshold)\n", "print_stats(preds, test_labels)" ] }, { "cell_type": "markdown", "metadata": { "id": "HrJRef8Ln945" }, "source": [ "## 后续步骤\n", "\n", "要详细了解如何使用自编码器检测异常,请查看 Victor Dibia 使用 TensorFlow.js 构建的出色[交互式示例](https://anomagram.fastforwardlabs.com/#/)。对于真实用例,您可以了解 [Airbus 如何使用 TensorFlow 检测 ISS 遥测数据中的异常](https://blog.tensorflow.org/2020/04/how-airbus-detects-anomalies-iss-telemetry-data-tfx.html)。要详细了解基础知识,请考虑阅读 François Chollet 撰写的这篇[博文](https://blog.keras.io/building-autoencoders-in-keras.html)。有关更多详细信息,请查看 Ian Goodfellow、Yoshua Bengio 和 Aaron Courville 撰写的[《深度学习》](https://www.deeplearningbook.org/)一书的第 14 章。\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "autoencoder.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }