{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "tDnwEv8FtJm7" }, "outputs": [], "source": [ "##### Copyright 2018 The TensorFlow Authors." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "JlknJBWQtKkI" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 设置日志级别为ERROR,以减少警告信息\n", "# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告\n", "os.environ[\"GRPC_VERBOSITY\"] = \"ERROR\"\n", "os.environ[\"GLOG_minloglevel\"] = \"3\" # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL\n", "os.environ[\"GLOG_minloglevel\"] = \"true\"\n", "import logging\n", "import tensorflow as tf\n", "tf.get_logger().setLevel(logging.ERROR)\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true\n", "from pathlib import Path\n", "\n", "temp_dir = Path(\".temp\")\n", "temp_dir.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "60RdWsg1tETW" }, "source": [ "# 自定义层" ] }, { "cell_type": "markdown", "metadata": { "id": "BcJg7Enms86w" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
Model: \"\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ conv2d (Conv2D) │ (1, 2, 3, 1) │ 4 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization │ (1, 2, 3, 1) │ 4 │\n", "│ (BatchNormalization) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_1 (Conv2D) │ (1, 2, 3, 2) │ 4 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_1 │ (1, 2, 3, 2) │ 8 │\n", "│ (BatchNormalization) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_2 (Conv2D) │ (1, 2, 3, 3) │ 9 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_2 │ (1, 2, 3, 3) │ 12 │\n", "│ (BatchNormalization) │ │ │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ conv2d (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │\n", "│ (\u001b[38;5;33mBatchNormalization\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_1 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_1 │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │\n", "│ (\u001b[38;5;33mBatchNormalization\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_2 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m9\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_2 │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m12\u001b[0m │\n", "│ (\u001b[38;5;33mBatchNormalization\u001b[0m) │ │ │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 41 (164.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m41\u001b[0m (164.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 29 (116.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m29\u001b[0m (116.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 12 (48.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m12\u001b[0m (48.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "block.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "wYfucVw65PMj" }, "source": [ "但是,在很多时候,由多个层组合而成的模型只需要逐一地调用各层。为此,使用 `tf.keras.Sequential` 只需少量代码即可完成:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
Model: \"sequential\"\n",
"
\n"
],
"text/plain": [
"\u001b[1mModel: \"sequential\"\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ conv2d_3 (Conv2D) │ (1, 2, 3, 1) │ 4 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_3 │ (1, 2, 3, 1) │ 4 │\n", "│ (BatchNormalization) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_4 (Conv2D) │ (1, 2, 3, 2) │ 4 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_4 │ (1, 2, 3, 2) │ 8 │\n", "│ (BatchNormalization) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_5 (Conv2D) │ (1, 2, 3, 3) │ 9 │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_5 │ (1, 2, 3, 3) │ 12 │\n", "│ (BatchNormalization) │ │ │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", "│ conv2d_3 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_3 │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │\n", "│ (\u001b[38;5;33mBatchNormalization\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_4 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m4\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_4 │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m2\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │\n", "│ (\u001b[38;5;33mBatchNormalization\u001b[0m) │ │ │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ conv2d_5 (\u001b[38;5;33mConv2D\u001b[0m) │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m9\u001b[0m │\n", "├─────────────────────────────────┼────────────────────────┼───────────────┤\n", "│ batch_normalization_5 │ (\u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m2\u001b[0m, \u001b[38;5;34m3\u001b[0m, \u001b[38;5;34m3\u001b[0m) │ \u001b[38;5;34m12\u001b[0m │\n", "│ (\u001b[38;5;33mBatchNormalization\u001b[0m) │ │ │\n", "└─────────────────────────────────┴────────────────────────┴───────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Total params: 41 (164.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m41\u001b[0m (164.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 29 (116.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m29\u001b[0m (116.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Non-trainable params: 12 (48.00 B)\n", "\n" ], "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m12\u001b[0m (48.00 B)\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "my_seq.summary()" ] }, { "cell_type": "markdown", "metadata": { "id": "c5YwYcnuK-wc" }, "source": [ "# 后续步骤\n", "\n", "现在,您可以回到上一个笔记本,调整线性回归样本以使用结构更好的层和模型。" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "custom_layers.ipynb", "toc_visible": true }, "kernelspec": { "display_name": "xxx", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 0 }