{ "cells": [ { "cell_type": "markdown", "metadata": { "cellView": "form", "id": "BZSlp3DAjdYf" }, "source": [ "````{admonition} Copyright 2019 The TensorFlow Authors.\n", "```\n", "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License.\n", "```\n", "````" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 设置日志级别为ERROR,以减少警告信息\n", "# 禁用 Gemini 的底层库(gRPC 和 Abseil)在初始化日志警告\n", "os.environ[\"GRPC_VERBOSITY\"] = \"ERROR\"\n", "os.environ[\"GLOG_minloglevel\"] = \"3\" # 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL\n", "os.environ[\"GLOG_minloglevel\"] = \"true\"\n", "import logging\n", "import tensorflow as tf\n", "tf.get_logger().setLevel(logging.ERROR)\n", "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n", "!export TF_FORCE_GPU_ALLOW_GROWTH=true" ] }, { "cell_type": "markdown", "metadata": { "id": "3wF5wszaj97Y" }, "source": [ "# 专家的 TensorFlow 2 快速入门" ] }, { "cell_type": "markdown", "metadata": { "id": "DUNzJc4jTj6G" }, "source": [ "
![]() | \n",
" ![]() | \n",
" ![]() | \n",
" ![]() | \n",
"
tf.keras
模型:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "h3IKyzTCDNGo"
},
"outputs": [],
"source": [
"class MyModel(Model):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = Conv2D(32, 3, activation='relu')\n",
" self.flatten = Flatten()\n",
" self.d1 = Dense(128, activation='relu')\n",
" self.d2 = Dense(10)\n",
"\n",
" def call(self, x):\n",
" x = self.conv1(x)\n",
" x = self.flatten(x)\n",
" x = self.d1(x)\n",
" return self.d2(x)\n",
"\n",
"# Create an instance of the model\n",
"model = MyModel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uGih-c2LgbJu"
},
"source": [
"选择用于训练的优化器和损失函数: "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "u48C9WQ774n4"
},
"outputs": [],
"source": [
"loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"\n",
"optimizer = tf.keras.optimizers.Adam()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JB6A1vcigsIe"
},
"source": [
"选择指标来衡量模型的损失和准确率。这些指标在周期内累积值,然后打印总体结果。"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "N0MqHFb4F_qn"
},
"outputs": [],
"source": [
"train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
"train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')\n",
"\n",
"test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
"test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ix4mEL65on-w"
},
"source": [
"使用 `tf.GradientTape`训练模型:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "OZACiVqA8KQV"
},
"outputs": [],
"source": [
"@tf.function\n",
"def train_step(images, labels):\n",
" with tf.GradientTape() as tape:\n",
" # training=True is only needed if there are layers with different\n",
" # behavior during training versus inference (e.g. Dropout).\n",
" predictions = model(images, training=True)\n",
" loss = loss_object(labels, predictions)\n",
" gradients = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
"\n",
" train_loss(loss)\n",
" train_accuracy(labels, predictions)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z8YT7UmFgpjV"
},
"source": [
"测试模型:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "xIKdEzHAJGt7"
},
"outputs": [],
"source": [
"@tf.function\n",
"def test_step(images, labels):\n",
" # training=False is only needed if there are layers with different\n",
" # behavior during training versus inference (e.g. Dropout).\n",
" predictions = model(images, training=False)\n",
" t_loss = loss_object(labels, predictions)\n",
"\n",
" test_loss(t_loss)\n",
" test_accuracy(labels, predictions)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "i-2pkctU_Ci7"
},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'Mean' object has no attribute 'reset_states'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[10], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m EPOCHS \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m5\u001b[39m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(EPOCHS):\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Reset the metrics at the start of the next epoch\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m train_loss\u001b[38;5;241m.\u001b[39mreset_states()\n\u001b[1;32m 6\u001b[0m train_accuracy\u001b[38;5;241m.\u001b[39mreset_states()\n\u001b[1;32m 7\u001b[0m test_loss\u001b[38;5;241m.\u001b[39mreset_states()\n",
"\u001b[0;31mAttributeError\u001b[0m: 'Mean' object has no attribute 'reset_states'"
]
}
],
"source": [
"EPOCHS = 5\n",
"\n",
"for epoch in range(EPOCHS):\n",
" # Reset the metrics at the start of the next epoch\n",
" train_loss.reset_states()\n",
" train_accuracy.reset_states()\n",
" test_loss.reset_states()\n",
" test_accuracy.reset_states()\n",
"\n",
" for images, labels in train_ds:\n",
" train_step(images, labels)\n",
"\n",
" for test_images, test_labels in test_ds:\n",
" test_step(test_images, test_labels)\n",
"\n",
" print(\n",
" f'Epoch {epoch + 1}, '\n",
" f'Loss: {train_loss.result()}, '\n",
" f'Accuracy: {train_accuracy.result() * 100}, '\n",
" f'Test Loss: {test_loss.result()}, '\n",
" f'Test Accuracy: {test_accuracy.result() * 100}'\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T4JfEh7kvx6m"
},
"source": [
"现在,经过训练,照片分类器在此数据集上的准确率约为 98%。要了解详情,请阅读 [TensorFlow 教程](https://tensorflow.google.cn/tutorials)。"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "advanced.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "xxx",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}