12.9.模型的保存与导入
Linux 5.4.0-80-generic
Python 3.9.5 @ GCC 7.3.0
Latest build date 2021.07.31
tensorflow version: 2.5.0
import keras
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import os
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_classification
class Loader:
def __init__(self):
self.train_data, self.train_label = make_classification(n_samples=100, n_features=2, n_classes=2,
n_informative=2, n_redundant=0,
n_repeated=0, n_clusters_per_class=1,
flip_y=0, random_state=3)
self.num_train_data = self.train_data.shape[0]
def get_batch(self, batch_size):
index = np.random.randint(0, self.num_train_data, batch_size)
return self.train_data[index, :], self.train_label[index] # (batch_size, 2), (batch_size,)
class Perceptron(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(1, activation=None, input_shape=(2,),
use_bias=True, kernel_initializer="zeros",
bias_initializer="zeros")
def call(self, inputs): # (batch_size, 2)
x = self.dense1(inputs) # (batch_size, 1)
output = tf.keras.activations.hard_sigmoid(x) # (batch_size, 1)
return output
Checkpoint
保存 Checkpoint
num_epochs = 50
batch_size = 16
learning_rate = 0.001
model = Perceptron()
data_loader = Loader()
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, decay=1e-6, momentum=0.9, nesterov=True)
checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
num_batchs = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batchs):
X, y = data_loader.get_batch(batch_size) # (batch_size, 2), (batch_size,)
with tf.GradientTape() as tape:
y_pred = model(X) # (batch_size, 1)
loss = tf.keras.losses.mean_squared_error(y_true=y, y_pred=y_pred) # (batch_size, )
loss = tf.reduce_mean(loss) # scalar
# print("batch %d: loss %f" % (batch_index, loss.numpy()))
grads = tape.gradient(loss, model.variables)
optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))
# print(model.get_weights())
if batch_index % 100 == 0:
checkpoint.save("./KerasPerceptron/model.ckpt")
print(os.listdir("./KerasPerceptron"))
['checkpoint', 'model.ckpt-2.data-00000-of-00001',
'model.ckpt-1.index', 'model.ckpt-1.data-00000-of-00001',
'model.ckpt-3.index', 'model.ckpt-3.data-00000-of-00001',
'model.ckpt-2.index']
恢复 Checkpoint
Checkpoint.restore
恢复模型权重时,需要指定保存计数。
model_to_be_restored = KerasPerceptron()
checkpoint_2 = tf.train.Checkpoint(model=model_to_be_restored)
checkpoint_2.restore(r"./KerasPerceptron/model.ckpt-3")
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at
0x7f98447bd280>
也可以使用 tf.train.latest_checkpoint
函数自动确定最后的保存计数。
checkpoint_2.restore(tf.train.latest_checkpoint("./KerasPerceptron/model.ckpt"))
checkpoint_2.restore(tf.train.latest_checkpoint("./KerasPerceptron"))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at
0x7f98447bdfd0>