AI Development/Tensorflow

[TensorFlow] Graph 에서 dropout 을 제거하는 방법

꾸준희 2019. 8. 3. 19:41




흔히들 드롭 아웃을 적용하여 네트워크를 설계하는데


다른 플랫폼에서 고정된 그래프를 사용하고자 할 때 다음과 같은 오류가 발생한다고 한다. 


Invalid argument: No OpKernel was registered to support Op 'RandomUniform' with these attrs.  Registered devices: [CPU], Registered kernels:
  <no registered kernels>

     [[Node: dropout/random_uniform/RandomUniform = RandomUniform[T=DT_INT32, dtype=DT_FLOAT, seed=0, seed2=0](dropout/Shape)]]



이 때 만들어진 pb 파일을 이용하여 드롭 아웃을 제거하는 과정을 거친다.


How to remove dropout from frozen model 

from __future__ import print_function
from tensorflow.core.framework import graph_pb2
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('/tmp/data/', one_hot=True)

def display_nodes(nodes):
    for i, node in enumerate(nodes):
        print('%d %s %s' % (i,, node.op))
        [print(u'└─── %d ─ %s' % (i, n)) for i, n in enumerate(node.input)]
def accuracy(predictions, labels):
    return (100.0 * np.sum(np.argmax(predictions, 1) == np.argmax(labels, 1)) / predictions.shape[0])

def test_graph(graph_path, use_dropout):
    graph_def = tf.GraphDef()
    with tf.gfile.FastGFile(graph_path, 'rb') as f:
    _ = tf.import_graph_def(graph_def, name='')
    sess = tf.Session()    
    prediction_tensor = sess.graph.get_tensor_by_name('final_result:0') 
    feed_dict = {'input:0': mnist.test.images[:256]}
    if use_dropout:
        feed_dict['keep_prob:0'] = 1.0
    predictions =, feed_dict)
    result = accuracy(predictions, mnist.test.labels[:256])
    return result
# read frozen graph and display nodes
graph = tf.GraphDef()
with tf.gfile.Open('./frozen_model.pb', 'r') as f:
    data =
# Connect 'MatMul_1' with 'Relu_2'
graph.node[44].input[0] = 'Relu_2' # 44 -> MatMul_1
# Remove dropout nodes
nodes = graph.node[:33] + graph.node[44:] # 33 -> MatMul_1 
del nodes[1] # 1 -> keep_prob

# Save graph
output_graph = graph_pb2.GraphDef()
with tf.gfile.GFile('./frozen_model_without_dropout.pb', 'w') as f:

# test graph via simple test
result_1 = test_graph('./frozen_model.pb', use_dropout=True)
result_2 = test_graph('./frozen_model_without_dropout.pb', use_dropout=False)

print('with dropout:    %f' % result_1)
print('without dropout: %f' % result_2)




dropout 을 제거한 모델과 원 모델의 정확도를 비교하면 같다고 한다. 







Drop dropout from Tensorflow - Dato ML

How to remove dropout from frozen tensorflow model.


1 2 3 4 5 6 7 ··· 15