Screenshotfrom

How to Use TensorFlow Mobile in Android Apps

With TensorFlow, one of the most popular machine learning frameworks available today, you can easily create and train deep models—also commonly referred to as deep feed-forward neural networks—that can solve a variety of complex problems, such as image classification, object detection, and natural language comprehension. TensorFlow Mobile is a library designed to help you leverage those models in your mobile apps.

In this tutorial, I’ll show you how to use TensorFlow Mobile in Android Studio projects.

Prerequisites

To be able to follow this tutorial, you’ll need:

  • Android Studio 3.0 or higher
  • TensorFlow 1.5.0 or higher
  • an Android device running API level 21 or higher
  • and a basic understanding of the TensorFlow framework

1. Creating a Model

Before we start using TensorFlow Mobile, we’ll need a trained TensorFlow model. Let’s create one now.

Our model is going to be very basic. It will behave like an XOR gate, taking two inputs, both of which can be either zero or one, and producing one output, which will be zero if both the inputs are identical and one otherwise. Additionally, because it’s going to be a deep model, it will have two hidden layers, one with four neurons, and another with three neurons. You are free to change the number of hidden layers and the numbers of neurons they contain.

In order to keep this tutorial short, instead of using the low-level TensorFlow APIs directly, we’ll be using TFLearn, a popular wrapper framework for TensorFlow offering more intuitive and concise APIs. If you don’t have it already, use the following command to install it inside your TensorFlow virtual environment:

To start creating the model, create a Python script named create_model.py, preferably in an empty directory, and open it with your favorite text editor.

Inside the file, the first thing we need to do is import the TFLearn APIs.

Next, we must create the training data. For our simple model, there will be only four possible inputs and outputs, which will resemble the contents of the XOR gate’s truth table.

It is usually a good idea to use random values picked from a uniform distribution while assigning initial weights to all the neurons in the hidden layers. To generate the values, use the uniform() method.

At this point, we can start creating the layers of our neural network. To create the input layer, we must use the input_data() method, which allows us to specify the number of inputs the network can accept. Once the input layer is ready, we can call the fully_connected() method multiple times to add more layers to the network.

Note that in the above code, we have given meaningful names to the input and output layers. Doing so is important because we’ll need them while using the network from our Android app. Also note that the hidden and output layers are using the sigmoid activation function. You are free to experiment with other activation functions, such as softmaxtanh, and relu.

As the last layer of our network, we must create a regression layer using the regression() function, which expects a few hyper-parameters as its arguments, such as the network’s learning rate and the optimizer and loss functions it should use. The following code shows you how to use stochastic gradient descent, SGD for short, as the optimizer function and mean square as the loss function:

Next, in order to let the TFLearn framework know that our network model is actually a deep neural network model, we must call the DNN() function.

The model is now ready. All we need to do now is train it using the training data we created earlier. So call the fit() method of the model and, along with the training data, specify the number of training epochs to run. Because the training data is very small, our model will need thousands of epochs to attain reasonable accuracy.

Once the training is complete, we can call the predict() method of the model to check if it is generating the desired outputs. The following code shows you how to check the outputs for all valid inputs:

If you run the Python script now, you should see output that looks like this:

Predictions after training

Note that the outputs are never exactly 0 or 1. Instead, they are floating-point numbers that are either close to zero or close to one. Therefore, while using the outputs, you might want to use Python’s round() function.

Unless we explicitly save the model after training it, we will lose it as soon as the script ends. Fortunately, with TFLearn, a simple call to the save() method saves the model. However, to be able to use the saved model with TensorFlow Mobile, before saving it, we must make sure we remove all the training-related operations, which are present in the tf.GraphKeys.TRAIN_OPS collection, associated with it. The following code shows you how to do so:

If you run the script again, you’ll see that it generates a checkpoint file, a metadata file, an index file, and a data file, all of which when used together can quickly recreate our trained model.

2. Freezing the Model

In addition to saving the model, we must freeze it before we can use it with TensorFlow Mobile. The process of freezing a model, as you might have guessed, involves converting all its variables into constants. Additionally, a frozen model must be a single binary file that conforms to the Google Protocol Buffers serialization format.

Create a new Python script named freeze_model.py and open it using a text editor. We’ll be writing all the code to freeze our model inside this file.

Because TFLearn doesn’t have any functions for freezing models, we’ll have to use the TensorFlow APIs directly now. Import them by adding the following line to the file:

Throughout the script, we’ll be using a single TensorFlow session. To create the session, use the constructor of the Session class.

At this point, we must create a Saver object by calling the import_meta_graph() function and passing the name of the model’s metadata file to it. In addition to returning a Saver object, the import_meta_graph() function also automatically adds the graph definition of the model to the graph definition of the session.

Once the saver is created, we can initialize all the variables that are present in the graph definition by calling the restore() method, which expects the path of the directory containing the model’s latest checkpoint file.

At this point, we can call the convert_variables_to_constants() function to create a frozen graph definition where all the variables of the model are replaced with constants. As its inputs, the function expects the current session, the current session’s graph definition, and a list containing the names of the model’s output layers.

Calling the SerializeToString() method of the frozen graph definition gives us a binary protobuf representation of the model. By using Python’s basic file I/O facilities, I suggest you save it as a file named frozen_model.pb.

You can run the script now to generate the frozen model.

We now have everything we need to start using TensorFlow Mobile.

3. Android Studio Project Setup

The TensorFlow Mobile library is available on JCenter, so we can directly add it as an implementation dependency in the app module’s build.gradle file.

To add the frozen model to the project, place the frozen_model.pb file in the project’s assets folder.

4. Initializing the TensorFlow Interface

TensorFlow Mobile offers a simple interface we can use to interact with our frozen model. To create the interface, use the constructor of the TensorFlowInferenceInterface class, which expects an AssetManager instance and the filename of the frozen model.

In the above code, you can see that we’re spawning a new thread. Doing so, although not always necessary, is recommended in order to make sure that the app’s UI stays responsive.

To be sure that TensorFlow Mobile has managed to read our model’s file correctly, let’s now try printing the names of all the operations that are present in the model’s graph. To get a reference to the graph, we can use the graph() method of the interface, and to get all the operations, the operations() method of the graph. The following code shows you how:

If you run the app now, you should be able to see over a dozen operation names printed in Android Studio’s Logcat window. Among all those names, if there were no errors while freezing the model, you’ll be able to find the names of the input and output layers: my_input/X and my_output/Sigmoid.

Logcat window showing list of operations

5. Using the Model

To make predictions with the model, we must put data into its input layer and retrieve data from its output layer. To put data into the input layer, use the feed() method of the interface, which expects the name of the layer, an array containing the inputs, and the dimensions of the array. The following code shows you how to send the numbers 0 and 1 to the input layer:

After loading data into the input layer, we must run an inference operation using the run() method, which expects the name of the output layer. Once the operation is complete, the output layer will contain the prediction of the model. To load the prediction into a Kotlin array, we can use the fetch() method. The following code shows you how to do so:

How you use the prediction is of course up to you. For now, I suggest you simply print it.

You can run the app now to see that the model’s prediction is correct.

Logcat window displaying the prediction

Feel free to change the numbers you feed to the input layer to confirm that the model’s predictions are always correct.

Conclusion

You now know how to create a simple TensorFlow model and use it with TensorFlow Mobile in Android apps. You don’t always have to limit yourself to your own models, though. With the skills you learned today, you should have no problems using larger models, such as MobileNet and Inception, available in the TensorFlow model zoo. Note, however, that such models will lead to larger APKs, which may create issues for users with low-end devices.

To learn more about TensorFlow Mobile, do refer to the official documentation.

Powered by WPeMatico

Leave a Comment

Scroll to Top