跳转至

12.9.模型的保存与导入

Linux 5.4.0-80-generic
Python 3.9.5 @ GCC 7.3.0
Latest build date 2021.07.31
tensorflow version:  2.5.0
import keras
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import os
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_classification


class Loader:
    def __init__(self):
        self.train_data, self.train_label = make_classification(n_samples=100, n_features=2, n_classes=2,
                                                                n_informative=2, n_redundant=0,
                                                                n_repeated=0, n_clusters_per_class=1,
                                                                flip_y=0, random_state=3)
        self.num_train_data = self.train_data.shape[0]

    def get_batch(self, batch_size):
        index = np.random.randint(0, self.num_train_data, batch_size)
        return self.train_data[index, :], self.train_label[index]  # (batch_size, 2), (batch_size,)


class Perceptron(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(1, activation=None, input_shape=(2,),
                                            use_bias=True, kernel_initializer="zeros",
                                            bias_initializer="zeros")

    def call(self, inputs):                            # (batch_size, 2)
        x = self.dense1(inputs)                        # (batch_size, 1)
        output = tf.keras.activations.hard_sigmoid(x)  # (batch_size, 1)
        return output

Checkpoint

保存 Checkpoint

num_epochs = 50
batch_size = 16
learning_rate = 0.001

model = Perceptron()
data_loader = Loader()
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)

num_batchs = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batchs):
    X, y = data_loader.get_batch(batch_size)  # (batch_size, 2), (batch_size,)
    with tf.GradientTape() as tape:
        y_pred = model(X)  # (batch_size, 1)
        loss = tf.keras.losses.mean_squared_error(y_true=y, y_pred=y_pred)  # (batch_size, )
        loss = tf.reduce_mean(loss)  # scalar
        # print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
    # print(model.get_weights())
    if batch_index % 100 == 0:
        checkpoint.save("./KerasPerceptron/model.ckpt")

print(os.listdir("./KerasPerceptron"))
['checkpoint', 'model.ckpt-2.data-00000-of-00001',
'model.ckpt-1.index', 'model.ckpt-1.data-00000-of-00001',
'model.ckpt-3.index', 'model.ckpt-3.data-00000-of-00001',
'model.ckpt-2.index']

恢复 Checkpoint

Checkpoint.restore 恢复模型权重时,需要指定保存计数。

model_to_be_restored = KerasPerceptron()
checkpoint_2 = tf.train.Checkpoint(model=model_to_be_restored)
checkpoint_2.restore(r"./KerasPerceptron/model.ckpt-3")
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at
0x7f98447bd280>

也可以使用 tf.train.latest_checkpoint 函数自动确定最后的保存计数。

checkpoint_2.restore(tf.train.latest_checkpoint("./KerasPerceptron/model.ckpt"))
checkpoint_2.restore(tf.train.latest_checkpoint("./KerasPerceptron"))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at
0x7f98447bdfd0>

SaveModel