Saltar al contenido
Portada » Tensorflow en Java

Tensorflow en Java

Introducción

TensorFlow es una librería de software libre para la programación de flujo de datos (dataflow). Fue originalmente desarrollado por Google y está disponible para diversas plataformas. La librería puede trabajar en un único núcleo pero puede beneficiarse de múltiples cores CPU, GPU o TPU disponibles.

En este tutorial repasaremos las cuestiones básicas de TensorFlow y como usarlo en Java. Es importante considerar que el API Java de Tensorflow es experimental y no tiene porqué mantenerse estable en el tiempo.

Elementos fundamentales

La computación con TensorFlow se mueve alrededor de dos conceptos básicos: Grafo (Graph) y Sesión (Session). Exploraremos a continuación los prerrequisitos necesarios para abordar el resto del tutorial.

TensorFlow Graph

Las computaciones vienen representadas como grafos en TensorFlow. Un grafo es típicamente un diagrama acíclico de operaciones y datos, como por ejemplo:

El grafo anterior representa la siguiente ecuación: f(x, y) = z = a * x + b * y.

Un grafo computacional en TensorFlow consiste en dos elementos:

  • Tensor: Representa la unidad fundamental de datos en TensorFlow. Son los bordes de un grafo computacional, representando el flujo de datos a través del grafo. Un tensor puede tener una forma con cualquier número de dimensiones. El número de dimensiones en un tensor se denomina como rango. Un valor escalar es un tensor de rango cero, un vector es un tensor de rango uno, una matriz de rango dos y así sucesivamente.
  • Operación: Son los nodos de un grafo computacional. Se refiere a las diversas computaciones que pueden ocurrir sobre los tensores que alimentan la operación. El resultado suelen ser tensores.

TensorFlow Session

Un grafo en TensorFlow es un mero esquema de la computación sin valores concretos. El grafo debe ser ejecutado en una sesión (session) de TensorFlow para que sean evaluados los tensores del grafo.

A partir de aquí lo único que nos falta por hacer es ejecutarlo desde el API de Java.

Setup con Maven

Para poder empezar a usar el API de TensorFlow en nuestro proyecto Java podemos incluir las siguientes dependencias Maven:

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.12.0</version>
</dependency>

Crear el grafo

Ahora crearemos el grafo usando el API Java de TensorFlow. Usaremos el API para resolver la función definida por la siguiente ecuación: z = 3 * x + 2 * y.

El primer paso es instanciar la clase Graph:

Graph graph = new Graph()

Ahora es necesario definir las operaciones necesarias. Es importante señalar que las operaciones en TensorFlow consumen y producen de cero a N tensores. Cada nodo del grafo es una operación incluyendo constantes y placeholders.

La clase Graph tiene una función llamada opBuilder() para construir cualquier clase de operación en TensorFlow.

Definición de Constantes

Empezaremos definiendo las operaciones constantes de nuestro grafo. Una operación constante necesita de un tensor para su valor:

Operation a = graph.opBuilder("Const", "a")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .setAttr("value", Tensor.<Double>create(3.0, Double.class))
  .build();		
Operation b = graph.opBuilder("Const", "b")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .setAttr("value", Tensor.<Double>create(2.0, Double.class))
  .build();

Definición de Placeholders

Mientras que en el caso de las constantes necesitamos proveerlas de valores, los placeholders no necesitan valor en tiempo de definición. Los valores de los placeholders (como tensores) serán suministrados cuando el grafo se ejecute en una Session.

Operation x = graph.opBuilder("Placeholder", "x")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .build();			
Operation y = graph.opBuilder("Placeholder", "y")
  .setAttr("dtype", DataType.fromClass(Double.class))
  .build();

Definición de las funciones

Definición de las operaciones matemáticas de nuestra ecuación, serán operaciones en TensorFlow.

Operation ax = graph.opBuilder("Mul", "ax")
  .addInput(a.output(0))
  .addInput(x.output(0))
  .build();			
Operation by = graph.opBuilder("Mul", "by")
  .addInput(b.output(0))
  .addInput(y.output(0))
  .build();
Operation z = graph.opBuilder("Add", "z")
  .addInput(ax.output(0))
  .addInput(by.output(0))
  .build();

Las operaciones reciben tensores que son el resultados de operaciones anteriores. Una operación puede resultar en uno o más tensores.

Visualizando el grafo

TensorFlow provee una utilidad llamada TensorBoard para facilitar la tarea de visualización. El API de Java no tiene la capacidad de generar un fichero de evento para que sea consumido por TensorBoard, podemos hacerlo usando el API de Python:

writer = tf.summary.FileWriter('.')
......
writer.add_graph(tf.get_default_graph())
writer.flush()

Trabajando con Session

Una vez creado el grafo computacional con TensorFlow tenemos que ejecutarlo.

Primero podemos visualizar el estado del grafo mediante:

System.out.println(z.output(0));

Con el resultado:

<Add 'z:0' shape=<unknown> dtype=DOUBLE>

El grafo se ha definido pero no ha sido ejecutado por eso los tensores no tienen aun ningún valor.

Definimos a continuación la sesión:

Session sess = new Session(graph)

Ahora ya podemos ejecutar el grafo

Tensor<Double> tensor = sess.runner().fetch("z")
  .feed("x", Tensor.<Double>create(3.0, Double.class))
  .feed("y", Tensor.<Double>create(6.0, Double.class))
  .run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());

La salida actual será 21.

Guardar y cargar modelos de TensorFlow

Salvar un modelo desde Python:

import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
with graph.as_default():
  a = tf.constant(2, name='a')
  b = tf.constant(3, name='b')
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  z = tf.math.add(a*x, b*y, name='z')
  sess = tf.Session()
  sess.run(z, feed_dict = {x: 2, y: 3})
  builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
  builder.save()

Cargar modelo en Java

SavedModelBundle model = SavedModelBundle.load("./model", "serve");	
Tensor<Integer> tensor = model.session().runner().fetch("z")
  .feed("x", Tensor.<Integer>create(3, Integer.class))
  .feed("y", Tensor.<Integer>create(3, Integer.class))
  .run().get(0).expect(Integer.class);	
System.out.println(tensor.intValue());


Más recursos Java gratuitos

AsuntoDescripción
Tutorial básico y sintaxisTutorial básico Java y sintaxis. Aprende los fundamentos del lenguaje.
Hilos (Threads)Aprende a manejar hilos y las cuestiones básicas de la concurrencia
Funciones LambdaAquí te enseñamos las nociones más importantes para arrancas con funciones lambda
PalíndromosPrograma de ejemplo para el uso de palíndromos en Java.
Máquina Virtual de JavaTe explicamos el funcionamiento de la máquina virtual de java (Java Virtual Machine – JVM)
JDK, JRE y JVMDiferencias entre el JDK, JRE y JVM.
Mejores libros Java en EspañolHazte con los mejores libros Java para aprender paso a paso y profundizar en las mejores prácticas
TensorFlowManejo del API de TensorFlow para la construcción de grafos de operaciones y su ejecución
Tutorial Log4jTutorial para el manejo de Log4j, herramienta ágil y flexible para la gestión de Logs en Java
Java SecurityEntiende y aplica las posibilidades que da Java para mantener la seguridad
Tutorial JConsoleAprende los conceptos básicos de monitorización de procesos Java con JConsole
JavaFXTutorial de JavaFX, librería gráfica moderna para construcción de GUIs en móvil, escritorio y web.
Estructuras de datos en JavaExplicación y ejemplos de las estructuras de datos más importantes: listas, pila, cola, arbol.
JavaapiConjunto de clases, interfaces, métodos y paquetes que forman parte de la plataforma Java estándar
Algoritmo HuffmanMétodo eficiente para codificar datos, asignando códigos más cortos a los caracteres más frecuentes