텐서플로우 내부에서 코드를 예쁘게 구조화 하여 그래프(Graph)를 실행할 수 있다.
구현할 때, 몇가지 방법이 있다.
1. Tensor 이름 전달하기
with tf.Session(graph=graph) as sess:
feed = {"Placeholder:0": 3}
print(sess.run("Add:0", feed_dict=feed))
위와 같이 Placeholder:0 이라는 텐서 자체를 feed 에 넘겨주는 것 대신,
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {"a:0": 3}
print(sess.run("b:0", feed_dict=feed))
위와 같이 with 문 내부에서 텐서를 정의하고 a 라는 Tensor 이름을 feed에 전달하는 것이다.
마찬가지로, Add:0 텐서도 with 문 내부에서 정의하고, b 라는 이름을 전달하여 session을 run 시킬 수 있다.
2. build_graph() 함수를 통한 중요한 노드 반환
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g, a, b
graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
1의 방법과는 달리, 그래프 g 뿐만 아니라 tf.placeholder 및 add 도 반환해주는 것이다.
그러면 더 코드가 간결해 질 수 있고, 바로 노드를 사용할 수 있다.
3. 컬렉션에 노드 추가하기
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
for node in (a, b):
g.add_to_collection("important_stuff", node)
return g
graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
collection은 사용자가 식별할 수 있고, 쉽게 검색할 수 있는 Python 객체를 포함 시킬 수 있다.
이 객체들은 그래프 안에서 train_op 이나 하이퍼파라미터, learning rate 처럼 특별한 연산을 할 수 있다.
사용자는 내보내고 싶은 컬렉션 목록을 지정할 수 있다.
메타 그래프(MetaGraph) 파일을 그래프로 가져오려면 import_meta_graph() 함수를 사용하는데
처음부터 모델을 구축하지 않고, 가져와서 계속 훈련 시킬 수 있다.
이 때 tf.add_to_collection('train_op', train_op) 함수를 사용하여 train op 을 추가하게 된다.
또한, 실행 중인 모델을 메타 그래프로 내보낼 때 export_meta_graph() 함수를 사용하는데
만약 collection_list 가 지정되지 않으면, 모든 컬렉션이 내보내진다.
보통 collection_list 에는 input tensor 와 output tensor 가 포함된다.
4. 텐서 이름 가져오기
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
위와 같은 함수를 통하여 텐서를 가져올 수 있다.
이 4번과 같은 방법은 누군가가 작성한 그래프로 실행할 때 유용하다.
참고자료 1
https://stackoverflow.com/questions/44418442/building-tensorflow-graphs-inside-of-functions
참고자료 2
https://tensorflowkorea.gitbooks.io/tensorflow-kr/content/g3doc/how_tos/meta_graph/
'AI Development > TensorFlow | TFLite' 카테고리의 다른 글
[TensorFlow] Graph 에서 dropout 을 제거하는 방법 (0) | 2019.08.03 |
---|---|
[TensorFlow] .ckpt vs .pb vs .pbtxt 차이점 (15) | 2019.08.03 |
[TensorFlow] pb 파일 TensorBoard에 띄우기 (TF 1.x 버전용) (0) | 2019.05.31 |
[TensorFlow] Anaconda 가상환경 이용하여 TensorFlow GPU 설치 (2) | 2019.03.05 |
[Tensorflow] Tensorflow GPU 버전 설치하기 (2) | 2018.01.11 |