Larq Zoo Tutorial¶
This tutorial demonstrates how to load pretrained models from Larq Zoo. These models can be used for prediction, feature extraction, and fine-tuning.
pip install larq larq-zoo
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import larq_zoo as lqz
from urllib.request import urlopen
from PIL import Image
img_path = "https://raw.githubusercontent.com/larq/zoo/master/tests/fixtures/elephant.jpg"
with urlopen(img_path) as f:
img = Image.open(f).resize((224, 224))
x = tf.keras.preprocessing.image.img_to_array(img)
x = lqz.preprocess_input(x)
x = np.expand_dims(x, axis=0)
Classify ImageNet classes with QuickNet¶
We will first load the QuickNet architecture with pretrained weights and predict the image class.
model = lqz.sota.QuickNet(weights="imagenet")
preds = model.predict(x)
lqz.decode_predictions(preds, top=5)[0]
Extract features with QuickNet¶
Larq Zoo models can also be used to extract features that can be used as input to a second model.
tf.keras.backend.clear_session()
model = lqz.sota.QuickNet(weights="imagenet", include_top=False)
features = model.predict(x)
print("Feature shape:", features.shape)
Extract features from an arbitrary intermediate layer¶
Features can also be extracted from arbitrary intermediate layer with just a few lines of code.
avg_pool_layer = model.get_layer("add_7")
avg_pool_model = tf.keras.models.Model(
inputs=model.input, outputs=avg_pool_layer.output)
avg_pool_features = avg_pool_model.predict(x)
print("add_7 feature shape:", avg_pool_features.shape)
Build QuickNet over a custom input Tensor¶
The model can also be used with an input Tensor that might also be the output a different Keras model or layer.
input_tensor = tf.keras.layers.Input(shape=(224, 224, 3))
model = lqz.sota.QuickNet(input_tensor=input_tensor, weights="imagenet")
Evaluate QuickNet with TensorFlow Datasets¶
To re-run the evaluation on the entire ImageNet validation dataset Tensorflow Datasets can be used.
Note that running this example will require mannualy downloading the entire dataset and might take a very long time to complete.
def preprocess(data):
img = lqz.preprocess_input(data["image"])
label = tf.one_hot(data["label"], 1000)
return img, label
dataset = (
tfds.load("imagenet2012:5.0.0", split=tfds.Split.VALIDATION)
.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.batch(128)
.prefetch(1)
)
model = lqz.sota.QuickNet()
model.compile(
optimizer="sgd",
loss="categorical_crossentropy",
metrics=["categorical_accuracy", "top_k_categorical_accuracy"],
)
model.evaluate(dataset)