tf.cast()

tf.cast(x, dtype, name=None)

The operation casts x (in case of Tensor) or x.values (in case of SparseTensor or IndexedSlices) to dtype.

해석하자면 x값을 새로운 형태의 dtype으로 캐스팅한다는 의미입니다.

  • 부동 소수점형에서 정수형으로 바꾼 경우 소수점을 버린다.

  • Boolean으로 참조한 경우 True이면 1, False이면 0을 출력한다.

  • 예시를 보면 이해가 될겁니다.

1
2
3
4
5
6
7
8
9
10
11
12
x = tf.constant([1.8, 2.2, 3.3], dtype=tf.float32)
print(x)
# tf.Tensor([1.8 2.2 3.3], shape=(3,), dtype=float32)


tf.cast(x, tf.int32)
# 출력 결과를 보시면 반올림, 내림이 아닌 소수점을 버립니다.
# <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>


tf.cast(x>2, tf.float32)
# <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 1., 1.], dtype=float32)>
Author

InhwanCho

Posted on

2023-01-08

Updated on

2023-01-08

Licensed under

Comments