بازرسی های آموزشی

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت

عبارت "Saving a TensorFlow model" معمولاً به معنای یکی از دو چیز است:

  1. ایست های بازرسی، OR
  2. SavedModel.

چک پوینت‌ها مقدار دقیق تمام پارامترها (اشیاء tf.Variable ) را که توسط یک مدل استفاده می‌شود را می‌گیرند. نقاط بازرسی حاوی هیچ توضیحی از محاسبات تعریف شده توسط مدل نیستند و بنابراین معمولاً فقط زمانی مفید هستند که کد منبعی که از مقادیر پارامتر ذخیره شده استفاده می کند در دسترس باشد.

فرمت SavedModel از سوی دیگر شامل یک توصیف سریالی از محاسبات تعریف شده توسط مدل علاوه بر مقادیر پارامتر (نقطه بازرسی) است. مدل‌های این قالب مستقل از کد منبعی هستند که مدل را ایجاد کرده است. بنابراین، آنها برای استقرار از طریق TensorFlow Serving، TensorFlow Lite، TensorFlow.js، یا برنامه هایی در سایر زبان های برنامه نویسی (C، C++، Java، Go، Rust، C# و غیره) مناسب هستند.

این راهنما APIهایی را برای نوشتن و خواندن نقاط بازرسی پوشش می دهد.

برپایی

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

صرفه جویی از API های آموزشی tf.keras

راهنمای tf.keras در مورد ذخیره و بازیابی را ببینید.

tf.keras.Model.save_weights یک ایست بازرسی TensorFlow را ذخیره می کند.

net.save_weights('easy_checkpoint')

نوشتن پست های بازرسی

حالت پایدار یک مدل TensorFlow در اشیاء tf.Variable ذخیره می شود. اینها را می‌توان مستقیماً ساخت، اما اغلب از طریق APIهای سطح بالا مانند tf.keras.layers یا tf.keras.Model .

ساده ترین راه برای مدیریت متغیرها این است که آنها را به اشیاء پایتون متصل کنید و سپس به آن اشیا ارجاع دهید.

زیر tf.train.Checkpoint ، tf.keras.layers.Layer ، و tf.keras.Model به طور خودکار متغیرهای اختصاص داده شده به ویژگی‌های آنها را ردیابی می‌کنند. مثال زیر یک مدل خطی ساده می سازد، سپس نقاط بازرسی را می نویسد که حاوی مقادیری برای همه متغیرهای مدل هستند.

با Model.save_weights می‌توانید به راحتی یک مدل-بررسی را ذخیره کنید.

ایست بازرسی دستی

برپایی

برای کمک به نشان دادن همه ویژگی‌های tf.train.Checkpoint ، یک مجموعه داده اسباب‌بازی و مرحله بهینه‌سازی را تعریف کنید:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

اشیاء ایست بازرسی را ایجاد کنید

از یک شی tf.train.Checkpoint استفاده کنید تا به صورت دستی یک نقطه بازرسی ایجاد کنید، جایی که اشیایی که می خواهید به چکپوینت بپردازید به عنوان ویژگی روی شی تنظیم می شوند.

یک tf.train.CheckpointManager همچنین می تواند برای مدیریت چندین ایست بازرسی مفید باشد.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

آموزش و بازرسی مدل

حلقه آموزشی زیر نمونه ای از مدل و یک بهینه ساز را ایجاد می کند، سپس آنها را در یک شی tf.train.Checkpoint جمع می کند. مرحله آموزش را در یک حلقه روی هر دسته از داده ها فراخوانی می کند و به طور دوره ای نقاط بازرسی را روی دیسک می نویسد.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

بازیابی و ادامه آموزش

پس از اولین چرخه آموزشی، می توانید یک مدل و مدیر جدید را پاس کنید، اما آموزش را دقیقاً از جایی که متوقف کرده اید ادامه دهید:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

شی tf.train.CheckpointManager نقاط بازرسی قدیمی را حذف می کند. در بالا پیکربندی شده است تا فقط سه نقطه بازرسی اخیر را حفظ کند.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

این مسیرها، به عنوان مثال './tf_ckpts/ckpt-10' ، فایل های روی دیسک نیستند. در عوض آنها پیشوندهای یک فایل index و یک یا چند فایل داده ای هستند که حاوی مقادیر متغیر هستند. این پیشوندها در یک فایل checkpoint واحد ( './tf_ckpts/checkpoint' ) گروه بندی می شوند که CheckpointManager وضعیت خود را ذخیره می کند.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

مکانیک بارگذاری

TensorFlow با عبور از یک نمودار جهت‌دار با یال‌های نام‌گذاری‌شده، با شروع از شی در حال بارگذاری، متغیرها را با مقادیر نقطه‌بازرسی مطابقت می‌دهد. نام‌های لبه معمولاً از نام‌های ویژگی در اشیاء می‌آیند، برای مثال "l1" در self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint از نام آرگومان های کلیدواژه خود استفاده می کند، مانند "step" در tf.train.Checkpoint(step=...) .

نمودار وابستگی از مثال بالا به شکل زیر است:

تجسم نمودار وابستگی برای حلقه آموزش مثال

بهینه ساز قرمز، متغیرهای معمولی به رنگ آبی، و متغیرهای اسلات بهینه ساز به رنگ نارنجی هستند. گره های دیگر - برای مثال، نشان دهنده tf.train.Checkpoint - به رنگ سیاه هستند.

متغیرهای اسلات بخشی از وضعیت بهینه ساز هستند، اما برای یک متغیر خاص ایجاد می شوند. به عنوان مثال، لبه های 'm' بالا با تکانه مطابقت دارد که بهینه ساز Adam برای هر متغیر آن را دنبال می کند. متغیرهای اسلات فقط در صورتی ذخیره می شوند که متغیر و بهینه ساز هر دو ذخیره شوند، بنابراین لبه های چین خورده.

فراخوانی restore در یک شی tf.train.Checkpoint ، بازیابی های درخواستی را در صف قرار می دهد، به محض اینکه یک مسیر منطبق از شی Checkpoint وجود دارد، مقادیر متغیر را بازیابی می کند. به عنوان مثال، شما می توانید فقط بایاس را از مدلی که در بالا تعریف کردید با بازسازی یک مسیر به آن از طریق شبکه و لایه بارگذاری کنید.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

نمودار وابستگی برای این اشیاء جدید یک زیرگراف بسیار کوچکتر از نقطه بازرسی بزرگتری است که در بالا نوشتید. این فقط شامل سوگیری و یک شمارنده ذخیره است که tf.train.Checkpoint از آن برای شماره گذاری نقاط بازرسی استفاده می کند.

تجسم یک زیرگراف برای متغیر بایاس

restore یک شی وضعیت را برمی‌گرداند که دارای ادعاهای اختیاری است. تمام اشیاء ایجاد شده در Checkpoint جدید بازیابی شده اند، بنابراین status.assert_existing_objects_matched می گذرد.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

اشیاء زیادی در چک پوینت وجود دارد که مطابقت ندارند، از جمله هسته لایه و متغیرهای بهینه ساز. status.assert_consumed فقط در صورتی می‌گذرد که نقطه بازرسی و برنامه دقیقاً مطابقت داشته باشند، و یک استثنا در اینجا ایجاد می‌کند.

ترمیم های معوق

اشیاء Layer در TensorFlow ممکن است ایجاد متغیرها را به اولین فراخوانی خود موکول کنند، زمانی که اشکال ورودی در دسترس هستند. به عنوان مثال، شکل هسته یک لایه Dense به هر دو شکل ورودی و خروجی لایه بستگی دارد، و بنابراین شکل خروجی مورد نیاز به عنوان آرگومان سازنده، اطلاعات کافی برای ایجاد متغیر به تنهایی نیست. از آنجایی که فراخوانی یک Layer مقدار متغیر را نیز می خواند، بازیابی باید بین ایجاد متغیر و اولین استفاده از آن اتفاق بیفتد.

برای پشتیبانی از این اصطلاح، tf.train.Checkpoint بازیابی هایی را که هنوز متغیر منطبقی ندارند به تعویق می اندازد.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

بازرسی دستی ایست های بازرسی

tf.train.load_checkpoint یک CheckpointReader را برمی گرداند که سطح پایین تری را به محتویات ایست بازرسی می دهد. این شامل نگاشتهایی از کلید هر متغیر، به شکل و نوع d برای هر متغیر در نقطه بازرسی است. کلید یک متغیر مسیر شی آن است، مانند نمودارهای نمایش داده شده در بالا.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

بنابراین اگر به مقدار net.l1.kernel دارید، می توانید مقدار را با کد زیر دریافت کنید:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

همچنین یک متد get_tensor ارائه می دهد که به شما امکان می دهد مقدار یک متغیر را بررسی کنید:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

ردیابی اشیا

چک پوینت ها مقادیر اشیاء tf.Variable را با "ردیابی" هر متغیر یا شی قابل پیگیری در یکی از ویژگی های آن ذخیره و بازیابی می کنند. هنگام اجرای یک ذخیره، متغیرها به صورت بازگشتی از همه اشیاء قابل دسترسی ردیابی شده جمع آوری می شوند.

همانند تخصیص‌های مستقیم ویژگی‌ها مانند self.l1 = tf.keras.layers.Dense(5) ، اختصاص فهرست‌ها و فرهنگ‌های لغت به ویژگی‌ها، محتوای آنها را دنبال می‌کند.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

ممکن است متوجه اشیاء بسته بندی لیست ها و لغت نامه ها شوید. این بسته‌بندی‌ها نسخه‌های قابل بررسی ساختارهای داده‌ای هستند. درست مانند بارگذاری مبتنی بر ویژگی، این wrapper ها مقدار متغیر را به محض اضافه شدن به ظرف بازیابی می کنند.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

اشیاء قابل ردیابی عبارتند از tf.train.Checkpoint ، tf.Module و زیر کلاس‌های آن (به عنوان مثال keras.layers.Layer و keras.Model )، و کانتینرهای شناسایی پایتون:

  • dictcollections.OrderedDict . OrderedDict)
  • list
  • tuplecollections.namedtuple ، تایپ کردن. typing.NamedTuple )

سایر انواع کانتینر پشتیبانی نمی شوند ، از جمله:

  • collections.defaultdict
  • set

سایر اشیاء پایتون نادیده گرفته می شوند، از جمله:

  • int
  • string
  • float

خلاصه

اشیاء TensorFlow یک مکانیسم خودکار آسان برای ذخیره و بازیابی مقادیر متغیرهایی که استفاده می کنند ارائه می کنند.