Image Classification on Android with TensorFlow Lite and CameraX

Leverage the GPU delegate for machine learning on the edge

TensorFlow Lite is the lightweight version of TensorFlow Mobile. It’s here to unleash the machine learning power on your smartphones while ensuring that the model binary size isn’t too big and there’s low latency. Additionally, it also supports hardware acceleration using the Neural Networks API and is destined to run 4X faster with GPU support.

CameraX is the latest Camera API released with the Jetpack Support library. It’s here to make developing with the camera much easier, and with Google’s automated lab testing, it strives to make things consistent across Android devices, of which there are many. CameraX represents a huge improvement from the Camera 2 API in terms of ease of use and simplicity.

The goal of this article is to merge the camera and ML worlds by processing CameraX frames for image classification using a TensorFlow Lite model. We’ll be building an Android application using Kotlin that leverages the power of GPUs of your smartphones.

CameraX: A Brief Overview

CameraX is lifecycle aware. So it removes the need for handling the states in the onResume and onPause methods.

The API is use case-based. The three main use cases that are currently supported are:

  • Preview — Displays the Camera feed.
  • Analyze — To process images for computer vision or other machine learning-related tasks.
  • Capture — To save high-quality images.

Additionally, CameraX provides Extensions to easily access features such as HDR, Portrait, and Night Mode on supported devices.

Tensor Flow Lite Converter

The TensorFlow Lite converter takes a TensorFlow model and generates a TensorFlow Lite FlatBuffer file. The .tflite model then can be deployed on mobile or embedded devices to run locally using the Tensor Flow interpreter.

The following code snippet depicts one such way of converting a Keras model to a mobile compatible .tflite file:

In the following sections, we’ll be demonstrating a hands-on implementation of CameraX with a MobileNet TensorFlow Lite model using Kotlin. You can create your own custom trained models or choose among the hosted, pre-trained ones.

Implementation

Under the Hood

The flow is really simple. We pass the bitmap images from the Analyze use case in CameraX to the TensorFlow interpreter that runs inference on the image using the MobileNet model and the label classes. Here’s an illustration of how CameraX and TensorFlow Lite interact with one another.

Setup

Launch a new Android Studio Kotlin project and add the following dependencies in your app’s build.gradle file.

The nightly TensorFlow Lite build provides experimental support for GPUs. The Google Play Services Task API is used for handling asynchronous method calls.

Next, add the MVP files, the labels, and the .tflite model file under your assets directory. Also, you need to ensure that the model isn’t compressed by setting the following aaptOptions in the build.gradle file:

Add the necessary permissions for the camera in your AndroidManifest.xml file:

Now that the setup is complete, it’s time to establish the layout!

Layout

The layout is defined inside the activity_main.xml file, and it consists of a TextureView for displaying the Camera Preview and a TextView that shows the predicted output from your image classification model.

<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:background="@color/colorPrimary"
    android:layout_height="match_parent"
    tools:context=".MainActivity">

    <TextureView
        android:id="@+id/textureView"
        android:layout_width="0dp"
        android:layout_height="0dp"
        app:layout_constraintBottom_toTopOf="@+id/rlBottom"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />


    <RelativeLayout
        android:id="@+id/rlBottom"
        android:layout_width="wrap_content"
        android:layout_height="200dp"
        app:layout_constraintBottom_toBottomOf="parent"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@+id/textureView">

        <TextView
            android:id="@+id/predictedTextView"
            android:layout_width="wrap_content"
            android:layout_height="wrap_content"
            android:layout_centerInParent="true"
            android:text=""
            android:textColor="@android:color/white"
            android:textSize="26sp" />

    </RelativeLayout>


</androidx.constraintlayout.widget.ConstraintLayout>

Request Camera Permissions

You’ll need to request runtime permissions before accessing the camera. The following code from the MainActivity.kt class shows how that’s done.

class MainActivity : AppCompatActivity() {

    private val REQUEST_CODE_PERMISSIONS = 101
    private val REQUIRED_PERMISSIONS = arrayOf("android.permission.CAMERA")

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        if (allPermissionsGranted()) {
            textureView.post { startCamera() }
            textureView.addOnLayoutChangeListener { _, _, _, _, _, _, _, _, _ ->
                updateTransform()
            }
        } else {
            ActivityCompat.requestPermissions(this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSIONS);
        }
    }
    
  override fun onRequestPermissionsResult(
        requestCode: Int,
        permissions: Array<String>,
        grantResults: IntArray
    ) {

        if (requestCode == REQUEST_CODE_PERMISSIONS) {
            if (allPermissionsGranted()) {
                startCamera()
            } else {
                Toast.makeText(this, "Permissions not granted by the user.", Toast.LENGTH_SHORT)
                    .show()
                finish()
            }
        }
    }

    private fun allPermissionsGranted(): Boolean {

        for (permission in REQUIRED_PERMISSIONS) {
            if (ContextCompat.checkSelfPermission(
                    this,
                    permission
                ) != PackageManager.PERMISSION_GRANTED
            ) {
                return false
            }
        }
        return true
    }
}

Once permission is granted, we’ll start our camera!

Setting Up Camera Use Cases

As seen in the previous section’s code, startCamera is called from the post method on the TextureView . This ensures that the camera is only started once the TextureView is laid on the screen. In the updateTransform method, we fix the orientation of the view with respect to the device’s orientation.

private var lensFacing = CameraX.LensFacing.BACK

private fun startCamera() {
        val metrics = DisplayMetrics().also { textureView.display.getRealMetrics(it) }
        val screenSize = Size(metrics.widthPixels, metrics.heightPixels)
        val screenAspectRatio = Rational(1, 1)

        val previewConfig = PreviewConfig.Builder().apply {
            setLensFacing(lensFacing)
            setTargetResolution(screenSize)
            setTargetAspectRatio(screenAspectRatio)
            setTargetRotation(windowManager.defaultDisplay.rotation)
            setTargetRotation(textureView.display.rotation)
        }.build()

        val preview = Preview(previewConfig)
        preview.setOnPreviewOutputUpdateListener {
            textureView.surfaceTexture = it.surfaceTexture
            updateTransform()
        }


        val analyzerConfig = ImageAnalysisConfig.Builder().apply {
            // Use a worker thread for image analysis to prevent glitches
            val analyzerThread = HandlerThread("AnalysisThread").apply {
                start()
            }
            setCallbackHandler(Handler(analyzerThread.looper))
            setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
        }.build()


        val analyzerUseCase = ImageAnalysis(analyzerConfig)
        analyzerUseCase.analyzer =
            ImageAnalysis.Analyzer { image: ImageProxy, rotationDegrees: Int ->

                val bitmap = image.toBitmap()

                tfLiteClassifier
                    .classifyAsync(bitmap)
                    .addOnSuccessListener { resultText -> predictedTextView?.text = resultText }
                    .addOnFailureListener { error -> }

            }
        CameraX.bindToLifecycle(this, preview, analyzerUseCase)
    }

    private fun updateTransform() {
        val matrix = Matrix()
        val centerX = textureView.width / 2f
        val centerY = textureView.height / 2f

        val rotationDegrees = when (textureView.display.rotation) {
            Surface.ROTATION_0 -> 0
            Surface.ROTATION_90 -> 90
            Surface.ROTATION_180 -> 180
            Surface.ROTATION_270 -> 270
            else -> return
        }
        matrix.postRotate(-rotationDegrees.toFloat(), centerX, centerY)
        textureView.setTransform(matrix)
    }

In the above code, we’re doing quite a few things. Let’s go through each of them:

  • Setting up our Preview use case using the PreviewConfig.Builder.
  • setOnPreviewOutputUpdateListener is where we add the surface texture of the camera preview to the TextureView.
  • Inside the Analyzer use case, we convert the image proxy to a Bitmap and pass it to the TFClassifier’s classify method. If this looks out of place, skip it for now, as we’ll be discussing the TFClassifier class at length in the next section.

The following code snippet is used for converting the ImageProxy to a Bitmap:

It’s now time to run image classification! Let’s jump to the next section.

Tensor Flow Lite Interpreter

The TensorFlow Lite Interpreter follows the following steps in order to return the predictions based on the input.

1. Converting the model into a ByteBuffer

We must memory map the model from the Assets folder to get a ByteBuffer, which is ultimately loaded into the interpreter:

2. Loading the labels classes into a Data Structure

The labels file consists of thousands of different classes from ImageNet. We’ll load those labels into an Array. In the end, the interpreter will return predictions based on these label strings.

3. Initializing Our Interpreter

Now that we’ve got our ByteBuffer and label list, it’s time to initialize our interpreter. In the following code, we’ve added the GPUDelegate in our Interpreter.Options() method:

class TFLiteClassifier(private val context: Context) {

    private var interpreter: Interpreter? = null
    var isInitialized = false
        private set

    private var gpuDelegate: GpuDelegate? = null

    var labels = ArrayList<String>()

    private val executorService: ExecutorService = Executors.newCachedThreadPool()

    private var inputImageWidth: Int = 0
    private var inputImageHeight: Int = 0
    private var modelInputSize: Int = 0

    fun initialize(): Task<Void> {
        return call(
            executorService,
            Callable<Void> {
                initializeInterpreter()
                null
            }
        )
    }

    @Throws(IOException::class)
    private fun initializeInterpreter() {

        val assetManager = context.assets
        val model = loadModelFile(assetManager, "mobilenet_v1_1.0_224.tflite")

        labels = loadLines(context, "labels.txt")
        val options = Interpreter.Options()
        gpuDelegate = GpuDelegate()
        options.addDelegate(gpuDelegate)
        val interpreter = Interpreter(model, options)

        val inputShape = interpreter.getInputTensor(0).shape()
        inputImageWidth = inputShape[1]
        inputImageHeight = inputShape[2]
        modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth * inputImageHeight * CHANNEL_SIZE

        this.interpreter = interpreter

        isInitialized = true
    }
    
    companion object {
        private const val TAG = "TfliteClassifier"
        private const val FLOAT_TYPE_SIZE = 4
        private const val CHANNEL_SIZE = 3
        private const val IMAGE_MEAN = 127.5f
        private const val IMAGE_STD = 127.5f
    }
}

In the above code, once the model’s setup is done in the interpreter, we retrieve the input tensor shape of the model. This is done in order to preprocess the Bitmap into the same shape that the model accepts.

The Callable interface is similar to Runnable but allows us to return a result. The ExecutorService is used for managing multiple threads from a ThreadPool.

The initialize method is called in the onCreate method of our MainActivity, as shown below:

4. Preprocessing the Input and Running Inference

We can now resize our Bitmap to fit the model input shape. Then, we’ll convert the new Bitmap into a ByteBuffer for model execution:

private fun classify(bitmap: Bitmap): String {

        check(isInitialized) { "TF Lite Interpreter is not initialized yet." }
        val resizedImage =
            Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true)

        val byteBuffer = convertBitmapToByteBuffer(resizedImage)

        val output = Array(1) { FloatArray(labels.size) }
        val startTime = SystemClock.uptimeMillis()
        interpreter?.run(byteBuffer, output)
        val endTime = SystemClock.uptimeMillis()

        var inferenceTime = endTime - startTime
        var index = getMaxResult(output[0])
        var result = "Prediction is ${labels[index]}nInference Time $inferenceTime ms"

        return result
}

private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
        val byteBuffer = ByteBuffer.allocateDirect(modelInputSize)
        byteBuffer.order(ByteOrder.nativeOrder())

        val pixels = IntArray(inputImageWidth * inputImageHeight)
        bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
        var pixel = 0
        for (i in 0 until inputImageWidth) {
            for (j in 0 until inputImageHeight) {
                val pixelVal = pixels[pixel++]

                byteBuffer.putFloat(((pixelVal shr 16 and 0xFF) - IMAGE_MEAN) / IMAGE_STD)
                byteBuffer.putFloat(((pixelVal shr 8 and 0xFF) - IMAGE_MEAN) / IMAGE_STD)
                byteBuffer.putFloat(((pixelVal and 0xFF) - IMAGE_MEAN) / IMAGE_STD)

            }
        }
        bitmap.recycle()

        return byteBuffer
}


fun classifyAsync(bitmap: Bitmap): Task<String> {
        return call(executorService, Callable<String> { classify(bitmap) })
    }

In the above code, the convertBitmapToByteBuffer masks the least significant 8 bits from each pixel in order to ignore the alpha channel.

Along with the ByteBuffer, we pass a float array for each of the image classes on which the predictions will be calculated and returned.

5. Computing Arg Max

Finally, the getMaxResult function returns the label with the highest confidence, as shown in the code snippet below:

The classifyAsync method that runs in the Analyzer use case gets a string consisting of prediction and inference time via the onSuccessListener, thanks to Callable interface.

In return, we display the predicted label and the inference time on the screen, as shown below:

Conclusion

So that sums up this article. We used TensorFlow Lite and CameraX to build an image classification Android application using MobileNet while leveraging the GPU delegate—and we got a pretty accurate result pretty quickly. Moving on from here, you can try building your own custom TFLite Models and see how they fare with CameraX. CameraX is still in alpha stages, but there’s already a lot you can do with it.

The full source code of this guide is available here.

That’s it for this one. I hope you enjoyed reading.

Fritz

Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square