Friday, November 27, 2015

Building TensorFlow for Jetson TK1

Google recently released TensorFlow, an open source software library for numerical computation using data flow graphs. 

TensorFlow has a GPU  backend built on CUDA, so I wanted to install it on a Jetson TK1. Even if the system did not meet the requirements ( CUDA 7.0 is not available and the GPU is a compute capability 3.2), I decided to give it a try anyway.  This blog reports all the steps required to build TensorFlow from source, it is quite challenging but it can be done. Including all the prerequisites, the whole build will take several hours ( if you just want to try Tensorflow, you can download the wheel file I generated and do a pip install. The file is at https://drive.google.com/file/d/0B1uGKNpQ7xNqZ2pvSmc3SlZJS2c/view?usp=sharing ). 

TensorFlow is under active development and the coding is using a lot of advanced C++ features that really push the compiler, these instructions worked with the version available on 11/26 but new 

The first challenge is to build Bazel, another software developed at Google used as building system for TensorFlow. Bazel requires a protobuf version newer than the one presents in the Ubuntu 14.04 repos, so the first step will be to install protobuf 3 from source, since there are no prebuilt binary for ARM32.

Java 8:
The first step is to install Java8, but this is quite simple since Oracle provides a package:

$ sudo add-apt-repository ppa:webupd8team/java
$ sudo apt-get update
$ sudo apt-get install oracle-java8-installer

Protobuf:
In order to build protobuf and bazel, we will need several other packages. The exact list will  depend on the status of your Jetson,  but you will need at least these ones:

$ sudo apt-get install git zip unzip autoconf automake libtool curl zlib1g-dev  

After downloading  the latest source from github:

$ git clone https://github.com/google/protobuf.git

you need to first generate the configuration file and then run make:

$ cd protobuf
$ ./autogen.sh 
$ ./configure --prefix=/usr
$ make -j 4
$ sudo make install

Protoc will be installed in /usr/lib and /usr/bin, this will be important when we run bazel since it tries to use a sandbox and will not find the libraries in /usr/local/lib.

You should see this output, if you have followed all the steps:

ubuntu@tegra-ubuntu:~/protobuf$ protoc --version
libprotoc 3.0.0

We also need to build the java interface for protobuf, that will require Maven.
Luckily maven is available from the repos, so we can just issue a:

$ sudo apt-get install maven

Go to the subdirectory java inside protobuf and type:
$ mvn package

Once the build is completes, there will be  a protobuf-java-3.0.0-beta-1.jar inside the target subdirectory.

Bazel:

We are now ready to tackle Bazel.
The first step is to download the source code for Bazel ( using the 0.1.0 version, that it is known to work with Tensorflow). 

$ git clone https://github.com/bazelbuild/bazel.git
$ cd bazel
$git checkout tags/0.1.0

Before compiling, we need to copy the protoc binary we just built as third_party/protobuf/protoc-linux-arm32.exe.
We also need to copy the jar file from protobuf in the same directory. Bazel is expecting an alpha-3 version, but we have built a  beta-1.
There is probably a better way of doing this, but just copying the file and rename it did the trick for me.


$ cp /usr/bin/protoc   third_party/protobuf/protoc-linux-arm32.exe
$ cp ~/protobuf/java/target/protobuf-java-3.0.0-beta-1.jar  third_party/protobuf/protobuf-java-3.0.0-alpha-3.jar

We are now ready to compile bazel. 

$ ./compile.sh

At the end of the compilation, the bazel binary will be in the output directory. You can add this directory
to your path or copy the binary in /usr/local/bin

TensorFlow

We are now ready to tackle the TensorFlow build for GPU. Just be sure to have CUDA 6.5 and CUDNN 6.5 installed on your Jetson TK1. 
You will also need some files from the CUDA 7.0 package ( cuda-repo-l4t-r23.1-7-0-local_7.0-71_armhf.deb ) that you can download from
the NVIDIA web site ( it is the one for Jetson TX1).
While Jetson TK1 cannot run the 7.0 runtime, since the driver shipped with the system does not support it, it is still  possible to run the CUDA 7.0 compiler. We need the 7.0 compiler because some of the TensorFlow source files will generate an internal compiler error with the 6.5 nvcc. 
All the libraries and runtime will be the standard 6.5 ones. 


On my system I have also enabled some swap space. You can plug a USB memory stick,  create a swap file and mount it with
$ sudo mkswap /dev/sda
$ sudo swapon /dev/sda 

The first step to build TensorFlow is to clone the github repository:
$ git clone -recurse-submodules https://github.com/tensorflow/tensorflow 

and install other dependencies:
$ sudo apt-get install python-numpy swig python-dev

TensorFlow is expecting a 64bit system and has a bunch of library paths and libraries hard-coded in the files.
Before starting the installation, we will need to modify several files.  We will need to change all the reference from lib64 to lib and change the 7.0 libraries to 6.5.  We can find all the files with the strings and apply all the changes with these commands:

$ cd tensorflow
$ grep -Rl "lib64"| xargs sed -i 's/lib64/lib/g'
$ grep -Rl "so.7.0"| xargs sed -i 's/so\.7\.0/so\.6\.5/g'


TensorFlow officially supports Cuda devices with 3.5 and 5.2 compute capabilities. We want to target a gpu with compute capabilities 3.2. 
This can be done through TensorFlow unofficial settings with "configure" via the TF_UNOFFICIAL_SETTING variable.
When prompted, specify that you only want a 3.2 compute capability device.

$ TF_UNOFFICIAL_SETTING=1 ./configure

# Same as the official settings above

WARNING: You are configuring unofficial settings in TensorFlow. Because some
external libraries are not backward compatible, these settings are largely
untested and unsupported.

Please specify a list of comma-separated Cuda compute capabilities you want to
build with. You can find the compute capability of your device at:
https://developer.nvidia.com/cuda-gpus.
Please note that each additional compute capability significantly increases
your build time and binary size. [Default is: "3.5,5.2"]: 3.2

Setting up Cuda include
Setting up Cuda lib
Setting up Cuda bin
Setting up Cuda nvvm
Configuration finished


After the configure, bazel has copied or symlinked all the binaries and libraries needed for the build in  the third_party/gpus/cuda subdirectory .
It is now time to replace the cuda compiler with the one from the 7.0 toolchain.

We want to extract (not install) the files from the  cuda-repo-l4t-r23.1-7-0-local_7.0-71_armhf.deb package with the following commands:

$ dpkg -x cuda-repo-l4t-r23.1-7-0-local_7.0-71_armhf.deb /tmp/cuda_repo
$ cd /tmp/cuda_repo/var/cuda-repo-7-0-local
$ dpkg -x cuda-core-7-0_7.0-71_armhf.deb /tmp/cuda7.0
$ rm -fr /tmp/cuda_repo

$ cd ~tensorflow/third_party/gpus/cuda
$ rm -fr bin nvvm
$ cp -R  /tmp/cuda7.0/usr/local/cuda-7.0/bin bin
$ cp -R /tmp/cuda7.0/usr/local/cuda-7.0/nvvm nvvm
$ rm -fr /tmp/cuda7.0

At this point, bazel is ready to use the 7.0 toolchain to compile Tensorflow.

We still need to add the ARM target to the build. 
This can be done adding the following lines to the third_party/gpus/crosstool/CROSSTOOL file:

default_toolchain {
  cpu: "arm"
  toolchain_identifier: "local_linux"
}                                                                                                                                                                                                                                                

Before starting the build, we need to edit few files to avoid compiler crashes and avoid double instantiations 
(on ARM v7, Eigen::DenseIndex is  typedefed to int):

third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h
tensorflow/core/kernels/conv_ops_gpu_2.cu.cc
tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
tensorflow/core/kernels/adjust_contrast_op.h


third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h:  
the compiler is crashing when evaluating the code inside the ifdef at line 312. We can just take the alternative path.
         Change line 312 to something like:
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES_TK1

tensorflow/core/kernels/conv_ops_gpu_2.cu.cc:
To avoid double instantiation, guard the second functor for InflateAnsShuffle with:
/* On ARMv7 Eigen::DenseIndex is typedefed to int */
#ifndef __arm__
template struct functor::InflatePadAndShuffle
                                              Eigen::DenseIndex>;
#endif
We also need to comment the tensor.h include ( will crash the compiler)
//#include "tensorflow/core/public/tensor.h"

tensorflow/core/kernels/conv_ops_gpu_3.cu.cc:
To avoid double instantiation, guard the second functor  for ShuffleAndReverse with:
/* On ARMv7 Eigen::DenseIndex is typedefed to int */
#ifndef __arm__
template struct functor::ShuffleAndReverse
                                           Eigen::DenseIndex>;
#endif

tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:
ARMv7 has no numa_node file. It should return 0 not -1, otherwise TensorFlow will crash at runtime:
FILE *file = fopen(filename.c_str(), "r");
  if (file == nullptr) {
    LOG(ERROR) << "could not open file to read NUMA node: " << filename;
#ifdef __arm__
    // There is no numa_node on Jetson TK1
    return 0;
#else
    return kUnknownNumaNode;
#endif


tensorflow/core/kernels/adjust_contrast_op.h:
The compiler is crashing on some initializations, we need to rewrite them in a simpler way:

//MF Eigen::array scalar_broadcast{{batch, height, width, channels}};
    Eigen::array scalar_broadcast;
    scalar_broadcast[0] = batch;
    scalar_broadcast[1] = height;
    scalar_broadcast[2] = width;
    scalar_broadcast[3] = channels;
#if !defined(EIGEN_HAS_INDEX_LIST)
//MF Eigen::array reduction_axis{{1, 2}};
//MF Eigen::array scalar{{1, 1, 1, 1}};
//MF Eigen::array broadcast_dims{{1, height, width, 1}};
//MF Eigen::Tensor::Dimensions reshape_dims{{batch, 1, 1, channels}};
     Eigen::array reduction_axis;
      reduction_axis[0]=1;
      reduction_axis[1]=2;
     Eigen::array scalar;
      scalar[0]=1;
      scalar[1]=1;
      scalar[2]=1;
      scalar[3]=1;
     Eigen::array broadcast_dims;
      broadcast_dims[0]=1;
      broadcast_dims[1]=height;
      broadcast_dims[2]=width;
      broadcast_dims[3]=1;
     Eigen::Tensor::Dimensions reshape_dims;
      reshape_dims[0]=batch;
      reshape_dims[1]=1;
      reshape_dims[2]=1;
      reshape_dims[3]=channels;
#else

The source code is now ready. Jeston TK1 has only 2GB of memory and bazel will try to compile several files at the same time.
We want to avoid this, so we will pass a local_resource flag that will use only 2GB and half core (don't ask, if you specify one it will still try
to compile two files at the same time). This build will take a long time:

$bazel build -c opt --local_resources 2048,0.5,1.0 --verbose_failures --config=cuda //tensorflow/cc:tutorials_example_trainer

If you get some failures during the build, keep trying, bazel scheduling seems to be non-deterministic and the Tensorflow code is really stressing the
compiler.

Once the build is completed, we can test the code:

$ bazel-bin/tensorflow/cc/tutorials_example_trainer --use_gpu

You should see a similar output:

# Lots of output. This tutorial iteratively calculates the major eigenvalue of
# a 2x2 matrix, on GPU. The last few lines look like this.
000009/000005 lambda = 2.000000 x = [0.894427 -0.447214] y = [1.788854 -0.894427]
000006/000001 lambda = 2.000000 x = [0.894427 -0.447214] y = [1.788854 -0.894427]
000009/000009 lambda = 2.000000 x = [0.894427 -0.447214] y = [1.788854 -0.894427]


We are now ready to create the pip package and install it:
# To build with GPU support:
$ bazel build -c opt --local_resources 2048,0.5,1.0 --verbose_failures --config=cuda //tensorflow/tools/pip_package:build_pip_package

$ bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg

# The name of the .whl file will depend on your platform.
$ sudo pip install /tmp/tensorflow_pkg/tensorflow-0.5.0-cp27-none-linux_armv7l.whl

Congratulation, TensorFlow is now installed on your system.

We can also try a more interesting example of image classification:
bazel build -c opt --local_resources 2048,0.5,1.0 --verbose_failures --config=cuda //tensorflow/examples/label_image/...

$ wget https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -O tensorflow/examples/label_image/data/inception5h.zip
$ unzip tensorflow/examples/label_image/data/inception5h.zip -d tensorflow/examples/label_image/data/
$ mv tensorflow/examples/label_image/data/tensorflow_inception_graph.pb tensorflow/examples/label_image/data/googlenet_graph.pb
$ mv tensorflow/examples/label_image/data/imagenet_comp_graph_label_strings.txt tensorflow/examples/label_image/data/googlenet_labels.txt 

And run it with:
$ bazel-bin/tensorflow/examples/label_image/label_image
I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 1
E tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:890] could not open file to read NUMA node: /sys/bus/pci/devices/0000:00:00.0/numa_node
I tensorflow/core/common_runtime/gpu/gpu_init.cc:103] Found device 0 with properties: 
name: GK20A
major: 3 minor: 2 memoryClockRate (GHz) 0.852
pciBusID 0000:00:00.0
Total memory: 1.85GiB
Free memory: 218.46MiB
I tensorflow/core/common_runtime/gpu/gpu_init.cc:127] DMA: 0 
I tensorflow/core/common_runtime/gpu/gpu_init.cc:137] 0:   Y 
I tensorflow/core/common_runtime/gpu/gpu_device.cc:702] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GK20A, pci bus id: 0000:00:00.0)
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Allocating 18.46MiB bytes.
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:52] GPU 0 memory begins at 0xa45ea000 extends to 0xa585f000
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 1.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 2.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 4.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 8.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 16.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 32.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 64.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 128.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 256.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 512.0KiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 1.00MiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 2.00MiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 4.00MiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 8.00MiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 16.00MiB
I tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:66] Creating bin of max chunk size 32.00MiB
I tensorflow/core/common_runtime/direct_session.cc:60] Direct session inter op parallelism threads: 1
I tensorflow/core/common_runtime/gpu/gpu_device.cc:702] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GK20A, pci bus id: 0000:00:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:702] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GK20A, pci bus id: 0000:00:00.0)
I tensorflow/examples/label_image/main.cc:221] military uniform (866): 0.902268
I tensorflow/examples/label_image/main.cc:221] bow tie (817): 0.05407
I tensorflow/examples/label_image/main.cc:221] suit (794): 0.0113196
I tensorflow/examples/label_image/main.cc:221] bulletproof vest (833): 0.0100269
I tensorflow/examples/label_image/main.cc:221] bearskin (849): 0.00649747