##### Copyright 2023 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

使用 JAX2TF 导入 JAX 模型#

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 Github 上查看源代码 下载笔记本

此笔记本提供了一个完整的可运行示例,说明如何使用 JAX 创建模型并将其导入 TensorFlow 以继续训练。这是通过 JAX2TF 实现的,JAX2TF 是一种轻量级 API,提供了从 JAX 生态系统到 TensorFlow 生态系统的途径。

JAX 是一个高性能数组计算库。为了创建模型,此笔记本使用 Flax,这是一种用于 JAX 的神经网络库。为了进行训练,它使用了 Optax,这是一种用于 JAX 的优化库。

如果您是使用 JAX 的研究人员,JAX2TF 为您提供了一条使用TensorFlow 成熟工具进行生产的路径。

这可以有很多用途,以下只是其中几个:

  • 推断:采用为 JAX 编写的模型,并使用 TF Serving 将其部署在服务器上、使用 TFLite 在设备上部署或使用 TensorFlow.js 在 Web 上部署。

  • 微调:采用使用 JAX 训练的模型,您可以使用 JAX2TF 将其组件引入 TF,并使用您现有的训练数据和设置在 TensorFlow 中继续训练。

  • 融合:将使用 JAX 训练的模型部分与使用 TensorFlow 训练的模型部分相结合,以获得最大的灵活性。

在 JAX 与 TensorFlow 之间实现这种互操作的关键是 jax2tf.convert,它接受在 JAX 上创建的模型组件(您的损失函数、预测函数等)并作为 TensorFlow 函数创建它们的等效表示,然后可以将其导出为 TensorFlow SavedModel。

安装#

import tensorflow as tf
import numpy as np
import jax
import jax.numpy as jnp
import flax
import optax
import os
from matplotlib import pyplot as plt
from jax.experimental import jax2tf
from threading import Lock # Only used in the visualization utility.
from functools import partial
# Needed for TensorFlow and JAX to coexist in GPU memory.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized.
    print(e)
#@title Visualization utilities

plt.rcParams["figure.figsize"] = (20,8)

# The utility for displaying training and validation curves.
def display_train_curves(loss, avg_loss, eval_loss, eval_accuracy, epochs, steps_per_epochs, ignore_first_n=10):

  ignore_first_n_epochs = int(ignore_first_n/steps_per_epochs)

  # The losses.
  ax = plt.subplot(121)
  if loss is not None:
    x = np.arange(len(loss)) / steps_per_epochs #* epochs
    ax.plot(x, loss)
  ax.plot(range(1, epochs+1), avg_loss, "-o", linewidth=3)
  ax.plot(range(1, epochs+1), eval_loss, "-o", linewidth=3)
  ax.set_title('Loss')
  ax.set_ylabel('loss')
  ax.set_xlabel('epoch')
  if loss is not None:
    ax.set_ylim(0, np.max(loss[ignore_first_n:]))
    ax.legend(['train', 'avg train', 'eval'])
  else:
    ymin = np.min(avg_loss[ignore_first_n_epochs:])
    ymax = np.max(avg_loss[ignore_first_n_epochs:])
    ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)
    ax.legend(['avg train', 'eval'])

  # The accuracy.
  ax = plt.subplot(122)
  ax.set_title('Eval Accuracy')
  ax.set_ylabel('accuracy')
  ax.set_xlabel('epoch')
  ymin = np.min(eval_accuracy[ignore_first_n_epochs:])
  ymax = np.max(eval_accuracy[ignore_first_n_epochs:])
  ax.set_ylim(ymin-(ymax-ymin)/10, ymax+(ymax-ymin)/10)
  ax.plot(range(1, epochs+1), eval_accuracy, "-o", linewidth=3)

class Progress:
    """Text mode progress bar.
    Usage:
            p = Progress(30)
            p.step()
            p.step()
            p.step(reset=True) # to restart form 0%
    The progress bar displays a new header at each restart."""
    def __init__(self, maxi, size=100, msg=""):
        """
        :param maxi: the number of steps required to reach 100%
        :param size: the number of characters taken on the screen by the progress bar
        :param msg: the message displayed in the header of the progress bar
        """
        self.maxi = maxi
        self.p = self.__start_progress(maxi)()  # `()`: to get the iterator from the generator.
        self.header_printed = False
        self.msg = msg
        self.size = size
        self.lock = Lock()

    def step(self, reset=False):
        with self.lock:
            if reset:
                self.__init__(self.maxi, self.size, self.msg)
            if not self.header_printed:
                self.__print_header()
            next(self.p)

    def __print_header(self):
        print()
        format_string = "0%{: ^" + str(self.size - 6) + "}100%"
        print(format_string.format(self.msg))
        self.header_printed = True

    def __start_progress(self, maxi):
        def print_progress():
            # Bresenham's algorithm. Yields the number of dots printed.
            # This will always print 100 dots in max invocations.
            dx = maxi
            dy = self.size
            d = dy - dx
            for x in range(maxi):
                k = 0
                while d >= 0:
                    print('=', end="", flush=True)
                    k += 1
                    d -= dx
                d += dy
                yield k
            # Keep yielding the last result if there are too many steps.
            while True:
              yield k

        return print_progress

下载并准备 MNIST 数据集#

(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),
                                         tf.one_hot(y, depth=10)))

BATCH_SIZE = 256
train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)
train_data = train_data.cache()
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)

test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))
test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),
                                         tf.one_hot(y, depth=10)))
test_data = test_data.batch(10000)
test_data = test_data.cache()

(one_batch, one_batch_labels) = next(iter(train_data)) # just one batch
(all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000

配置训练#

此笔记本将为演示目的创建并训练一个简单的模型。

# Training hyperparameters.
JAX_EPOCHS = 3
TF_EPOCHS = 7
STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE
LEARNING_RATE = 0.01
LEARNING_RATE_EXP_DECAY = 0.6

# The learning rate schedule for JAX (with Optax).
jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

# THe learning rate schedule for TensorFlow.
tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

使用 Flax 创建模型#

class ConvModel(flax.linen.Module):

  @flax.linen.compact
  def __call__(self, x, train):
    x = flax.linen.Conv(features=12, kernel_size=(3,3), padding="SAME", use_bias=False)(x)
    x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)
    x = x.reshape((x.shape[0], -1))  # flatten
    x = flax.linen.Dense(features=200, use_bias=True)(x)
    x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)
    x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x)
    x = flax.linen.relu(x)
    x = flax.linen.Dense(features=10)(x)
    #x = flax.linen.log_softmax(x)
    return x

  # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number).
  # `jax.grad` will differentiate it against the fist argument.
  # The user must split trainable and non-trainable variables into `params` and `other_state`.
  # Must pass a different RNG key each time for the dropout mask to be different.
  def loss(self, params, other_state, rng, data, labels, train):
    logits, batch_stats = self.apply({'params': params, **other_state},
                                     data,
                                     mutable=['batch_stats'],
                                     rngs={'dropout': rng},
                                     train=train)
    # The loss averaged across the batch dimension.
    loss = optax.softmax_cross_entropy(logits, labels).mean()
    return loss, batch_stats

  def predict(self, state, data):
    logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False)
    probabilities = flax.linen.log_softmax(logits)
    return probabilities

  def accuracy(self, state, data, labels):
    probabilities = self.predict(state, data)
    predictions = jnp.argmax(probabilities, axis=-1)
    dense_labels = jnp.argmax(labels, axis=-1)
    accuracy = jnp.equal(predictions, dense_labels).mean()
    return accuracy

编写训练步骤函数#

# The training step.
@partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model
def train_step(model, state, optimizer_state, rng, data, labels):

  other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables
  (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True)

  updates, optimizer_state = optimizer.update(grads, optimizer_state)
  params = optax.apply_updates(params, updates)
  new_state = state.copy(add_or_replace={**batch_stats, 'params': params})

  rng, _ = jax.random.split(rng)

  return new_state, optimizer_state, rng, loss

编写训练循环#

def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies):
  p = Progress(STEPS_PER_EPOCH)
  rng = jax.random.PRNGKey(0)
  for epoch in range(epochs):

    # This is where the learning rate schedule state is stored in the optimizer state.
    optimizer_step = optimizer_state[1].count

    # Run an epoch of training.
    for step, (data, labels) in enumerate(train_data):
      p.step(reset=(step==0))
      state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy())
      losses.append(loss)
    avg_loss = np.mean(losses[-step:])
    avg_losses.append(avg_loss)

    # Run one epoch of evals (10,000 test images in a single batch).
    other_state, params = state.pop('params')
    # Gotcha: must discard modified batch_stats here
    eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False)
    eval_losses.append(eval_loss)
    eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy())
    eval_accuracies.append(eval_accuracy)

    print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", jlr_decay(optimizer_step))

  return state, optimizer_state

创建模型和优化器(使用 Optax)#

# The model.
model = ConvModel()
state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout"

# The optimizer.
optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule
optimizer_state = optimizer.init(state['params'])

losses=[]
avg_losses=[]
eval_losses=[]
eval_accuracies=[]

训练模型#

new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

部分训练模型#

您很快将在 TensorFlow 中继续训练模型。

model = ConvModel()
state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout"

# The optimizer.
optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule
optimizer_state = optimizer.init(state['params'])

losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []
state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

保存推断所需的内容#

如果您的目标是部署 JAX 模型(以便您可以使用 model.predict() 运行推断),只需将其导出到 SavedModel。本节演示如何实现这一点。

# Test data with a different batch size to test polymorphic shapes.
x, y = next(iter(train_data.unbatch().batch(13)))

m = tf.Module()
# Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function.
state_vars = tf.nest.map_structure(tf.Variable, state)
# Keep the wrapped state as flat list (needed in TensorFlow fine-tuning).
m.vars = tf.nest.flatten(state_vars)
# Convert the desired JAX function (`model.predict`).
predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"])
# Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work).
@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])
def predict(data):
    return predict_fn(state_vars, data)
m.predict = predict
tf.saved_model.save(m, "./")
# Test the converted function.
print("Converted function predictions:", np.argmax(m.predict(x).numpy(), axis=-1))
# Reload the model.
reloaded_model = tf.saved_model.load("./")
# Test the reloaded converted function (the result should be the same).
print("Reloaded  function predictions:", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

保存一切#

如果您的目标是全面导出(如果您计划将模型导入 TensorFlow 进行微调、融合等,这很有用),本节将演示如何保存模型,以便您可以访问以下方法:

  • model.predict

  • model.accuracy

  • model.loss(包括 train=True/False bool,用于随机失活和 BatchNorm 状态更新的 RNG)

from collections import abc

def _fix_frozen(d):
  """Changes any mappings (e.g. frozendict) back to dict."""
  if isinstance(d, list):
    return [_fix_frozen(v) for v in d]
  elif isinstance(d, tuple):
    return tuple(_fix_frozen(v) for v in d)
  elif not isinstance(d, abc.Mapping):
    return d
  d = dict(d)
  for k, v in d.items():
    d[k] = _fix_frozen(v)
  return d
class TFModel(tf.Module):
  def __init__(self, state, model):
    super().__init__()

    # Special care needed for the train=True/False parameter in the loss
    @jax.jit
    def loss_with_train_bool(state, rng, data, labels, train):
      other_state, params = state.pop('params')
      loss, batch_stats = jax.lax.cond(train,
                                       lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True),
                                       lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False),
                                       state, data, labels)
      # must use JAX to split the RNG, therefore, must do it in a @jax.jit function
      new_rng, _ = jax.random.split(rng)
      return loss, batch_stats, new_rng

    self.state_vars = tf.nest.map_structure(tf.Variable, state)
    self.vars = tf.nest.flatten(self.state_vars)
    self.jax_rng = tf.Variable(jax.random.PRNGKey(0))

    self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=["...", "...", "(b, 28, 28, 1)", "(b, 10)", "..."])
    self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=["...", "(b, 28, 28, 1)", "(b, 10)"])
    self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"])

  # Must specify TensorSpec manually for variable batch size to work
  @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])
  def predict(self, data):
    # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly
    # otherwise, all model weights would be embedded in the TF graph as constants.
    return self.predict_fn(self.state_vars, data)

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def train_loss(self, data, labels):
      loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True)
      # update batch norm stats
      flat_vars = tf.nest.flatten(self.state_vars['batch_stats'])
      flat_values = tf.nest.flatten(batch_stats['batch_stats'])
      for var, val in zip(flat_vars, flat_values):
        var.assign(val)
      # update RNG
      self.jax_rng.assign(new_rng)
      return loss

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def eval_loss(self, data, labels):
      loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False)
      return loss

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def accuracy(self, data, labels):
    return self.accuracy_fn(self.state_vars, data, labels)
# Instantiate the model.
tf_model = TFModel(state, model)

# Save the model.
tf.saved_model.save(tf_model, "./")

重新加载模型#

reloaded_model = tf.saved_model.load("./")

# Test if it works and that the batch size is indeed variable.
x,y = next(iter(train_data.unbatch().batch(13)))
print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))
x,y = next(iter(train_data.unbatch().batch(20)))
print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

print(reloaded_model.accuracy(one_batch, one_batch_labels))
print(reloaded_model.accuracy(all_test_data, all_test_labels))

在 TensorFlow 中继续训练转换后的 JAX 模型#

optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay)

# Set the iteration step for the learning rate to resume from where it left off in JAX.
optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH)

p = Progress(STEPS_PER_EPOCH)

for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS):

  # This is where the learning rate schedule state is stored in the optimizer state.
  optimizer_step = optimizer.iterations

  for step, (data, labels) in enumerate(train_data):
    p.step(reset=(step==0))
    with tf.GradientTape() as tape:
      #loss = reloaded_model.loss(data, labels, True)
      loss = reloaded_model.train_loss(data, labels)
      grads = tape.gradient(loss, reloaded_model.vars)
      optimizer.apply_gradients(zip(grads, reloaded_model.vars))
      losses.append(loss)
  avg_loss = np.mean(losses[-step:])
  avg_losses.append(avg_loss)

  eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy()
  eval_losses.append(eval_loss)
  eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy()
  eval_accuracies.append(eval_accuracy)

  print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", tflr_decay(optimizer.iterations).numpy())
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH)

# The loss takes a hit when the training restarts, but does not go back to random levels.
# This is likely caused by the optimizer momentum being reinitialized.

后续步骤#

您可以在包含详细指南和示例的文档网站上详细了解 JAXFlax。如果您刚接触 JAX,请务必浏览 JAX 101 教程,并查看 Flax 快速入门。要详细了解如何将 JAX 模型转换为 TensorFlow 格式,请查看 GitHub 上的 jax2tf 实用工具。如果您有兴趣将 JAX 模型转换为使用 TensorFlow.js 在浏览器中运行,请访问使用 TensorFlow.js 在 Web 上运行 JAX。如果您想准备 JAX 模型以在 TensorFLow Lite 中运行,请访问 TFLite 的 JAX 模型转换指南。