مشاهده در TensorFlow.org | در Google Colab اجرا شود | مشاهده منبع در GitHub | دانلود دفترچه یادداشت |
عبارت "Saving a TensorFlow model" معمولاً به معنای یکی از دو چیز است:
- ایست های بازرسی، OR
- SavedModel.
چک پوینتها مقدار دقیق تمام پارامترها (اشیاء tf.Variable
) را که توسط یک مدل استفاده میشود را میگیرند. نقاط بازرسی حاوی هیچ توضیحی از محاسبات تعریف شده توسط مدل نیستند و بنابراین معمولاً فقط زمانی مفید هستند که کد منبعی که از مقادیر پارامتر ذخیره شده استفاده می کند در دسترس باشد.
فرمت SavedModel از سوی دیگر شامل یک توصیف سریالی از محاسبات تعریف شده توسط مدل علاوه بر مقادیر پارامتر (نقطه بازرسی) است. مدلهای این قالب مستقل از کد منبعی هستند که مدل را ایجاد کرده است. بنابراین، آنها برای استقرار از طریق TensorFlow Serving، TensorFlow Lite، TensorFlow.js، یا برنامه هایی در سایر زبان های برنامه نویسی (C، C++، Java، Go، Rust، C# و غیره) مناسب هستند.
این راهنما APIهایی را برای نوشتن و خواندن نقاط بازرسی پوشش می دهد.
برپایی
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
صرفه جویی از API های آموزشی tf.keras
راهنمای tf.keras
در مورد ذخیره و بازیابی را ببینید.
tf.keras.Model.save_weights
یک ایست بازرسی TensorFlow را ذخیره می کند.
net.save_weights('easy_checkpoint')
نوشتن پست های بازرسی
حالت پایدار یک مدل TensorFlow در اشیاء tf.Variable
ذخیره می شود. اینها را میتوان مستقیماً ساخت، اما اغلب از طریق APIهای سطح بالا مانند tf.keras.layers
یا tf.keras.Model
.
ساده ترین راه برای مدیریت متغیرها این است که آنها را به اشیاء پایتون متصل کنید و سپس به آن اشیا ارجاع دهید.
زیر tf.train.Checkpoint
، tf.keras.layers.Layer
، و tf.keras.Model
به طور خودکار متغیرهای اختصاص داده شده به ویژگیهای آنها را ردیابی میکنند. مثال زیر یک مدل خطی ساده می سازد، سپس نقاط بازرسی را می نویسد که حاوی مقادیری برای همه متغیرهای مدل هستند.
با Model.save_weights
میتوانید به راحتی یک مدل-بررسی را ذخیره کنید.
ایست بازرسی دستی
برپایی
برای کمک به نشان دادن همه ویژگیهای tf.train.Checkpoint
، یک مجموعه داده اسباببازی و مرحله بهینهسازی را تعریف کنید:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
اشیاء ایست بازرسی را ایجاد کنید
از یک شی tf.train.Checkpoint
استفاده کنید تا به صورت دستی یک نقطه بازرسی ایجاد کنید، جایی که اشیایی که می خواهید به چکپوینت بپردازید به عنوان ویژگی روی شی تنظیم می شوند.
یک tf.train.CheckpointManager
همچنین می تواند برای مدیریت چندین ایست بازرسی مفید باشد.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
آموزش و بازرسی مدل
حلقه آموزشی زیر نمونه ای از مدل و یک بهینه ساز را ایجاد می کند، سپس آنها را در یک شی tf.train.Checkpoint
جمع می کند. مرحله آموزش را در یک حلقه روی هر دسته از داده ها فراخوانی می کند و به طور دوره ای نقاط بازرسی را روی دیسک می نویسد.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
بازیابی و ادامه آموزش
پس از اولین چرخه آموزشی، می توانید یک مدل و مدیر جدید را پاس کنید، اما آموزش را دقیقاً از جایی که متوقف کرده اید ادامه دهید:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
شی tf.train.CheckpointManager
نقاط بازرسی قدیمی را حذف می کند. در بالا پیکربندی شده است تا فقط سه نقطه بازرسی اخیر را حفظ کند.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
این مسیرها، به عنوان مثال './tf_ckpts/ckpt-10'
، فایل های روی دیسک نیستند. در عوض آنها پیشوندهای یک فایل index
و یک یا چند فایل داده ای هستند که حاوی مقادیر متغیر هستند. این پیشوندها در یک فایل checkpoint
واحد ( './tf_ckpts/checkpoint'
) گروه بندی می شوند که CheckpointManager
وضعیت خود را ذخیره می کند.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
مکانیک بارگذاری
TensorFlow با عبور از یک نمودار جهتدار با یالهای نامگذاریشده، با شروع از شی در حال بارگذاری، متغیرها را با مقادیر نقطهبازرسی مطابقت میدهد. نامهای لبه معمولاً از نامهای ویژگی در اشیاء میآیند، برای مثال "l1"
در self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
از نام آرگومان های کلیدواژه خود استفاده می کند، مانند "step"
در tf.train.Checkpoint(step=...)
.
نمودار وابستگی از مثال بالا به شکل زیر است:
بهینه ساز قرمز، متغیرهای معمولی به رنگ آبی، و متغیرهای اسلات بهینه ساز به رنگ نارنجی هستند. گره های دیگر - برای مثال، نشان دهنده tf.train.Checkpoint
- به رنگ سیاه هستند.
متغیرهای اسلات بخشی از وضعیت بهینه ساز هستند، اما برای یک متغیر خاص ایجاد می شوند. به عنوان مثال، لبه های 'm'
بالا با تکانه مطابقت دارد که بهینه ساز Adam برای هر متغیر آن را دنبال می کند. متغیرهای اسلات فقط در صورتی ذخیره می شوند که متغیر و بهینه ساز هر دو ذخیره شوند، بنابراین لبه های چین خورده.
فراخوانی restore
در یک شی tf.train.Checkpoint
، بازیابی های درخواستی را در صف قرار می دهد، به محض اینکه یک مسیر منطبق از شی Checkpoint
وجود دارد، مقادیر متغیر را بازیابی می کند. به عنوان مثال، شما می توانید فقط بایاس را از مدلی که در بالا تعریف کردید با بازسازی یک مسیر به آن از طریق شبکه و لایه بارگذاری کنید.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
نمودار وابستگی برای این اشیاء جدید یک زیرگراف بسیار کوچکتر از نقطه بازرسی بزرگتری است که در بالا نوشتید. این فقط شامل سوگیری و یک شمارنده ذخیره است که tf.train.Checkpoint
از آن برای شماره گذاری نقاط بازرسی استفاده می کند.
restore
یک شی وضعیت را برمیگرداند که دارای ادعاهای اختیاری است. تمام اشیاء ایجاد شده در Checkpoint
جدید بازیابی شده اند، بنابراین status.assert_existing_objects_matched
می گذرد.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
اشیاء زیادی در چک پوینت وجود دارد که مطابقت ندارند، از جمله هسته لایه و متغیرهای بهینه ساز. status.assert_consumed
فقط در صورتی میگذرد که نقطه بازرسی و برنامه دقیقاً مطابقت داشته باشند، و یک استثنا در اینجا ایجاد میکند.
ترمیم های معوق
اشیاء Layer
در TensorFlow ممکن است ایجاد متغیرها را به اولین فراخوانی خود موکول کنند، زمانی که اشکال ورودی در دسترس هستند. به عنوان مثال، شکل هسته یک لایه Dense
به هر دو شکل ورودی و خروجی لایه بستگی دارد، و بنابراین شکل خروجی مورد نیاز به عنوان آرگومان سازنده، اطلاعات کافی برای ایجاد متغیر به تنهایی نیست. از آنجایی که فراخوانی یک Layer
مقدار متغیر را نیز می خواند، بازیابی باید بین ایجاد متغیر و اولین استفاده از آن اتفاق بیفتد.
برای پشتیبانی از این اصطلاح، tf.train.Checkpoint
بازیابی هایی را که هنوز متغیر منطبقی ندارند به تعویق می اندازد.
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]
بازرسی دستی ایست های بازرسی
tf.train.load_checkpoint
یک CheckpointReader
را برمی گرداند که سطح پایین تری را به محتویات ایست بازرسی می دهد. این شامل نگاشتهایی از کلید هر متغیر، به شکل و نوع d برای هر متغیر در نقطه بازرسی است. کلید یک متغیر مسیر شی آن است، مانند نمودارهای نمایش داده شده در بالا.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
بنابراین اگر به مقدار net.l1.kernel
دارید، می توانید مقدار را با کد زیر دریافت کنید:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
همچنین یک متد get_tensor
ارائه می دهد که به شما امکان می دهد مقدار یک متغیر را بررسی کنید:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
ردیابی اشیا
چک پوینت ها مقادیر اشیاء tf.Variable
را با "ردیابی" هر متغیر یا شی قابل پیگیری در یکی از ویژگی های آن ذخیره و بازیابی می کنند. هنگام اجرای یک ذخیره، متغیرها به صورت بازگشتی از همه اشیاء قابل دسترسی ردیابی شده جمع آوری می شوند.
همانند تخصیصهای مستقیم ویژگیها مانند self.l1 = tf.keras.layers.Dense(5)
، اختصاص فهرستها و فرهنگهای لغت به ویژگیها، محتوای آنها را دنبال میکند.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
ممکن است متوجه اشیاء بسته بندی لیست ها و لغت نامه ها شوید. این بستهبندیها نسخههای قابل بررسی ساختارهای دادهای هستند. درست مانند بارگذاری مبتنی بر ویژگی، این wrapper ها مقدار متغیر را به محض اضافه شدن به ظرف بازیابی می کنند.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
اشیاء قابل ردیابی عبارتند از tf.train.Checkpoint
، tf.Module
و زیر کلاسهای آن (به عنوان مثال keras.layers.Layer
و keras.Model
)، و کانتینرهای شناسایی پایتون:
-
dict
(وcollections.OrderedDict
. OrderedDict) -
list
-
tuple
(وcollections.namedtuple
، تایپ کردن.typing.NamedTuple
)
سایر انواع کانتینر پشتیبانی نمی شوند ، از جمله:
-
collections.defaultdict
-
set
سایر اشیاء پایتون نادیده گرفته می شوند، از جمله:
-
int
-
string
-
float
خلاصه
اشیاء TensorFlow یک مکانیسم خودکار آسان برای ذخیره و بازیابی مقادیر متغیرهایی که استفاده می کنند ارائه می کنند.