##### Copyright 2021 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
将 tf.feature_column
迁移到 Keras 预处理层#
在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
训练模型通常会伴随一些特征预处理,尤其是在处理结构化数据时。在 TensorFlow 1 中训练 tf.estimator.Estimator
时,通常使用 tf.feature_column
API 执行特征预处理。在 TensorFlow 2 中,您可以直接使用 Keras 预处理层执行此操作。
本迁移指南演示了使用特征列和预处理层的常见特征转换,然后使用这两种 API 训练一个完整的模型。
首先,从几个必要的导入开始:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import math
接下来,添加一个用于调用特征列的效用函数进行演示:
def call_feature_columns(feature_columns, inputs):
# This is a convenient way to call a `feature_column` outside of an estimator
# to display its output.
feature_layer = tf1.keras.layers.DenseFeatures(feature_columns)
return feature_layer(inputs)
输入处理#
要将特征列与 Estimator 一起使用,模型输入始终应为张量的字典:
input_dict = {
'foo': tf.constant([1]),
'bar': tf.constant([0]),
'baz': tf.constant([-1])
}
每个特征列都需要有一个键来索引到源数据。所有特征列的输出串联并由 Estimator 模型使用。
columns = [
tf1.feature_column.numeric_column('foo'),
tf1.feature_column.numeric_column('bar'),
tf1.feature_column.numeric_column('baz'),
]
call_feature_columns(columns, input_dict)
在 Keras 中,模型输入更加灵活。tf.keras.Model
可以处理单个张量输入、张量特征列表或张量特征字典。可以通过在模型创建时传递 tf.keras.Input
的字典来处理字典输入。输入不会自动串联,这样它们便能以更灵活的方式使用。它们可以与 tf.keras.layers.Concatenate
串联。
inputs = {
'foo': tf.keras.Input(shape=()),
'bar': tf.keras.Input(shape=()),
'baz': tf.keras.Input(shape=()),
}
# Inputs are typically transformed by preprocessing layers before concatenation.
outputs = tf.keras.layers.Concatenate()(inputs.values())
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model(input_dict)
独热编码整数 ID#
常见的特征转换是对已知范围内的整数输入进行独热编码。下面是一个使用特征列的示例:
categorical_col = tf1.feature_column.categorical_column_with_identity(
'type', num_buckets=3)
indicator_col = tf1.feature_column.indicator_column(categorical_col)
call_feature_columns(indicator_col, {'type': [0, 1, 2]})
利用 Keras 预处理层,这些列可被替换为单个 tf.keras.layers.CategoryEncoding
层,其中 output_mode
设置为 'one_hot'
:
one_hot_layer = tf.keras.layers.CategoryEncoding(
num_tokens=3, output_mode='one_hot')
one_hot_layer([0, 1, 2])
注:对于大型独热编码,使用输出的稀疏表示会更高效。如果将 sparse=True
传递给 CategoryEncoding
层,则该层的输出将是 tf.sparse.SparseTensor
,它可以作为 tf.keras.layers.Dense
层的输入高效地处理。
归一化数字特征#
在处理具有特征列的连续浮点特征时,需要使用 tf.feature_column.numeric_column
。在输入已经归一化的情况下,将其转换为 Keras 的操作十分简单。可以直接在模型中使用 tf.keras.Input
,如上面所示。
numeric_column
也可用于归一化输入:
def normalize(x):
mean, variance = (2.0, 1.0)
return (x - mean) / math.sqrt(variance)
numeric_col = tf1.feature_column.numeric_column('col', normalizer_fn=normalize)
call_feature_columns(numeric_col, {'col': tf.constant([[0.], [1.], [2.]])})
相比之下,使用 Keras,这种归一化可以使用 tf.keras.layers.Normalization
完成。
normalization_layer = tf.keras.layers.Normalization(mean=2.0, variance=1.0)
normalization_layer(tf.constant([[0.], [1.], [2.]]))
对数字特征进行分桶和独热编码#
连续浮点输入的另一种常见转换是分桶为固定范围的整数。
在特征列中,可以使用 tf.feature_column.bucketized_column
实现:
numeric_col = tf1.feature_column.numeric_column('col')
bucketized_col = tf1.feature_column.bucketized_column(numeric_col, [1, 4, 5])
call_feature_columns(bucketized_col, {'col': tf.constant([1., 2., 3., 4., 5.])})
在 Keras 中,可以使用 tf.keras.layers.Discretization
代替:
discretization_layer = tf.keras.layers.Discretization(bin_boundaries=[1, 4, 5])
one_hot_layer = tf.keras.layers.CategoryEncoding(
num_tokens=4, output_mode='one_hot')
one_hot_layer(discretization_layer([1., 2., 3., 4., 5.]))
使用词汇表对字符串数据进行独热编码#
处理字符串特征通常需要词汇查找来将字符串转换为索引。下面是一个使用特征列查找字符串,然后对索引进行独热编码的示例:
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'sizes',
vocabulary_list=['small', 'medium', 'large'],
num_oov_buckets=0)
indicator_col = tf1.feature_column.indicator_column(vocab_col)
call_feature_columns(indicator_col, {'sizes': ['small', 'medium', 'large']})
利用 Keras 预处理层,可以使用 tf.keras.layers.StringLookup
层,并将 output_mode
设置为 'one_hot'
:
string_lookup_layer = tf.keras.layers.StringLookup(
vocabulary=['small', 'medium', 'large'],
num_oov_indices=0,
output_mode='one_hot')
string_lookup_layer(['small', 'medium', 'large'])
注:对于大型独热编码,使用输出的稀疏表示会更高效。如果将 sparse=True
传递给 StringLookup
层,则该层的输出将是 tf.sparse.SparseTensor
,它可以作为 tf.keras.layers.Dense
层的输入高效地处理。
使用词汇表嵌入字符串数据#
对于较大的词汇表,通常需要嵌入向量才能获得良好的性能。下面是一个使用特征列嵌入字符串特征的示例:
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'col',
vocabulary_list=['small', 'medium', 'large'],
num_oov_buckets=0)
embedding_col = tf1.feature_column.embedding_column(vocab_col, 4)
call_feature_columns(embedding_col, {'col': ['small', 'medium', 'large']})
利用 Keras 预处理层,可以通过组合 tf.keras.layers.StringLookup
层和 tf.keras.layers.Embedding
层来实现。StringLookup
的默认输出将是可直接馈送到嵌入向量中的整数索引。
注:Embedding
层包含可训练参数。虽然 StringLookup
层可以应用于模型内部或外部的数据,但 Embedding
必须始终是可训练 Keras 模型的一部分才能正确训练。
string_lookup_layer = tf.keras.layers.StringLookup(
vocabulary=['small', 'medium', 'large'], num_oov_indices=0)
embedding = tf.keras.layers.Embedding(3, 4)
embedding(string_lookup_layer(['small', 'medium', 'large']))
对加权分类数据求和#
在某些情况下,您需要处理分类数据,其中类别的每次出现都附带关联的权重。在特征列中,这由 tf.feature_column.weighted_categorical_column
处理。与 indicator_column
配对时,效果是对每个类别的权重求和。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
categorical_col = tf1.feature_column.categorical_column_with_identity(
'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
categorical_col, 'weights')
indicator_col = tf1.feature_column.indicator_column(weighted_categorical_col)
call_feature_columns(indicator_col, {'ids': ids, 'weights': weights})
在 Keras 中,这可以通过将 count_weights
输入传递给 tf.keras.layers.CategoryEncoding
来完成,其中 output_mode='count'
。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
# Using sparse output is more efficient when `num_tokens` is large.
count_layer = tf.keras.layers.CategoryEncoding(
num_tokens=20, output_mode='count', sparse=True)
tf.sparse.to_dense(count_layer(ids, count_weights=weights))
嵌入加权分类数据#
您可能还想嵌入加权分类输入。在特征列中, embedding_column
包含 combiner
参数。如果任何样本包含一个类别的多个条目,则它们将根据参数设置进行组合(默认为 'mean'
)。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
categorical_col = tf1.feature_column.categorical_column_with_identity(
'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
categorical_col, 'weights')
embedding_col = tf1.feature_column.embedding_column(
weighted_categorical_col, 4, combiner='mean')
call_feature_columns(embedding_col, {'ids': ids, 'weights': weights})
在 Keras 中,tf.keras.layers.Embedding
没有 combiner
选项,但可以使用 tf.keras.layers.Dense
实现相同的效果。上面的 embedding_column
只是根据类别权重线性组合嵌入向量。虽然一开始并不明显,但它完全等效于将您的分类输入表示为大小为 (num_tokens)
的稀疏权重向量,随后将它们乘以形状为 (embedding_size, num_tokens)
的 Dense
内核 。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
# For `combiner='mean'`, normalize your weights to sum to 1. Removing this line
# would be equivalent to an `embedding_column` with `combiner='sum'`.
weights = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)
count_layer = tf.keras.layers.CategoryEncoding(
num_tokens=20, output_mode='count', sparse=True)
embedding_layer = tf.keras.layers.Dense(4, use_bias=False)
embedding_layer(count_layer(ids, count_weights=weights))
完整的训练示例#
为了展示完整的训练工作流,首先准备一些具有三种不同类型特征的数据:
features = {
'type': [0, 1, 1],
'size': ['small', 'small', 'medium'],
'weight': [2.7, 1.8, 1.6],
}
labels = [1, 1, 0]
predict_features = {'type': [0], 'size': ['foo'], 'weight': [-0.7]}
为 TensorFlow 1 和 TensorFlow 2 工作流定义一些通用常量:
vocab = ['small', 'medium', 'large']
one_hot_dims = 3
embedding_dims = 4
weight_mean = 2.0
weight_variance = 1.0
使用特征列#
特征列在创建时必须作为列表传递给 Estimator,并在训练期间隐式调用。
categorical_col = tf1.feature_column.categorical_column_with_identity(
'type', num_buckets=one_hot_dims)
# Convert index to one-hot; e.g. [2] -> [0,0,1].
indicator_col = tf1.feature_column.indicator_column(categorical_col)
# Convert strings to indices; e.g. ['small'] -> [1].
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'size', vocabulary_list=vocab, num_oov_buckets=1)
# Embed the indices.
embedding_col = tf1.feature_column.embedding_column(vocab_col, embedding_dims)
normalizer_fn = lambda x: (x - weight_mean) / math.sqrt(weight_variance)
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
numeric_col = tf1.feature_column.numeric_column(
'weight', normalizer_fn=normalizer_fn)
estimator = tf1.estimator.DNNClassifier(
feature_columns=[indicator_col, embedding_col, numeric_col],
hidden_units=[1])
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
estimator.train(_input_fn)
在模型上运行推断时,特征列也将用于转换输入数据。
def _predict_fn():
return tf1.data.Dataset.from_tensor_slices(predict_features).batch(1)
next(estimator.predict(_predict_fn))
使用 Keras 预处理层#
Keras 预处理层在调用它们的位置上更加灵活。层可以直接应用于张量,在 tf.data
输入流水线内使用,或者直接构建到可训练的 Keras 模型中。
在此示例中,您将在 tf.data
输入流水线中应用预处理层。为此,可以定义一个单独的 tf.keras.Model
来预处理您的输入特征。此模型不可训练,但可以方便地对预处理层进行分组。
inputs = {
'type': tf.keras.Input(shape=(), dtype='int64'),
'size': tf.keras.Input(shape=(), dtype='string'),
'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Convert index to one-hot; e.g. [2] -> [0,0,1].
type_output = tf.keras.layers.CategoryEncoding(
one_hot_dims, output_mode='one_hot')(inputs['type'])
# Convert size strings to indices; e.g. ['small'] -> [1].
size_output = tf.keras.layers.StringLookup(vocabulary=vocab)(inputs['size'])
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
weight_output = tf.keras.layers.Normalization(
axis=None, mean=weight_mean, variance=weight_variance)(inputs['weight'])
outputs = {
'type': type_output,
'size': size_output,
'weight': weight_output,
}
preprocessing_model = tf.keras.Model(inputs, outputs)
注:作为在层创建时提供词汇表和归一化统计信息的替代方式,许多预处理层提供了一个 adapt()
方法,可用于直接从输入数据学习层状态。请参阅预处理指南 ,了解更多详细信息。
您现在可以在对 tf.data.Dataset.map
的调用中应用此模型。请注意,传递给 map
的函数将自动转换为 tf.function
,并且用于编写 tf.function
代码的通常注意事项适用(无副作用)。
# Apply the preprocessing in tf.data.Dataset.map.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
dataset = dataset.map(lambda x, y: (preprocessing_model(x), y),
num_parallel_calls=tf.data.AUTOTUNE)
# Display a preprocessed input sample.
next(dataset.take(1).as_numpy_iterator())
接下来,可以定义一个包含可训练层的单独 Model
。请注意此模型的输入现在如何反映预处理的特征类型和形状。
inputs = {
'type': tf.keras.Input(shape=(one_hot_dims,), dtype='float32'),
'size': tf.keras.Input(shape=(), dtype='int64'),
'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Since the embedding is trainable, it needs to be part of the training model.
embedding = tf.keras.layers.Embedding(len(vocab), embedding_dims)
outputs = tf.keras.layers.Concatenate()([
inputs['type'],
embedding(inputs['size']),
tf.expand_dims(inputs['weight'], -1),
])
outputs = tf.keras.layers.Dense(1)(outputs)
training_model = tf.keras.Model(inputs, outputs)
您现在可以使用 tf.keras.Model.fit
训练 training_model
。
# Train on the preprocessed data.
training_model.compile(
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
training_model.fit(dataset)
最后,在推断时,不妨将这些单独的阶段组合成处理原始特征输入的单一模型。
inputs = preprocessing_model.input
outputs = training_model(preprocessing_model(inputs))
inference_model = tf.keras.Model(inputs, outputs)
predict_dataset = tf.data.Dataset.from_tensor_slices(predict_features).batch(1)
inference_model.predict(predict_dataset)
可以将此组合模型保存为 .keras 文件供以后使用。
inference_model.save('model.keras')
restored_model = tf.keras.models.load_model('model.keras')
restored_model.predict(predict_dataset)
注:预处理层不可训练,这允许您使用 tf.data
异步应用它们。这样做可以获得性能优势,因为您既可以预提取预处理的批次,又可以释放任何加速器以专注于模型的可微分部分(请在使用 tf.data
API 提升性能指南的预提取部分了解更多信息)。如本指南中所示,在训练期间分离预处理并在推断期间组合预处理是一种利用这些性能提升的灵活方式。但是,如果您的模型很小或预处理时间可以忽略不计,那么从一开始就将预处理构建为一个完整的模型可能会更简单。为此,您可以从 tf.keras.Input
开始构建单一模型,接着是预处理层,最后是可训练层。
特征列对应关系表#
作为参考,下面是特征列和 Keras 预处理层之间的大致对应关系:
特征列 | Keras 层 |
---|---|
`tf.feature_column.bucketized_column` | `tf.keras.layers.Discretization` |
`tf.feature_column.categorical_column_with_hash_bucket` | `tf.keras.layers.Hashing` |
`tf.feature_column.categorical_column_with_identity` | `tf.keras.layers.CategoryEncoding` |
`tf.feature_column.categorical_column_with_vocabulary_file` | `tf.keras.layers.StringLookup` 或 `tf.keras.layers.IntegerLookup` |
`tf.feature_column.categorical_column_with_vocabulary_list` | `tf.keras.layers.StringLookup` 或 `tf.keras.layers.IntegerLookup` |
`tf.feature_column.crossed_column` | `tf.keras.layers.experimental.preprocessing.HashedCrossing` |
`tf.feature_column.embedding_column` | `tf.keras.layers.Embedding` |
`tf.feature_column.indicator_column` | `output_mode='one_hot'` 或 `output_mode='multi_hot'`* |
`tf.feature_column.numeric_column` | `tf.keras.layers.Normalization` |
`tf.feature_column.sequence_categorical_column_with_hash_bucket` | `tf.keras.layers.Hashing` |
`tf.feature_column.sequence_categorical_column_with_identity` | `tf.keras.layers.CategoryEncoding` |
`tf.feature_column.sequence_categorical_column_with_vocabulary_file` | `tf.keras.layers.StringLookup`、`tf.keras.layers.IntegerLookup` 或 `tf.keras.layer.TextVectorization`† |
`tf.feature_column.sequence_categorical_column_with_vocabulary_list` | `tf.keras.layers.StringLookup`、`tf.keras.layers.IntegerLookup` 或 `tf.keras.layer.TextVectorization`† |
`tf.feature_column.sequence_numeric_column` | `tf.keras.layers.Normalization` |
`tf.feature_column.weighted_categorical_column` | `tf.keras.layers.CategoryEncoding` |
output_mode
可以传递给 tf.keras.layers.CategoryEncoding
、tf.keras.layers.StringLookup
、tf.keras.layers.IntegerLookup
和 tf.keras.layers.TextVectorization
。
† tf.keras.layers.TextVectorization
可以直接处理自由格式的文本输入(例如,整个句子或段落)。这不是 TensorFlow 1 中分类序列处理的一对一替代,但可以为临时文本预处理提供方便的替代。
注:线性 Estimator(例如 tf.estimator.LinearClassifier
)可以在没有 embedding_column
或 indicator_column
的情况下处理直接分类输入(整数索引)。但是,整数索引不能直接传递给 tf.keras.layers.Dense
或 tf.keras.experimental.LinearModel
。在调用 Dense
或 LinearModel
之前,应当首先使用 tf.layers.CategoryEncoding
对这些输入进行编码,其中 output_mode='count'
(如果类别大小很大,则为 sparse=True
)。
后续步骤#
有关 Keras 预处理层的更多信息,请转到使用预处理层指南。
有关将预处理层应用于结构化数据的更深入示例,请参阅使用 Keras 预处理层对结构化数据进行分类教程。