{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Cb4espuLKJiA" }, "source": [ "##### Copyright 2021 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "DjZQV2njKJ3U" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "mTL0TERThT6z" }, "source": [ "\n", " \n", " \n", " \n", " \n", " \n", "
在TensorFlow.org上查看在Google Colab中运行在GitHub上查看下载笔记本请参阅TF Hub模型
" ] }, { "cell_type": "markdown", "metadata": { "id": "K2madPFAGHb3" }, "source": [ "# 使用基于 YAMNet 的迁移学习进行环境声音分类\n", "\n", "[YAMNet](https://tfhub.dev/google/yamnet/1) 是一种预训练的深度神经网络,可以预测 [521 个类](https://github.com/tensorflow/models/blob/master/research/audioset/yamnet/yamnet_class_map.csv)的音频事件,例如笑声、吠叫或警笛声。\n", "\n", "在本教程中,您将学习如何:\n", "\n", "- 加载并使用 YAMNet 模型进行推断。\n", "- 使用 YAMNet 嵌入向量构建一个新模型来对猫和狗的声音进行分类。\n", "- 评估并导出模型。\n" ] }, { "cell_type": "markdown", "metadata": { "id": "5Mdp2TpBh96Y" }, "source": [ "## 导入 TensorFlow 和其他库\n" ] }, { "cell_type": "markdown", "metadata": { "id": "zCcKYqu_hvKe" }, "source": [ "首先安装 [TensorFlow I/O](https://tensorflow.google.cn/io),这将使您更轻松地从磁盘上加载音频文件。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "urBpRWDHTHHU" }, "outputs": [], "source": [ "!pip install -q \"tensorflow==2.11.*\"\n", "# tensorflow_io 0.28 is compatible with TensorFlow 2.11\n", "!pip install -q \"tensorflow_io==0.28.*\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7l3nqdWVF-kC" }, "outputs": [], "source": [ "import os\n", "\n", "from IPython import display\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "import tensorflow_io as tfio" ] }, { "cell_type": "markdown", "metadata": { "id": "v9ZhybCnt_bM" }, "source": [ "## 关于 YAMNet\n", "\n", "[YAMNet](https://github.com/tensorflow/models/tree/master/research/audioset/yamnet) 是一种采用 [MobileNetV1](https://arxiv.org/abs/1704.04861) 深度可分离卷积架构的预训练神经网络。它可以使用音频波形作为输入,并对 [AudioSet](http://g.co/audioset) 语料库中的 521 个音频事件分别进行独立预测。\n", "\n", "在内部,模型会从音频信号中提取“帧”并批量处理这些帧。此版本的模型使用时长为 0.96 秒的帧,每 0.48 秒提取一帧。\n", "\n", "模型会接受包含任意长度波形的一维 float32 张量或 NumPy 数组,表示为 `[-1.0, +1.0]` 区间内的单通道(单声道)16 kHz 样本。本教程包含帮助您将 WAV 文件转换为受支持格式的代码。\n", "\n", "模型会返回 3 个输出,包括类分数、嵌入向量(将用于迁移学习)和对数梅尔语[谱图](https://tfhub.dev/google/yamnet/1)。您可以在[此处](https://tfhub.dev/google/yamnet/1)找到更多详细信息。\n", "\n", "YAMNet 的一种特定用途是作为高级特征提取器 - 1,024 维嵌入向量输出。您将使用基础 (YAMNet) 模型的输入特征并将它们馈送到由一个隐藏的 `tf.keras.layers.Dense` 层组成的浅层模型中。然后,您将在少量数据上训练网络进行音频分类,而*不*需要大量带标签的数据和端到端训练。(这类似于[使用 TensorFlow Hub 进行图像分类迁移学习](https://tensorflow.google.cn/tutorials/images/transfer_learning_with_hub),请参阅以了解更多信息。)\n", "\n", "首先,您将测试模型并查看音频分类结果。然后,您将构建数据预处理流水线。\n", "\n", "### 从 TensorFlow Hub 加载 YAMNet\n", "\n", "您将使用来自 [TensorFlow Hub](https://tfhub.dev/) 的预训练 YAMNet 从声音文件中提取嵌入向量。\n", "\n", "从 TensorFlow Hub 中加载模型非常简单:选择模型,复制其网址,然后使用 `load` 函数。\n", "\n", "注:要阅读模型的文档,请在浏览器中使用模型网址。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "06CWkBV5v3gr" }, "outputs": [], "source": [ "yamnet_model_handle = 'https://tfhub.dev/google/yamnet/1'\n", "yamnet_model = hub.load(yamnet_model_handle)" ] }, { "cell_type": "markdown", "metadata": { "id": "GmrPJ0GHw9rr" }, "source": [ "加载模型后,您可以遵循 [YAMNet 基本使用教程](https://tensorflow.google.cn/hub/tutorials/yamnet)并下载 WAV 样本文件以运行推断。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C5i6xktEq00P" }, "outputs": [], "source": [ "testing_wav_file_name = tf.keras.utils.get_file('miaow_16k.wav',\n", " 'https://storage.googleapis.com/audioset/miaow_16k.wav',\n", " cache_dir='./',\n", " cache_subdir='test_data')\n", "\n", "print(testing_wav_file_name)" ] }, { "cell_type": "markdown", "metadata": { "id": "mBm9y9iV2U_-" }, "source": [ "您将需要用于加载音频文件的函数,稍后在处理训练数据时也将使用该函数。(请参阅[简单音频识别](https://tensorflow.google.cn/tutorials/audio/simple_audio#reading_audio_files_and_their_labels)以详细了解如何读取音频文件及其标签。)\n", "\n", "注:从 `load_wav_16k_mono` 返回的 `wav_data` 已经归一化为 `[-1.0, 1.0]` 区间内的值(有关更多信息,请参阅 [TF Hub 上的 YAMNet 文档](https://tfhub.dev/google/yamnet/1))。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xwc9Wrdg2EtY" }, "outputs": [], "source": [ "# Utility functions for loading audio files and making sure the sample rate is correct.\n", "\n", "@tf.function\n", "def load_wav_16k_mono(filename):\n", " \"\"\" Load a WAV file, convert it to a float tensor, resample to 16 kHz single-channel audio. \"\"\"\n", " file_contents = tf.io.read_file(filename)\n", " wav, sample_rate = tf.audio.decode_wav(\n", " file_contents,\n", " desired_channels=1)\n", " wav = tf.squeeze(wav, axis=-1)\n", " sample_rate = tf.cast(sample_rate, dtype=tf.int64)\n", " wav = tfio.audio.resample(wav, rate_in=sample_rate, rate_out=16000)\n", " return wav" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FRqpjkwB0Jjw" }, "outputs": [], "source": [ "testing_wav_data = load_wav_16k_mono(testing_wav_file_name)\n", "\n", "_ = plt.plot(testing_wav_data)\n", "\n", "# Play the audio file.\n", "display.Audio(testing_wav_data, rate=16000)" ] }, { "cell_type": "markdown", "metadata": { "id": "6z6rqlEz20YB" }, "source": [ "### 加载类映射\n", "\n", "务必加载 YAMNet 能够识别的类名。映射文件以 CSV 格式记录在 `yamnet_model.class_map_path()` 中。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6Gyj23e_3Mgr" }, "outputs": [], "source": [ "class_map_path = yamnet_model.class_map_path().numpy().decode('utf-8')\n", "class_names =list(pd.read_csv(class_map_path)['display_name'])\n", "\n", "for name in class_names[:20]:\n", " print(name)\n", "print('...')" ] }, { "cell_type": "markdown", "metadata": { "id": "5xbycDnT40u0" }, "source": [ "### 运行推断\n", "\n", "YAMNet 提供帧级类分数(即每帧 521 个分数)。为了确定剪辑级预测,可以按类跨帧聚合分数(例如,使用平均值或最大值聚合)。这是通过 `scores_np.mean(axis=0)` 以如下方式完成的。最后,要在剪辑级找到分数最高的类,您需要在 521 个聚合分数中取最大值。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NT0otp-A4Y3u" }, "outputs": [], "source": [ "scores, embeddings, spectrogram = yamnet_model(testing_wav_data)\n", "class_scores = tf.reduce_mean(scores, axis=0)\n", "top_class = tf.math.argmax(class_scores)\n", "inferred_class = class_names[top_class]\n", "\n", "print(f'The main sound is: {inferred_class}')\n", "print(f'The embeddings shape: {embeddings.shape}')" ] }, { "cell_type": "markdown", "metadata": { "id": "YBaLNg5H5IWa" }, "source": [ "注:模型正确推断出动物的声音。您在本教程中的目标是提高模型针对特定类的准确率。此外,请注意该模型生成了 13 个嵌入向量,每帧 1 个。" ] }, { "cell_type": "markdown", "metadata": { "id": "fmthELBg1A2-" }, "source": [ "## ESC-50 数据集\n", "\n", "[ESC-50 数据集](https://github.com/karolpiczak/ESC-50#repository-content) ([Piczak, 2015](https://www.karolpiczak.com/papers/Piczak2015-ESC-Dataset.pdf)) 是一个包含 2,000 个时长为 5 秒的环境录音的带标签集合。该数据集由 50 个类组成,每个类有 40 个样本。\n", "\n", "下载并提取数据集。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MWobqK8JmZOU" }, "outputs": [], "source": [ "_ = tf.keras.utils.get_file('esc-50.zip',\n", " 'https://github.com/karoldvl/ESC-50/archive/master.zip',\n", " cache_dir='./',\n", " cache_subdir='datasets',\n", " extract=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "qcruxiuX1cO5" }, "source": [ "### 探索数据\n", "\n", "每个文件的元数据均在 `./datasets/ESC-50-master/meta/esc50.csv` 下的 csv 文件中指定\n", "\n", "所有音频文件均位于 `./datasets/ESC-50-master/audio/`\n", "\n", "您将创建支持映射的 pandas `DataFrame`,并使用它来更清晰地查看数据。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jwmLygPrMAbH" }, "outputs": [], "source": [ "esc50_csv = './datasets/ESC-50-master/meta/esc50.csv'\n", "base_data_path = './datasets/ESC-50-master/audio/'\n", "\n", "pd_data = pd.read_csv(esc50_csv)\n", "pd_data.head()" ] }, { "cell_type": "markdown", "metadata": { "id": "7d4rHBEQ2QAU" }, "source": [ "### 过滤数据\n", "\n", "现在,数据存储在 `DataFrame` 中,请应用一些转换:\n", "\n", "- 过滤掉行并仅使用所选类 - `dog` 和 `cat`。如果您想使用任何其他类,则可以在此处进行选择。\n", "- 修改文件名以获得完整路径。这将使后续加载更加容易。\n", "- 将目标更改到特定区间内。在此示例中,`dog` 将保持为 `0`,但 `cat` 将改为 `1`,而非其原始值 `5`。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tFnEoQjgs14I" }, "outputs": [], "source": [ "my_classes = ['dog', 'cat']\n", "map_class_to_id = {'dog':0, 'cat':1}\n", "\n", "filtered_pd = pd_data[pd_data.category.isin(my_classes)]\n", "\n", "class_id = filtered_pd['category'].apply(lambda name: map_class_to_id[name])\n", "filtered_pd = filtered_pd.assign(target=class_id)\n", "\n", "full_path = filtered_pd['filename'].apply(lambda row: os.path.join(base_data_path, row))\n", "filtered_pd = filtered_pd.assign(filename=full_path)\n", "\n", "filtered_pd.head(10)" ] }, { "cell_type": "markdown", "metadata": { "id": "BkDcBS-aJdCz" }, "source": [ "### 加载音频文件并检索嵌入向量\n", "\n", "在这里,您将应用 `load_wav_16k_mono` 并为模型准备 WAV 数据。\n", "\n", "从 WAV 数据中提取嵌入向量时,您会得到一个形状为 `(N, 1024)` 的数组,其中 `N` 为 YAMNet 找到的帧数(每 0.48 秒音频一帧)。" ] }, { "cell_type": "markdown", "metadata": { "id": "AKDT5RomaDKO" }, "source": [ "您的模型将使用每一帧作为一个输入。因此,您需要创建一个新列,每行包含一帧。您还需要展开标签和 `fold` 列以正确反映这些新行。\n", "\n", "展开的 `fold` 列会保留原始值。您不能混合帧,因为在执行拆分时,最后可能会将同一个音频拆分为不同的部分,这会降低您的验证和测试步骤的效率。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u5Rq3_PyKLtU" }, "outputs": [], "source": [ "filenames = filtered_pd['filename']\n", "targets = filtered_pd['target']\n", "folds = filtered_pd['fold']\n", "\n", "main_ds = tf.data.Dataset.from_tensor_slices((filenames, targets, folds))\n", "main_ds.element_spec" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rsEfovDVAHGY" }, "outputs": [], "source": [ "def load_wav_for_map(filename, label, fold):\n", " return load_wav_16k_mono(filename), label, fold\n", "\n", "main_ds = main_ds.map(load_wav_for_map)\n", "main_ds.element_spec" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "k0tG8DBNAHcE" }, "outputs": [], "source": [ "# applies the embedding extraction model to a wav data\n", "def extract_embedding(wav_data, label, fold):\n", " ''' run YAMNet to extract embedding from the wav data '''\n", " scores, embeddings, spectrogram = yamnet_model(wav_data)\n", " num_embeddings = tf.shape(embeddings)[0]\n", " return (embeddings,\n", " tf.repeat(label, num_embeddings),\n", " tf.repeat(fold, num_embeddings))\n", "\n", "# extract embedding\n", "main_ds = main_ds.map(extract_embedding).unbatch()\n", "main_ds.element_spec" ] }, { "cell_type": "markdown", "metadata": { "id": "ZdfPIeD0Qedk" }, "source": [ "### 拆分数据\n", "\n", "您需要使用 `fold` 列将数据集拆分为训练集、验证集和测试集。\n", "\n", "ESC-50 被排列成五个大小一致的交叉验证 `fold`,这样,源自同一来源的剪辑就始终位于同一 `fold` 中 - 请参阅 [ESC: Dataset for Environmental Sound Classification](https://www.karolpiczak.com/papers/Piczak2015-ESC-Dataset.pdf) 论文以了解更多信息。\n", "\n", "最后一步是从数据集中移除 `fold` 列,因为您在训练期间不会用到它。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1ZYvlFiVsffC" }, "outputs": [], "source": [ "cached_ds = main_ds.cache()\n", "train_ds = cached_ds.filter(lambda embedding, label, fold: fold < 4)\n", "val_ds = cached_ds.filter(lambda embedding, label, fold: fold == 4)\n", "test_ds = cached_ds.filter(lambda embedding, label, fold: fold == 5)\n", "\n", "# remove the folds column now that it's not needed anymore\n", "remove_fold_column = lambda embedding, label, fold: (embedding, label)\n", "\n", "train_ds = train_ds.map(remove_fold_column)\n", "val_ds = val_ds.map(remove_fold_column)\n", "test_ds = test_ds.map(remove_fold_column)\n", "\n", "train_ds = train_ds.cache().shuffle(1000).batch(32).prefetch(tf.data.AUTOTUNE)\n", "val_ds = val_ds.cache().batch(32).prefetch(tf.data.AUTOTUNE)\n", "test_ds = test_ds.cache().batch(32).prefetch(tf.data.AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "v5PaMwvtcAIe" }, "source": [ "## 创建模型\n", "\n", "大部分工作已经完成!接下来,请定义一个非常简单的[序贯](https://tensorflow.google.cn/guide/keras/sequential_model)模型,其中包含一个隐藏层和两个输出,以便通过声音识别猫和狗。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JYCE0Fr1GpN3" }, "outputs": [], "source": [ "my_model = tf.keras.Sequential([\n", " tf.keras.layers.Input(shape=(1024), dtype=tf.float32,\n", " name='input_embedding'),\n", " tf.keras.layers.Dense(512, activation='relu'),\n", " tf.keras.layers.Dense(len(my_classes))\n", "], name='my_model')\n", "\n", "my_model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l1qgH35HY0SE" }, "outputs": [], "source": [ "my_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", " optimizer=\"adam\",\n", " metrics=['accuracy'])\n", "\n", "callback = tf.keras.callbacks.EarlyStopping(monitor='loss',\n", " patience=3,\n", " restore_best_weights=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T3sj84eOZ3pk" }, "outputs": [], "source": [ "history = my_model.fit(train_ds,\n", " epochs=20,\n", " validation_data=val_ds,\n", " callbacks=callback)" ] }, { "cell_type": "markdown", "metadata": { "id": "OAbraYKYpdoE" }, "source": [ "让我们对测试数据运行 `evaluate` 方法,以避免过拟合。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "H4Nh5nec3Sky" }, "outputs": [], "source": [ "loss, accuracy = my_model.evaluate(test_ds)\n", "\n", "print(\"Loss: \", loss)\n", "print(\"Accuracy: \", accuracy)" ] }, { "cell_type": "markdown", "metadata": { "id": "cid-qIrIpqHS" }, "source": [ "做得很棒!" ] }, { "cell_type": "markdown", "metadata": { "id": "nCKZonrJcXab" }, "source": [ "## 测试模型\n", "\n", "接下来,仅使用 YAMNet 基于之前测试中的嵌入向量尝试您的模型。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "79AFpA3_ctCF" }, "outputs": [], "source": [ "scores, embeddings, spectrogram = yamnet_model(testing_wav_data)\n", "result = my_model(embeddings).numpy()\n", "\n", "inferred_class = my_classes[result.mean(axis=0).argmax()]\n", "print(f'The main sound is: {inferred_class}')" ] }, { "cell_type": "markdown", "metadata": { "id": "k2yleeev645r" }, "source": [ "## 保存可直接将 WAV 文件作为输入的模型\n", "\n", "使用嵌入向量作为输入,您的模型即可工作。\n", "\n", "在实际场景中,您需要使用音频数据作为直接输入。\n", "\n", "为此,您需要将 YAMNet 与您的模型组合成一个模型,从而导出用于其他应用。\n", "\n", "为了便于使用模型的结果,最后一层将为 `reduce_mean` 运算。使用此模型进行应用时(您将在本教程后续内容中了解),您将需要最后一层的名称。如果未定义,TensorFlow 会自动定义递增式名称,这会使得使其难以测试,因为它会在您每次训练模型时不断变化。使用原始 TensorFlow 运算时,您无法为其分配名称。为了解决这个问题,您将创建一个应用 `reduce_mean` 的自定义层并将其称为 `'classifier'`。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QUVCI2Suunpw" }, "outputs": [], "source": [ "class ReduceMeanLayer(tf.keras.layers.Layer):\n", " def __init__(self, axis=0, **kwargs):\n", " super(ReduceMeanLayer, self).__init__(**kwargs)\n", " self.axis = axis\n", "\n", " def call(self, input):\n", " return tf.math.reduce_mean(input, axis=self.axis)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zE_Npm0nzlwc" }, "outputs": [], "source": [ "saved_model_path = './dogs_and_cats_yamnet'\n", "\n", "input_segment = tf.keras.layers.Input(shape=(), dtype=tf.float32, name='audio')\n", "embedding_extraction_layer = hub.KerasLayer(yamnet_model_handle,\n", " trainable=False, name='yamnet')\n", "_, embeddings_output, _ = embedding_extraction_layer(input_segment)\n", "serving_outputs = my_model(embeddings_output)\n", "serving_outputs = ReduceMeanLayer(axis=0, name='classifier')(serving_outputs)\n", "serving_model = tf.keras.Model(input_segment, serving_outputs)\n", "serving_model.save(saved_model_path, include_optimizer=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y-0bY5FMme1C" }, "outputs": [], "source": [ "tf.keras.utils.plot_model(serving_model)" ] }, { "cell_type": "markdown", "metadata": { "id": "btHQDN9mqxM_" }, "source": [ "加载您保存的模型以验证它能否按预期工作。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KkYVpJS72WWB" }, "outputs": [], "source": [ "reloaded_model = tf.saved_model.load(saved_model_path)" ] }, { "cell_type": "markdown", "metadata": { "id": "4BkmvvNzq49l" }, "source": [ "最终测试:给定一些声音数据,您的模型能否返回正确的结果?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xeXtD5HO28y-" }, "outputs": [], "source": [ "reloaded_results = reloaded_model(testing_wav_data)\n", "cat_or_dog = my_classes[tf.math.argmax(reloaded_results)]\n", "print(f'The main sound is: {cat_or_dog}')" ] }, { "cell_type": "markdown", "metadata": { "id": "ZRrOcBYTUgwn" }, "source": [ "如果您想在应用环境中尝试您的新模型,可以使用 'serving_default' 签名。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ycC8zzDSUG2s" }, "outputs": [], "source": [ "serving_results = reloaded_model.signatures['serving_default'](testing_wav_data)\n", "cat_or_dog = my_classes[tf.math.argmax(serving_results['classifier'])]\n", "print(f'The main sound is: {cat_or_dog}')\n" ] }, { "cell_type": "markdown", "metadata": { "id": "da7blblCHs8c" }, "source": [ "## (可选)更多测试\n", "\n", "模型已准备就绪。\n", "\n", "让我们基于测试数据集将它与 YAMNet 进行比较。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vDf5MASIIN1z" }, "outputs": [], "source": [ "test_pd = filtered_pd.loc[filtered_pd['fold'] == 5]\n", "row = test_pd.sample(1)\n", "filename = row['filename'].item()\n", "print(filename)\n", "waveform = load_wav_16k_mono(filename)\n", "print(f'Waveform values: {waveform}')\n", "_ = plt.plot(waveform)\n", "\n", "display.Audio(waveform, rate=16000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eYUzFxYJIcE1" }, "outputs": [], "source": [ "# Run the model, check the output.\n", "scores, embeddings, spectrogram = yamnet_model(waveform)\n", "class_scores = tf.reduce_mean(scores, axis=0)\n", "top_class = tf.math.argmax(class_scores)\n", "inferred_class = class_names[top_class]\n", "top_score = class_scores[top_class]\n", "print(f'[YAMNet] The main sound is: {inferred_class} ({top_score})')\n", "\n", "reloaded_results = reloaded_model(waveform)\n", "your_top_class = tf.math.argmax(reloaded_results)\n", "your_inferred_class = my_classes[your_top_class]\n", "class_probabilities = tf.nn.softmax(reloaded_results, axis=-1)\n", "your_top_score = class_probabilities[your_top_class]\n", "print(f'[Your model] The main sound is: {your_inferred_class} ({your_top_score})')" ] }, { "cell_type": "markdown", "metadata": { "id": "g8Tsym8Rq-0V" }, "source": [ "## 后续步骤\n", "\n", "您已创建可对狗或猫的叫声进行分类的模型。利用相同的想法和不同的数据集,您可以尝试构建诸如基于鸟鸣的[鸟类声学识别模型](https://www.kaggle.com/c/birdclef-2021/)。\n", "\n", "在社交媒体上与 TensorFlow 团队分享您的项目吧!\n" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "transfer_learning_audio.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }