Flatten A TensorFlow Tensor

Use the TensorFlow reshape operation to flatten a TensorFlow Tensor

Type: FREE   By: Sebastian Gutierrez, AIWorkbox.com Instructor Sebastian Gutierrez   Duration: 3:17   Technologies: TensorFlow, Python

Page Sections: Video  |  Code  |  Transcript


< > Code:

You must be a Member to view code

Access all courses and lessons, gain confidence and expertise, and learn how things work and how to use them.

    or   Log In


Transcript:

This video will show you how to use the TensorFlow reshape operation to flatten a TensorFlow tensor.


First, we import TensorFlow as tf.

import tensorflow as tf


Next, we print out what version of TensorFlow we are using.

print(tf.__version__)

We are using TensorFlow 1.10.0.


Let’s start out with an initial TensorFlow constant tensor, so tf.constant, shaped 2x3x4, with numerical integer values between 1 and 24, all of whom have the data type of int32.

tf_initial_tensor_constant = tf.constant(
[
    [
        [ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]
    ]
    ,
    [
        [13, 14, 15, 16],
        [17, 18, 19, 20],
        [21, 22, 23, 24],
    ]
]
, dtype="int32"
)

And we assign all of this tensor constant to the Python variable tf_initial_tensor_constant.


Then let’s print this tf_initial_tensor_constant variable to see what we have.

print(tf_initial_tensor_constant)

We see that it’s a tensor, we see that the shape is 2x3x4, and the data type is int32.


Now that we have created our TensorFlow tensor, it’s time to run the computational graph.


So we launch the graph in a session.

sess = tf.Session()


Then we initialize all the global variables in the graph.

sess.run(tf.global_variables_initializer())


Now that we have a TensorFlow session, let’s print the value of our tf_initial_tensor_constant Python variable inside of a TensorFlow session.

print(sess.run(tf_initial_tensor_constant))

So print(sess.run(tf_initial_tensor_constant)).

There we go.

We see our tensor with the numbers 1 to 24 and its shape is still 2x3x4.


When we flatten this TensorFlow tensor, we will want there to only be one dimension rather than the three dimensions we currently have in this tensor and we want that one dimension to be 24, that is 2x3 = 6 x 4 = 24.

So it will just be one flat tensor.


To flatten the tensor, we’re going to use the TensorFlow reshape operation.


So tf.reshape, we pass in our tensor currently represented by tf_initial_tensor_constant, and then the shape that we’re going to give it is a -1 inside of a Python list.

flattened_tensor_example = tf.reshape(tf_initial_tensor_constant, [-1])

What this does is we’re saying take whatever shape is currently inside and then convert it to a flat tensor.

This flattened tensor is going to be assigned to the Python variable flattened_tensor_example.


Now that we have our flattened tensor, let’s use the TensorFlow get shape operation to see what shape our flattened tensor has.

print(flattened_tensor_example.get_shape())

We see that it’s (24,).

So it is a flat tensor with 24 elements in one dimension.

So we went from a 2x3x4 tensor to a flat tensor with 24 entries.


Let’s print the flattened_tensor_example Python variable to see what the tensor looks like.

print(flattened_tensor_example)

Notice that we’re not printing it in a TensorFlow session so it should give us the actual node.

When we do that, we see that it’s a TensorFlow tensor node, it’s been given the name Reshape, and the shape is (24,), and the data type is int32.

So we went from our initial tensor which was a 2x3x4 shape to a shape of (24,) and the data type remains the same.


Lastly, let’s print the flattened tensor example inside of a TensorFlow session run to see what the values actually look like.

print(sess.run(flattened_tensor_example))

When we do that, we see that it is in fact a tensor that is flat and it has all of our initial numbers, 1 all the way to 24.

So we see all of our values in a flattened tensor.


Perfect - We were able to use the TensorFlow reshape operation with a -1 to flatten a TensorFlow tensor.



Back to deep learning tutorial lesson list