TensorFlow Lite Model Maker: Build an Image Classifier for Android

Building machine learning models for edge devices just got a whole lot easier

TensorFlow only recently concluded its yearly Dev Summit via livestream (due to the COVID-19 global pandemic) and there were a lot of exciting announcements, most focused on propelling machine learning to even greater heights.

From a robust new release of the core TensorFlow platform (TF2.2) to new Google Cloud AI Platform Pipelines for making the use of TensorFlow in production even easier, and beyond.

But that isn’t the focus of this piece. Instead, we’ll dig into one of the breakthrough announcements of the year and that is: the TensorFlow Lite Model Maker.

Using the TF Lite Model Maker, which is baked into the TF Lite support library, building models ready for mobile and edge devices is super easy. Moreover, Android Studio 4.1(current, a Canary release) with it’s new code generation capability for TF Lite models automatically generates the wrapper Java classes for TF Lite models, thereby simplifying model development and deployment processes for mobile machine learning devs.

TensorFlow Lite is a lightweight, cross-platform solution for deploying ML models on mobile and embedded devices. If you’d like to get up to speed with all the news and announcements regarding TensorFlow Lite from this year’s TF Dev Summit, I’d highly recommend that you check out this resource.

TensorFlow Lite Model Maker

The TF Lite Model Maker is a Python API that makes building machine learning models from scratch a no-brainer. All it needs is five lines of code (excluding imports), as shown below:

In the above illustration, we’re loading a dataset and splitting it into training and testing sets. Subsequently, we train, evaluate, and export the TF Lite model along with the labels (retrieved from the subfolders).

Under the hood, the Model Maker API uses transfer learning to re-train a model with a different dataset and categories. By default, the Model Maker API uses EfficientNet-Lite0 as the base model.

EfficientNet-Lite was only recently released and belongs to the family of image classification models capable of achieving state-of-the-art accuracy on edge devices. The following graph shows an accuracy vs size comparison of the EfficientNet-Lite models and stacks it up against MobileNet and ResNet.

The Model Maker API also lets us switch the underlying model. For example:

Alternatively, we can also pass hosted models from TensorFlow Hub, along with customized input shapes, as shown below:

We can also fine-tune the training hyperparameters like epochs, dropout_rate, and batch_size in the create function of the ModelMaker API.

Now that we’ve got a good look at the core functionality of the Model Maker API, let’s tighten the dependencies required to run the above Python script.

Upgrading TensorFlow

Make sure you’re running Python3.6 or above and have the latest pip version installed on your macOS. TensorFlow 2 packages require a pip version >19.0. Subsequently, pip install the following to update TensorFlow:

From our terminal, let’s quickly test that we have the latest TensorFlow version installed, using the following command:

Installing the Model Maker library

Run the following command on your terminal to install the Model Maker library:

Things are all set up now, which means it’s time to train our model. Just run the Python script from your macOS terminal. For this demo, we’ve used an appropriate NSFW dataset from Kaggle. Once our model is ready, it’s time to import into our new Android Studio project.

Setting Up Our Gradle Dependencies for the TensorFlow Lite Model

Android Studio 4.1 has few new enhancements for TensorFlow Lite models:

  • First, it lets us import a tflite model directly from the import menu and places it in an ml folder. Just go to File > New > Other > TensorFlow Lite Model.
  • Secondly, Android Studio now has a model viewer that shows the metadata summary — input and output tensors, a description of those, and sample code, as shown below:

By default, the Model Maker API generates only the bare minimum metadata, which is comprised of input and output shapes. To extend and add more context such as author, version, license, and input and output descriptions, we can leverage the new extended metadata feature (currently in an experimental stage).

Enable ML Model Binding

Despite placing the tflite model in the ml directory, model binding isn’t enabled automatically. You’ll need to add the buildFeatures and aaptOptions elements in your app’s build.gradle script to enable it:

Our model’s classifier is now available for us to run inference. It’s time to add the tensorflow-lite dependencies in the build.gradle file:

Setting Up Our Activity Layout

Now it’s time to lay down the UI elements in our Activity. To keep it simple, our activity_main.xml file consists of a RecyclerView and a Button:

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    android:orientation="vertical"
    tools:context=".MainActivity">

    <androidx.recyclerview.widget.RecyclerView
        android:id="@+id/recyclerView"
        android:layout_width="match_parent"
        android:layout_above="@+id/btnClassifier"
        android:layout_height="match_parent" />


    <Button
        android:id="@+id/btnClassifier"
        android:text="Run Classifier"
        android:layout_alignParentBottom="true"
        android:layout_centerHorizontal="true"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"/>


</RelativeLayout>

To populate our RecyclerView’s adapter, we need a model. The following Kotlin data class holds an image, predicted text, and a boolean flag to indicate whether our input image is NSFW.

The following XML code represents the layout for each of the RecyclerView’s rows:

Now that we’ve created the data model and the view, it’s time to feed them to the RecyclerView’s Adapter.

Setting Up Our RecyclerView’s Adapter

The following code creates the Adapter class of the RecyclerView:

import android.content.Context
import android.graphics.Color
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import androidx.core.content.ContextCompat
import androidx.recyclerview.widget.RecyclerView
import kotlinx.android.synthetic.main.item_row.view.*


class RecyclerViewAdapter(val items: ArrayList<DataModel>, val context: Context) :
    RecyclerView.Adapter<ViewHolder>() {

    override fun getItemCount(): Int {
        return items.size
    }

    override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder {
        return ViewHolder(
            LayoutInflater.from(context).inflate(
                R.layout.item_row,
                parent,
                false
            )
        )
    }

    override fun onBindViewHolder(holder: ViewHolder, position: Int) {

        val data = items.get(position)

        holder.tvPrediction?.text = data.prediction

        val image = ContextCompat.getDrawable(context, data.drawableID)
        holder.imageView.setImageDrawable(image)

        if (data.isNSFW) {
            holder.imageView.setColorFilter(Color.BLACK)
        } else {
            holder.imageView.setColorFilter(Color.TRANSPARENT)
        }
    }
}

class ViewHolder(view: View) : RecyclerView.ViewHolder(view) {
    val tvPrediction = view.tvPrediction
    val imageView = view.imageView
}

We’re setting a color filter on the ImageView based on the NSFW output (NSFW images are hidden in black for obvious reasons).

Finally, it’s time to dive into our MainActivity.kt, where we initialize the above adapter and, more importantly, run inference on a list of images.

Running the TF Lite Image Classifier

To run the model, we need to pre-process the input to satisfy the model’s constraints. TensorFlow Lite has a bunch of image pre-processing methods built-in. To use them, we first need to initialize an ImageProcessor and subsequently add the required operators:

Pre-processing the Input Image

In the following code, we’re resizing the input image to 224 by 224, the dimensions of the model’s input shape:

TensorImage is the input that’s fed to our TensorFlow Lite model. But before we run inference, let’s create a post-processor that’ll normalize the output probabilities.

Setting Up Our Post-processor

A post-processor is basically a container that will de-quantize our results:

Running Inference

The following few lines of code instantiate the classifier that was auto-generated from the model, pass the input tensor image, and get the results in the outputBuffer:

TensorLabel is used to map the associated probabilities with their labels. In our model, there are just a couple of labels: “NSFW” and “SFW”. We’ve set them in the labelsList ArrayList. In a different scenario, you can parse the labels.txt file to get ahold of all categories, like what’s done here.

Finally, using the mapWithFloatValue function, we can retrieve the probabilities of both the NSFW and SFW categories.

The complete code of the MainActivity.kt is given below. It runs the above image classifier on every image and updates the RecyclerView adapter with the data changes accordingly:

class MainActivity : AppCompatActivity() {

    val dataArray: ArrayList<DataModel> = ArrayList()
    val labelsList = arrayListOf("NSFW", "SFW")

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        btnClassifier.setOnClickListener {
            val iterate = dataArray.listIterator()
            while (iterate.hasNext()) {
                val oldValue = iterate.next()
                runImageClassifier(oldValue)
            }
            recyclerView.adapter?.notifyDataSetChanged()
        }

        populateData()
        recyclerView.layoutManager = LinearLayoutManager(this)
        recyclerView.adapter = RecyclerViewAdapter(dataArray, this)

    }

    fun populateData()
    {
        dataArray.add(DataModel(R.drawable.sfw_1,true,""))
        dataArray.add(DataModel(R.drawable.nsfw,true,""))
        dataArray.add(DataModel(R.drawable.nsfw2,true,""))
        dataArray.add(DataModel(R.drawable.sfw,true,""))
    }

    fun runImageClassifier(data: DataModel)
    {

        val bitmap =
            BitmapFactory.decodeResource(applicationContext.resources, data.drawableID)

        try {

            val probabilityProcessor =
                TensorProcessor.Builder().add(NormalizeOp(0f, 255f)).build()

            val imageProcessor = ImageProcessor.Builder()
                .add(ResizeOp(224, 224, ResizeOp.ResizeMethod.BILINEAR))
                .build()
            
            var tImage = TensorImage(DataType.FLOAT32)

            tImage.load(bitmap)
            tImage = imageProcessor.process(tImage)
            val model = NsfwClassifier.newInstance(this@MainActivity)
            val outputs =
                model.process(probabilityProcessor.process(tImage.tensorBuffer))
            val outputBuffer = outputs.outputFeature0AsTensorBuffer
            val tensorLabel = TensorLabel(labelsList, outputBuffer)

            val nsfwProbability = tensorLabel.mapWithFloatValue.get("NSFW")
            if (nsfwProbability?.compareTo(0.5)!! < 0){
                data.isNSFW = false
            }
            data.prediction =  "NSFW : "+ tensorLabel.mapWithFloatValue.get("NSFW")
            

        } catch (e: Exception) {
            Log.d("TAG", "Exception is " + e.localizedMessage)
        }
    }
}

Here’s an output of the above application in action:

Closing Thoughts

It suffices to say that the Model Maker Python library is here to stay and will be widely used by mobile developers looking to quickly deploy ML models on-device.

For those of you who’re aware of Apple’s machine learning technologies, the TF Lite Model Maker is similar to Create ML, at least theoretically. Currently, the Model Maker API only supports image and text classification use cases, with object detection and QR readers expected to be out soon.

Android Studio’s support for ML model binding and automatic code generation removes the need to interact with ByteBuffer as we did in a previous TensorFlow Lite Android tutorial.

Extended metadata (in an experimental stage at the time of writing) also allows us to generate custom, platform-specific wrapper code, thereby further reducing the amount of boilerplate code we need to write. We’ll look into custom code generation and much more in a future tutorial.

The full source code of the above tutorial is available in this GitHub Repository.

That’s a wrap for this one. Thanks for reading.

Fritz

Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square