Copyright 2024 The AI Edge Authors.
Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Run in Google Colab | View source on GitHub | Download notebook |
LiteRT Authoring API provides a way to maintain your tf.function
models compatible with LiteRT.
Setup
import tensorflow as tf
TensorFlow to LiteRT compatibility issue
If you want to use your TF model on devices, you need to convert it to a TFLite model to use it from TFLite interpreter. During the conversion, you might encounter a compatibility error because of unsupported TensorFlow ops by the TFLite builtin op set.
This is a kind of annoying issue. How can you detect it earlier like the model authoring time?
Note that the following code will fail on the converter.convert()
call.
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
# Convert the tf.function
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[f.get_concrete_function()], f)
try:
fb_model = converter.convert()
except Exception as e:
print(f"Got an exception: {e}")
Simple Target Aware Authoring usage
We introduced Authoring API to detect the LiteRT compatibility issue during the model authoring time.
You just need to add @tf.lite.experimental.authoring.compatible
decorator to wrap your tf.function
model to check TFLite compatibility.
After this, the compatibility will be checked automatically when you evaluate your model.
@tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
If any LiteRT compatibility issue is found, it will show COMPATIBILITY WARNING
or COMPATIBILITY ERROR
with the exact location of the problematic op. In this example, it shows the location of tf.Cosh
op in your tf.function model.
You can also check the compatibility log with the <function_name>.get_compatibility_log()
method.
compatibility_log = '\n'.join(f.get_compatibility_log())
print (f"compatibility_log = {compatibility_log}")
Raise an exception for an incompatibility
You can provide an option to the @tf.lite.experimental.authoring.compatible
decorator. The raise_exception
option gives you an exception when you're trying to evaluate the decorated model.
@tf.lite.experimental.authoring.compatible(raise_exception=True)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
try:
result = f(tf.constant([0.0]))
print (f"result = {result}")
except Exception as e:
print(f"Got an exception: {e}")
Specifying "Select TF ops" usage
If you're already aware of Select TF ops usage, you can tell this to the Authoring API by setting converter_target_spec
. It's the same tf.lite.TargetSpec object you'll use it for tf.lite.TFLiteConverter API.
target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec, raise_exception=True)
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
# Evaluate the tf.function
result = f(tf.constant([0.0]))
print (f"result = {result}")
Checking GPU compatibility
If you want to ensure your model is compatible with GPU delegate of LiteRT, you can set experimental_supported_backends
of tf.lite.TargetSpec.
The following example shows how to ensure GPU delegate compatibility of your model. Note that this model has compatibility issues since it uses a 2D tensor with tf.slice operator and unsupported tf.cosh operator. You'll see two COMPATIBILITY WARNING
with the location information.
target_spec = tf.lite.TargetSpec()
target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
target_spec.experimental_supported_backends = ["GPU"]
@tf.lite.experimental.authoring.compatible(converter_target_spec=target_spec)
@tf.function(input_signature=[
tf.TensorSpec(shape=[4, 4], dtype=tf.float32)
])
def func(x):
y = tf.cosh(x)
return y + tf.slice(x, [1, 1], [1, 1])
result = func(tf.ones(shape=(4,4), dtype=tf.float32))