Flutter Image classification using TensorFlow in 4 steps


Photo by Alexander Sinn on Unsplash

Google provides a few products within the TensorFlow family:

  • TensorFlow: the core open course library that is the foundation of developing and training machine learning models.
  • TensorFlow.js: similar to TensorFlow but focus purely on JavaScript
  • TensorFlow Lite: as the name suggests, it is a lightweight version of TensorFlow for deploying models on mobile devices. It has limited functions where it only accepts pre-trained model injections and loads the model into a mobile device. You can use it for image classifications, object detection, and question/answer based on the natural language model.
  • TensorFlow Production: it is an extension of TensorFlow for large production environments.

Here in this article, I will use TensorFlow Lite to deploy a model into a Flutter application. Unfortunately at the time of writing this article, there isn’t an official TensorFlow library for Flutter, therefore we will be using a 3rd party lib called tflite_flutter.


Step 1: Get yourself a model

To use TensorFlow Lite, you must convert a full TensorFlow model into TensorFlow Lite, and you cannot use this library itself to train a model, luckily the Lite library comes with lots of pre-trained models for image detection, object detection, smart reply, pose estimation and segmentation. Alternatively, you can also find pre-trained models from TensorFlow Hub, make sure you selected the model type as TFLite.

The TensorFlow Hub site gives you lots of options of how to inject the model, you can select from downloading or import directly into Android Studio.

If you chose to download a model, the file you will receive will be named “some-image-classification-model.tflite”, and remember to unzip the file and extract the label, you need both the model.tflite and label.txt file later on.

unzip some-image-classification-model.tflite

Step 2: Create a Flutter Project

*Prerequisite: IntelliJ or VS Code IDE with Flutter build environment

Create a new Flutter project with Android/iOS/Web-enabled, or use your existing Flutter project if you have one. In your root directory, create a folder called “assets” and save your “label.txt” and “model.tflite” inside that folder.

Next, go to your project pubspec.yaml file adds the following dependencies:

name: tensorflow
description: A new Flutter application.
version: 1.0.0+1
environment:
sdk: ">=2.12.0 <3.0.0"

dependencies:
flutter:
sdk: flutter
tflite: 1.1.2
image_picker: 0.7.4


dev_dependencies:
flutter_test:
sdk: flutter

flutter:
uses-material-design: true
assets:
- assets/model.tflite
- assets/label.txt

Step 3: Coding time

  • Create a Flutter main app
void main() => runApp(MaterialApp(
home: ImageDetectApp(),
));

class ImageDetectApp extends StatefulWidget {
@override
_ImageDetectState createState() => _ImageDetectState();
}
  • Create an _ImageDetectState class and init the Tflite library
class _ImageDetectState extends State<ImageDetectApp> {
List? _listResult;
PickedFile? _imageFile;
bool _loading = false;

@override
void initState() {
super.initState();
_loading = true;
_loadModel();
}

void _loadModel() async {
await Tflite.loadModel(
model: "assets/model.tflite",
labels: "assets/label.txt",
).then((value) {
setState(() {
_loading = false;
});
});
}
  • Inside this class, create a floating button (or any click events) to receive user action of image selection
floatingActionButton: FloatingActionButton(
onPressed: _imageSelection,
backgroundColor: Colors.blue,
child: Icon(Icons.add_photo_alternate_outlined),
)
  • Add image selection function
void _imageSelection() async {
var imageFile = await ImagePicker().getImage(source: ImageSource.gallery);
setState(() {
_loading = true;
_imageFile = imageFile;
});
_imageClasification(imageFile);
}
  • Add image classification function
void _imageClasification(PickedFile image) async {
var output = await Tflite.runModelOnImage(
path: image.path,
numResults: 2,
threshold: 0.5,
imageMean: 127.5,
imageStd: 127.5,
);
setState(() {
_loading = false;
_listResult = output;
});
}
  • Last but not least: dispose of the Tflite lib
@override
void dispose() {
Tflite.close();
super.dispose();
}

Run the project and voila!


Step 4 Bouns: Train your own model

There are many ways to train your own model, here in this example, I will use Google colab (https://colab.research.google.com/), you can run the same code sample showing in this demo from an IDE.

  • First, install packages as prerequisites
!pip install -q tflite-model-maker

Add the above code into the code block and click run, add another code block by clicking the “+Code” sign, click run to execute the below code.

import os
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2')
from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader
from tflite_model_maker.image_classifier import ModelSpec
import matplotlib.pyplot as plt
  • Second, upload your data set

It is time to collect images! To have an accurate result, you need a minimum of 100+ images per set and store them inside one folder.

Compress the food_images folder (or any folder that you prefer to use), upload this zip file into colab, and after successfully uploading, the next step is to unzip it (with new code block and execution)

!unzip food_images.zip
  • Load data to an on-device ML app, and split it into training and testing data (with new code block and execution)
data = DataLoader.from_folder(‘/content/food_images’)
train_data, test_data = data.split(0.9)

from google.colab import drive
drive.mount('/content/drive')
  • Customize the TensorFlow model and evaluate it
model = image_classifier.create(train_data)
loss, accuracy = model.evaluate(test_data)
  • Export the TensorFlow Lite model and its label
model.export(export_dir=’.’)
model.export(export_dir=’.’, export_format=ExportFormat.LABEL)

Download this model and label, import them into your Flutter project, Viola!