{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "mt9dL5dIir8X" }, "outputs": [], "source": [ "##### Copyright 2022 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ufPx7EiCiqgR", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4StGz9ynOEL6" }, "source": [ "# 加载视频数据" ] }, { "cell_type": "markdown", "metadata": { "id": "KwQtSOz0VrVX" }, "source": [ "\n", " \n", " \n", " \n", " \n", "
在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码\n", " 下载笔记本
" ] }, { "cell_type": "markdown", "metadata": { "id": "F-SqCosJ6-0H" }, "source": [ "本教程演示如何使用 [UCF101 人体动作数据集](https://en.wikipedia.org/wiki/Audio_Video_Interleave)加载和预处理 [AVI](https://tensorflow.google.cn/datasets/catalog/ucf101) 视频数据。当您对数据进行预处理后,就可以将其用于视频分类/识别、字幕或聚类等任务。原始数据集包含从 YouTube 收集的具有 101 个类别的真实动作视频,包括演奏大提琴、刷牙和化眼妆。您将学习如何:\n", "\n", "- 从 ZIP 文件加载数据。\n", "\n", "- 从视频文件中读取帧序列。\n", "\n", "- 呈现视频数据。\n", "\n", "- 封装帧生成器 [`tf.data.Dataset`](https://tensorflow.google.cn/guide/data)。\n", "\n", "本视频加载和预处理教程是 TensorFlow 视频教程系列的第一部分。下面是其他三个教程:\n", "\n", "- [构建用于视频分类的 3D CNN 模型](https://tensorflow.google.cn/tutorials/video/video_classification):请注意,本教程使用分解 3D 数据的空间和时间方面的 (2+1)D CNN;如果使用 MRI 扫描等体数据,请考虑使用 3D CNN 而不是 (2+1)D CNN。\n", "- [用于流式动作识别的 MoViNet](https://tensorflow.google.cn/hub/tutorials/movinet):熟悉 TF Hub 上提供的 MoViNet 模型。\n", "- [使用 MoViNet 进行视频分类的迁移学习](https://tensorflow.google.cn/tutorials/video/transfer_learning_with_movinet):本教程介绍了如何使用预训练的视频分类模型,该模型是在具有 UCF-101 数据集的不同数据集上训练的。" ] }, { "cell_type": "markdown", "metadata": { "id": "PnpPjKVD68eH" }, "source": [ "## 安装\n", "\n", "首先,安装和导入一些必要的库,包括:用于检查 ZIP 文件内容的 [remotezip](https://github.com/gtsystem/python-remotezip),用于使用进度条的 [tqdm](https://github.com/tqdm/tqdm),用于处理视频文件的 [OpenCV](https://opencv.org/),以及用于在 Jupyter 笔记本中嵌入数据的 [`tensorflow_docs`](https://github.com/tensorflow/docs/tree/master/tools/tensorflow_docs)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SjI3AaaO16bd", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# The way this tutorial uses the `TimeDistributed` layer requires TF>=2.10\n", "!pip install -U \"tensorflow>=2.10.0\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P5SBasQcbwQA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!pip install remotezip tqdm opencv-python\n", "!pip install -q git+https://github.com/tensorflow/docs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9RYQIJ9C6BVH", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "import tqdm\n", "import random\n", "import pathlib\n", "import itertools\n", "import collections\n", "\n", "import os\n", "import cv2\n", "import numpy as np\n", "import remotezip as rz\n", "\n", "import tensorflow as tf\n", "\n", "# Some modules to display an animation using imageio.\n", "import imageio\n", "from IPython import display\n", "from urllib import request\n", "from tensorflow_docs.vis import embed" ] }, { "cell_type": "markdown", "metadata": { "id": "KbhwWLLM7FXo" }, "source": [ "## 下载 UCF101 数据集的子集\n", "\n", "[UCF101 数据集](https://tensorflow.google.cn/datasets/catalog/ucf101)包含 101 类不同动作的视频,主要用于动作识别。您将在此演示中使用这些类别的一个子集。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gVIgj-jIA8U8", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "URL = 'https://storage.googleapis.com/thumos14_files/UCF101_videos.zip'" ] }, { "cell_type": "markdown", "metadata": { "id": "2tm8aBzw6Md7" }, "source": [ "上面的网址包含一个带有 UCF 101 数据集的 ZIP 文件。创建一个使用 `remotezip` 库的函数来检查该 URL 中 ZIP 文件的内容:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lY-x7TaZlK6O", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def list_files_from_zip_url(zip_url):\n", " \"\"\" List the files in each class of the dataset given a URL with the zip file.\n", "\n", " Args:\n", " zip_url: A URL from which the files can be extracted from.\n", "\n", " Returns:\n", " List of files in each of the classes.\n", " \"\"\"\n", " files = []\n", " with rz.RemoteZip(zip_url) as zip:\n", " for zip_info in zip.infolist():\n", " files.append(zip_info.filename)\n", " return files" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lYErXAdUr-rk", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "files = list_files_from_zip_url(URL)\n", "files = [f for f in files if f.endswith('.avi')]\n", "files[:10]" ] }, { "cell_type": "markdown", "metadata": { "id": "rQ4l8D9dFPS7" }, "source": [ "先从几个视频和有限数量的类开始训练。运行上述代码块后,请注意类名包含在每个视频的文件名中。\n", "\n", "定义从文件名中检索类名的 `get_class` 函数。然后,创建一个名为 `get_files_per_class` 的函数,它会将所有文件的列表(上面的 `files`)转换为列出每个类的文件的字典:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yyyivOX0sO19", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_class(fname):\n", " \"\"\" Retrieve the name of the class given a filename.\n", "\n", " Args:\n", " fname: Name of the file in the UCF101 dataset.\n", "\n", " Returns:\n", " Class that the file belongs to.\n", " \"\"\"\n", " return fname.split('_')[-3]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1qnH0xKzlyw_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def get_files_per_class(files):\n", " \"\"\" Retrieve the files that belong to each class.\n", "\n", " Args:\n", " files: List of files in the dataset.\n", "\n", " Returns:\n", " Dictionary of class names (key) and files (values). \n", " \"\"\"\n", " files_for_class = collections.defaultdict(list)\n", " for fname in files:\n", " class_name = get_class(fname)\n", " files_for_class[class_name].append(fname)\n", " return files_for_class" ] }, { "cell_type": "markdown", "metadata": { "id": "VxSt5YgSGrWn" }, "source": [ "获得每个类的文件列表后,您可以选择要使用多少个类,以及每个类需要多少视频,以创建数据集。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qPdURg74uUTk", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "NUM_CLASSES = 10\n", "FILES_PER_CLASS = 50" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GUs0xtXsr9i3", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "files_for_class = get_files_per_class(files)\n", "classes = list(files_for_class.keys())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-YqFARvqwon9", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "print('Num classes:', len(classes))\n", "print('Num videos for class[0]:', len(files_for_class[classes[0]]))" ] }, { "cell_type": "markdown", "metadata": { "id": "yFAFqKqE92bQ" }, "source": [ "创建一个名为 `select_subset_of_classes` 的新函数,它会选择数据集中存在的类的子集并在每个类中选择特定数量的文件:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "O3jek4QimIj-", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def select_subset_of_classes(files_for_class, classes, files_per_class):\n", " \"\"\" Create a dictionary with the class name and a subset of the files in that class.\n", "\n", " Args:\n", " files_for_class: Dictionary of class names (key) and files (values).\n", " classes: List of classes.\n", " files_per_class: Number of files per class of interest.\n", "\n", " Returns:\n", " Dictionary with class as key and list of specified number of video files in that class.\n", " \"\"\"\n", " files_subset = dict()\n", "\n", " for class_name in classes:\n", " class_files = files_for_class[class_name]\n", " files_subset[class_name] = class_files[:files_per_class]\n", "\n", " return files_subset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5cjcz6Gpcb-W", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "files_subset = select_subset_of_classes(files_for_class, classes[:NUM_CLASSES], FILES_PER_CLASS)\n", "list(files_subset.keys())" ] }, { "cell_type": "markdown", "metadata": { "id": "ALrlDS1lZx3E" }, "source": [ "定义将视频拆分为训练集、验证集和测试集的辅助函数。视频从带有 ZIP 文件的网址下载,并放置在各自的子目录中。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AH9sWS_6nRz3", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def download_from_zip(zip_url, to_dir, file_names):\n", " \"\"\" Download the contents of the zip file from the zip URL.\n", "\n", " Args:\n", " zip_url: A URL with a zip file containing data.\n", " to_dir: A directory to download data to.\n", " file_names: Names of files to download.\n", " \"\"\"\n", " with rz.RemoteZip(zip_url) as zip:\n", " for fn in tqdm.tqdm(file_names):\n", " class_name = get_class(fn)\n", " zip.extract(fn, str(to_dir / class_name))\n", " unzipped_file = to_dir / class_name / fn\n", "\n", " fn = pathlib.Path(fn).parts[-1]\n", " output_file = to_dir / class_name / fn\n", " unzipped_file.rename(output_file)" ] }, { "cell_type": "markdown", "metadata": { "id": "pejRTChA6mrp" }, "source": [ "以下函数会返回尚未放入数据子集的剩余数据。它允许您将剩余的数据放在下一个指定的数据子集中。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6ARYc-WLqqNF", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def split_class_lists(files_for_class, count):\n", " \"\"\" Returns the list of files belonging to a subset of data as well as the remainder of\n", " files that need to be downloaded.\n", " \n", " Args:\n", " files_for_class: Files belonging to a particular class of data.\n", " count: Number of files to download.\n", "\n", " Returns:\n", " Files belonging to the subset of data and dictionary of the remainder of files that need to be downloaded.\n", " \"\"\"\n", " split_files = []\n", " remainder = {}\n", " for cls in files_for_class:\n", " split_files.extend(files_for_class[cls][:count])\n", " remainder[cls] = files_for_class[cls][count:]\n", " return split_files, remainder" ] }, { "cell_type": "markdown", "metadata": { "id": "LlEQ_I0TLd1X" }, "source": [ "下面的 `download_ufc_101_subset` 函数允许您下载 UCF101 数据集的子集并将其拆分为训练集、验证集和测试集。您可以指定要使用的类的数量。`splits` 参数允许您传入一个字典,其中键值是子集的名称(例如:“train”)和您希望每个类拥有的视频数量。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IHH2Y1M06xoz", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def download_ucf_101_subset(zip_url, num_classes, splits, download_dir):\n", " \"\"\" Download a subset of the UCF101 dataset and split them into various parts, such as\n", " training, validation, and test.\n", "\n", " Args:\n", " zip_url: A URL with a ZIP file with the data.\n", " num_classes: Number of labels.\n", " splits: Dictionary specifying the training, validation, test, etc. (key) division of data \n", " (value is number of files per split).\n", " download_dir: Directory to download data to.\n", "\n", " Return:\n", " Mapping of the directories containing the subsections of data.\n", " \"\"\"\n", " files = list_files_from_zip_url(zip_url)\n", " for f in files:\n", " path = os.path.normpath(f)\n", " tokens = path.split(os.sep)\n", " if len(tokens) <= 2:\n", " files.remove(f) # Remove that item from the list if it does not have a filename\n", " \n", " files_for_class = get_files_per_class(files)\n", "\n", " classes = list(files_for_class.keys())[:num_classes]\n", "\n", " for cls in classes:\n", " random.shuffle(files_for_class[cls])\n", " \n", " # Only use the number of classes you want in the dictionary\n", " files_for_class = {x: files_for_class[x] for x in classes}\n", "\n", " dirs = {}\n", " for split_name, split_count in splits.items():\n", " print(split_name, \":\")\n", " split_dir = download_dir / split_name\n", " split_files, files_for_class = split_class_lists(files_for_class, split_count)\n", " download_from_zip(zip_url, split_dir, split_files)\n", " dirs[split_name] = split_dir\n", "\n", " return dirs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NuD-xU8Q66Vm", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "download_dir = pathlib.Path('./UCF101_subset/')\n", "subset_paths = download_ucf_101_subset(URL,\n", " num_classes = NUM_CLASSES,\n", " splits = {\"train\": 30, \"val\": 10, \"test\": 10},\n", " download_dir = download_dir)" ] }, { "cell_type": "markdown", "metadata": { "id": "MBMRm9Ub3Zrk" }, "source": [ "下载数据后,您现在应该拥有了一个 UCF101 数据集子集的副本。运行以下代码即可打印您在所有数据子集中拥有的视频总数。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zupvOLYP4D4q", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "video_count_train = len(list(download_dir.glob('train/*/*.avi')))\n", "video_count_val = len(list(download_dir.glob('val/*/*.avi')))\n", "video_count_test = len(list(download_dir.glob('test/*/*.avi')))\n", "video_total = video_count_train + video_count_val + video_count_test\n", "print(f\"Total videos: {video_total}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "JmJG1SlXiOX8" }, "source": [ "您现在还可以预览数据文件的目录。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "y9be0WlDiNM0", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!find ./UCF101_subset" ] }, { "cell_type": "markdown", "metadata": { "id": "U4uslY4dScyu" }, "source": [ "## 从每个视频文件创建帧" ] }, { "cell_type": "markdown", "metadata": { "id": "D1vvyT0F7JAZ" }, "source": [ "`frames_from_video_file` 函数会将视频拆分为帧,从视频文件中读取随机选择的 `n_frames` 跨度,并将它们作为 NumPy `array` 返回。要减少内存和计算开销,请选择少量帧。此外,请从每个视频中选取**相同**数量的帧,这样可以更轻松地处理批量数据。\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vNBCiV3bMzpD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def format_frames(frame, output_size):\n", " \"\"\"\n", " Pad and resize an image from a video.\n", " \n", " Args:\n", " frame: Image that needs to resized and padded. \n", " output_size: Pixel size of the output frame image.\n", "\n", " Return:\n", " Formatted frame with padding of specified output size.\n", " \"\"\"\n", " frame = tf.image.convert_image_dtype(frame, tf.float32)\n", " frame = tf.image.resize_with_pad(frame, *output_size)\n", " return frame" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9ujLDC9G7JyE", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def frames_from_video_file(video_path, n_frames, output_size = (224,224), frame_step = 15):\n", " \"\"\"\n", " Creates frames from each video file present for each category.\n", "\n", " Args:\n", " video_path: File path to the video.\n", " n_frames: Number of frames to be created per video file.\n", " output_size: Pixel size of the output frame image.\n", "\n", " Return:\n", " An NumPy array of frames in the shape of (n_frames, height, width, channels).\n", " \"\"\"\n", " # Read each video frame by frame\n", " result = []\n", " src = cv2.VideoCapture(str(video_path)) \n", "\n", " video_length = src.get(cv2.CAP_PROP_FRAME_COUNT)\n", "\n", " need_length = 1 + (n_frames - 1) * frame_step\n", "\n", " if need_length > video_length:\n", " start = 0\n", " else:\n", " max_start = video_length - need_length\n", " start = random.randint(0, max_start + 1)\n", "\n", " src.set(cv2.CAP_PROP_POS_FRAMES, start)\n", " # ret is a boolean indicating whether read was successful, frame is the image itself\n", " ret, frame = src.read()\n", " result.append(format_frames(frame, output_size))\n", "\n", " for _ in range(n_frames - 1):\n", " for _ in range(frame_step):\n", " ret, frame = src.read()\n", " if ret:\n", " frame = format_frames(frame, output_size)\n", " result.append(frame)\n", " else:\n", " result.append(np.zeros_like(result[0]))\n", " src.release()\n", " result = np.array(result)[..., [2, 1, 0]]\n", "\n", " return result" ] }, { "cell_type": "markdown", "metadata": { "id": "1ENtlwhxwyTe" }, "source": [ "## 呈现视频数据\n", "\n", "`frames_from_video_file` 函数会将一组帧作为 NumPy 数组返回。尝试在 Patrick Gillett 的 [Wikimedia](https://commons.wikimedia.org/wiki/Category:Videos_of_sports){:.external} 的新视频中使用此函数:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z2hgSghlykzA", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "!curl -O https://upload.wikimedia.org/wikipedia/commons/8/86/End_of_a_jam.ogv" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xdHvHw3hym-U", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "video_path = \"End_of_a_jam.ogv\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u845YODXyqo5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "sample_video = frames_from_video_file(video_path, n_frames = 10)\n", "sample_video.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zFHGHiFgGjv2", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "def to_gif(images):\n", " converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)\n", " imageio.mimsave('./animation.gif', converted_images, fps=10)\n", " return embed.embed_file('./animation.gif')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7hiwUJenEN3p", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "to_gif(sample_video)" ] }, { "cell_type": "markdown", "metadata": { "id": "3dktTnDVG7xf" }, "source": [ "除了查看此视频外,您还可以显示 UCF-101 数据。为此,请运行以下代码:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MghJzJsWme0t", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# docs-infra: no-execute\n", "ucf_sample_video = frames_from_video_file(next(subset_paths['train'].glob('*/*.avi')), 50)\n", "to_gif(ucf_sample_video)" ] }, { "cell_type": "markdown", "metadata": { "id": "NlvuC5_E7XrF" }, "source": [ "接下来,定义 `FrameGenerator` 类以创建一个可迭代对象,该对象可以将数据输入 TensorFlow 数据流水线。生成器 (`__call__`) 函数产生由 `frames_from_video_file` 生成的帧数组以及与帧集相关联的标签的独热编码向量。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MVmfLTlw7Ues", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "class FrameGenerator:\n", " def __init__(self, path, n_frames, training = False):\n", " \"\"\" Returns a set of frames with their associated label. \n", "\n", " Args:\n", " path: Video file paths.\n", " n_frames: Number of frames. \n", " training: Boolean to determine if training dataset is being created.\n", " \"\"\"\n", " self.path = path\n", " self.n_frames = n_frames\n", " self.training = training\n", " self.class_names = sorted(set(p.name for p in self.path.iterdir() if p.is_dir()))\n", " self.class_ids_for_name = dict((name, idx) for idx, name in enumerate(self.class_names))\n", "\n", " def get_files_and_class_names(self):\n", " video_paths = list(self.path.glob('*/*.avi'))\n", " classes = [p.parent.name for p in video_paths] \n", " return video_paths, classes\n", "\n", " def __call__(self):\n", " video_paths, classes = self.get_files_and_class_names()\n", "\n", " pairs = list(zip(video_paths, classes))\n", "\n", " if self.training:\n", " random.shuffle(pairs)\n", "\n", " for path, name in pairs:\n", " video_frames = frames_from_video_file(path, self.n_frames) \n", " label = self.class_ids_for_name[name] # Encode labels\n", " yield video_frames, label" ] }, { "cell_type": "markdown", "metadata": { "id": "xsvhPIkpzx-r" }, "source": [ "在将 `FrameGenerator` 对象封装为 TensorFlow Dataset 对象之前对其进行测试。此外,对于训练数据集,请确保启用训练模式,以便对数据进行重排。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P5jwagZxzxOf", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "fg = FrameGenerator(subset_paths['train'], 10, training=True)\n", "\n", "frames, label = next(fg())\n", "\n", "print(f\"Shape: {frames.shape}\")\n", "print(f\"Label: {label}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "E7MRRFSks7l1" }, "source": [ "最后,创建一个 TensorFlow 数据输入流水线。您从生成器对象创建的此流水线允许您将数据输入深度学习模型。在此视频流水线中,每个元素都是一组单独的帧及其关联标签。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "HM4NboJr7ck4", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create the training set\n", "output_signature = (tf.TensorSpec(shape = (None, None, None, 3), dtype = tf.float32),\n", " tf.TensorSpec(shape = (), dtype = tf.int16))\n", "train_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths['train'], 10, training=True),\n", " output_signature = output_signature)" ] }, { "cell_type": "markdown", "metadata": { "id": "9oF_8m8IZvcY" }, "source": [ "检查标签是否重排。 " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3XYVmsgiZsJD", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "for frames, labels in train_ds.take(10):\n", " print(labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pi8-WkOkEXw5", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Create the validation set\n", "val_ds = tf.data.Dataset.from_generator(FrameGenerator(subset_paths['val'], 10),\n", " output_signature = output_signature)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V6qXc-6i7eyK", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "# Print the shapes of the data\n", "train_frames, train_labels = next(iter(train_ds))\n", "print(f'Shape of training set of frames: {train_frames.shape}')\n", "print(f'Shape of training labels: {train_labels.shape}')\n", "\n", "val_frames, val_labels = next(iter(val_ds))\n", "print(f'Shape of validation set of frames: {val_frames.shape}')\n", "print(f'Shape of validation labels: {val_labels.shape}')" ] }, { "cell_type": "markdown", "metadata": { "id": "bIrFpUIxvTLe" }, "source": [ "## 配置数据集以提高性能\n", "\n", "使用缓冲预提取,以便从磁盘产生数据,而不会阻塞 I/O。下面是可以在加载数据时使用的两个重要函数:\n", "\n", "- `Dataset.cache`:在第一个周期期间从磁盘加载图像后,它会将这些图像保留在内存中。该函数确保在训练模型时数据集不会成为瓶颈。如果数据集太大无法装入内存,您也可以使用此方法创建高性能的磁盘缓存。\n", "\n", "- `Dataset.prefetch`:在训练时重叠数据预处理和模型执行。有关详细信息,请参阅[使用 `tf.data` 提升性能](https://tensorflow.google.cn/guide/data_performance)。" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QSxjFtxAvY3_", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "AUTOTUNE = tf.data.AUTOTUNE\n", "\n", "train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size = AUTOTUNE)\n", "val_ds = val_ds.cache().shuffle(1000).prefetch(buffer_size = AUTOTUNE)" ] }, { "cell_type": "markdown", "metadata": { "id": "VaY-hyr-Fbfr" }, "source": [ "要准备馈送到模型的数据,请使用批处理,如下所示。请注意,在处理视频数据(例如 AVI 文件)时,数据应形成五维对象。这些维度如下:`[batch_size, number_of_frames, height, width, channels]`。相比之下,图像将具有四个维度:`[batch_size, height, width, channels]`。下图说明了如何表示视频数据的形状。\n", "\n", "![视频数据形状](https://tensorflow.google.cn/images/tutorials/video/video_data_shape.png)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pp2Qc6XSFmeB", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "train_ds = train_ds.batch(2)\n", "val_ds = val_ds.batch(2)\n", "\n", "train_frames, train_labels = next(iter(train_ds))\n", "print(f'Shape of training set of frames: {train_frames.shape}')\n", "print(f'Shape of training labels: {train_labels.shape}')\n", "\n", "val_frames, val_labels = next(iter(val_ds))\n", "print(f'Shape of validation set of frames: {val_frames.shape}')\n", "print(f'Shape of validation labels: {val_labels.shape}')" ] }, { "cell_type": "markdown", "metadata": { "id": "hqjXn1FgsMqZ" }, "source": [ "## 后续步骤\n", "\n", "现在,您已经创建了带有标签的视频帧的 TensorFlow `Dataset`,您可以将其与深度学习模型一起使用。以下使用预训练的 [EfficientNet](https://arxiv.org/abs/1905.11946){:.external} 的分类模型可在几分钟内训练到较高准确率:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qzqgPBUuForj", "vscode": { "languageId": "python" } }, "outputs": [], "source": [ "net = tf.keras.applications.EfficientNetB0(include_top = False)\n", "net.trainable = False\n", "\n", "model = tf.keras.Sequential([\n", " tf.keras.layers.Rescaling(scale=255),\n", " tf.keras.layers.TimeDistributed(net),\n", " tf.keras.layers.Dense(10),\n", " tf.keras.layers.GlobalAveragePooling3D()\n", "])\n", "\n", "model.compile(optimizer = 'adam',\n", " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),\n", " metrics=['accuracy'])\n", "\n", "model.fit(train_ds, \n", " epochs = 10,\n", " validation_data = val_ds,\n", " callbacks = tf.keras.callbacks.EarlyStopping(patience = 2, monitor = 'val_loss'))" ] }, { "cell_type": "markdown", "metadata": { "id": "DdJm7ojgGxtT" }, "source": [ "要详细了解如何在 TensorFlow 中处理视频数据,请查看以下教程:\n", "\n", "- [构建用于视频分类的 3D CNN 模型](https://tensorflow.google.cn/tutorials/video/video_classification)\n", "- [用于流式动作识别的 MoViNet](https://tensorflow.google.cn/hub/tutorials/movinet)\n", "- [使用 MoViNet 进行视频分类的迁移学习](https://tensorflow.google.cn/tutorials/video/transfer_learning_with_movinet)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "video.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }