AI Development/Tensorflow

[TensorFlow] 함수 내부에서 TensorFlow Graph 실행하기

꾸준희 2019. 8. 3. 14:28
728x90
반응형

 

 

 

텐서플로우 내부에서 코드를 예쁘게 구조화 하여 그래프(Graph)를 실행할 수 있다.

구현할 때, 몇가지 방법이 있다. 

 

 

1. Tensor 이름 전달하기 

with tf.Session(graph=graph) as sess:
	feed = {"Placeholder:0": 3}      
	print(sess.run("Add:0", feed_dict=feed))

위와 같이 Placeholder:0 이라는 텐서 자체를 feed 에 넘겨주는 것 대신, 

 

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()

with tf.Session(graph=graph) as sess:
     feed = {"a:0": 3}
     print(sess.run("b:0", feed_dict=feed))

위와 같이 with 문 내부에서 텐서를 정의하고 a 라는 Tensor 이름을 feed에 전달하는 것이다. 

마찬가지로, Add:0 텐서도 with 문 내부에서 정의하고, b 라는 이름을 전달하여 session을 run 시킬 수 있다.

 

 

 

 

2. build_graph() 함수를 통한 중요한 노드 반환

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g, a, b

graph, a, b = build_graph()

with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

1의 방법과는 달리, 그래프 g 뿐만 아니라 tf.placeholder 및 add 도 반환해주는 것이다. 

그러면 더 코드가 간결해 질 수 있고, 바로 노드를 사용할 수 있다. 

 

 

 

 

3. 컬렉션에 노드 추가하기 

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     for node in (a, b):
         g.add_to_collection("important_stuff", node)
     return g

graph = build_graph()

a, b = graph.get_collection("important_stuff")

with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

collection은 사용자가 식별할 수 있고, 쉽게 검색할 수 있는 Python 객체를 포함 시킬 수 있다.

이 객체들은 그래프 안에서 train_op 이나 하이퍼파라미터, learning rate 처럼 특별한 연산을 할 수 있다. 

사용자는 내보내고 싶은 컬렉션 목록을 지정할 수 있다. 

 

메타 그래프(MetaGraph) 파일을 그래프로 가져오려면 import_meta_graph() 함수를 사용하는데

처음부터 모델을 구축하지 않고, 가져와서 계속 훈련 시킬 수 있다. 

이 때 tf.add_to_collection('train_op', train_op) 함수를 사용하여 train op 을 추가하게 된다. 

 

또한, 실행 중인 모델을 메타 그래프로 내보낼 때 export_meta_graph() 함수를 사용하는데 

만약 collection_list 가 지정되지 않으면, 모든 컬렉션이 내보내진다. 

보통 collection_list 에는 input tensor 와 output tensor 가 포함된다. 

 

 

 

 

 

4. 텐서 이름 가져오기 

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()

a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]

with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

 

a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]

위와 같은 함수를 통하여 텐서를 가져올 수 있다. 

이 4번과 같은 방법은 누군가가 작성한 그래프로 실행할 때 유용하다.

 

 

 

참고자료 1

https://stackoverflow.com/questions/44418442/building-tensorflow-graphs-inside-of-functions

 

Building Tensorflow Graphs Inside of Functions

I'm learning Tensorflow and am trying to properly structure my code. I (more or less) know how to build graphs either bare or as class methods, but I'm trying to figure out how best to structure th...

stackoverflow.com

 

참고자료 2 

https://tensorflowkorea.gitbooks.io/tensorflow-kr/content/g3doc/how_tos/meta_graph/

 

메타 그래프 · 텐서플로우 문서 한글 번역본

CollectionDef map은 모델의 Variables, QueueRunners, etc와 같은 추가적인 요소를 더 자세히 설명합니다. Python 오브젝트를 MetaGraphDef로부터 직렬화하기 위해서, Python 클래스는 to_proto()와 from_proto()메소드를 실행하고, register_proto_function를 사용해서 시스템에 등록합니다. 예를 들어, def to_proto(self): """Converts a `Variabl

tensorflowkorea.gitbooks.io

 

728x90
반응형
1 2 3 4 5 6 7 8 9 ··· 15