AI Development/Tensorflow

[TensorFlow] Graph 에서 dropout 을 제거하는 방법

꾸준희 2019. 8. 3. 19:41
728x90
반응형

 

 

 

흔히들 드롭 아웃을 적용하여 네트워크를 설계하는데

 

다른 플랫폼에서 고정된 그래프를 사용하고자 할 때 다음과 같은 오류가 발생한다고 한다. 

 

Invalid argument: No OpKernel was registered to support Op 'RandomUniform' with these attrs.  Registered devices: [CPU], Registered kernels:
  <no registered kernels>

     [[Node: dropout/random_uniform/RandomUniform = RandomUniform[T=DT_INT32, dtype=DT_FLOAT, seed=0, seed2=0](dropout/Shape)]]

 

 

이 때 만들어진 pb 파일을 이용하여 드롭 아웃을 제거하는 과정을 거친다.

 

How to remove dropout from frozen model 

from __future__ import print_function
from tensorflow.core.framework import graph_pb2
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('/tmp/data/', one_hot=True)

def display_nodes(nodes):
    for i, node in enumerate(nodes):
        print('%d %s %s' % (i, node.name, node.op))
        [print(u'└─── %d ─ %s' % (i, n)) for i, n in enumerate(node.input)]
        
def accuracy(predictions, labels):
    return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])

def test_graph(graph_path, use_dropout):
    tf.reset_default_graph()
    graph_def = tf.GraphDef()
    
    with tf.gfile.FastGFile(graph_path, 'rb') as f:
        graph_def.ParseFromString(f.read())
        
    _ = tf.import_graph_def(graph_def, name='')
    sess = tf.Session()    
    prediction_tensor = sess.graph.get_tensor_by_name('final_result:0') 
    
    feed_dict = {'input:0': mnist.test.images[:256]}
    if use_dropout:
        feed_dict['keep_prob:0'] = 1.0
        
    predictions = sess.run(prediction_tensor, feed_dict)
    result = accuracy(predictions, mnist.test.labels[:256])
    return result
    
    
    
    
# read frozen graph and display nodes
graph = tf.GraphDef()
with tf.gfile.Open('./frozen_model.pb', 'r') as f:
    data = f.read()
    graph.ParseFromString(data)
    
display_nodes(graph.node)
    
    
    
    
# Connect 'MatMul_1' with 'Relu_2'
graph.node[44].input[0] = 'Relu_2' # 44 -> MatMul_1
# Remove dropout nodes
nodes = graph.node[:33] + graph.node[44:] # 33 -> MatMul_1 
del nodes[1] # 1 -> keep_prob

# Save graph
output_graph = graph_pb2.GraphDef()
output_graph.node.extend(nodes)
with tf.gfile.GFile('./frozen_model_without_dropout.pb', 'w') as f:
    f.write(output_graph.SerializeToString())
    



# test graph via simple test
result_1 = test_graph('./frozen_model.pb', use_dropout=True)
result_2 = test_graph('./frozen_model_without_dropout.pb', use_dropout=False)

print('with dropout:    %f' % result_1)
print('without dropout: %f' % result_2)
    

    
    
    

 

 

dropout 을 제거한 모델과 원 모델의 정확도를 비교하면 같다고 한다. 

 

 

 

 

참고자료

https://dato.ml/drop-dropout-from-frozen-model/

 

Drop dropout from Tensorflow - Dato ML

How to remove dropout from frozen tensorflow model.

dato.ml

 

728x90
반응형
1 2 3 4 5 6 7 ··· 15