@@ -122,46 +122,45 @@ def test_cudnn_rnn_canonical_to_params_gru():
122
122
123
123
124
124
@keras_test
125
+ @pytest .mark .parametrize ('rnn_type' , ['lstm' , 'gru' ], ids = ['LSTM' , 'GRU' ])
125
126
@pytest .mark .skipif ((keras .backend .backend () != 'tensorflow' ),
126
127
reason = 'Requires TensorFlow backend' )
127
128
@pytest .mark .skipif (not keras .backend .tensorflow_backend ._get_available_gpus (),
128
129
reason = 'Requires GPU' )
129
- def test_cudnn_rnn_timing ():
130
+ def test_cudnn_rnn_timing (rnn_type ):
130
131
input_size = 1000
131
132
timesteps = 60
132
133
units = 256
133
134
num_samples = 10000
134
135
135
136
times = []
136
- for rnn_type in ['lstm' , 'gru' ]:
137
- for use_cudnn in [True , False ]:
138
- start_time = time .time ()
139
- inputs = keras .layers .Input (shape = (None , input_size ))
140
- if use_cudnn :
141
- if rnn_type == 'lstm' :
142
- layer = keras .layers .CuDNNLSTM (units )
143
- else :
144
- layer = keras .layers .CuDNNGRU (units )
137
+ for use_cudnn in [True , False ]:
138
+ start_time = time .time ()
139
+ inputs = keras .layers .Input (shape = (None , input_size ))
140
+ if use_cudnn :
141
+ if rnn_type == 'lstm' :
142
+ layer = keras .layers .CuDNNLSTM (units )
145
143
else :
146
- if rnn_type == 'lstm' :
147
- layer = keras .layers .LSTM (units )
148
- else :
149
- layer = keras .layers .GRU (units )
150
- outputs = layer (inputs )
144
+ layer = keras .layers .CuDNNGRU (units )
145
+ else :
146
+ if rnn_type == 'lstm' :
147
+ layer = keras .layers .LSTM (units )
148
+ else :
149
+ layer = keras .layers .GRU (units )
150
+ outputs = layer (inputs )
151
151
152
- model = keras .models .Model (inputs , outputs )
153
- model .compile ('sgd' , 'mse' )
152
+ model = keras .models .Model (inputs , outputs )
153
+ model .compile ('sgd' , 'mse' )
154
154
155
- x = np .random .random ((num_samples , timesteps , input_size ))
156
- y = np .random .random ((num_samples , units ))
157
- model .fit (x , y , epochs = 4 , batch_size = 32 )
155
+ x = np .random .random ((num_samples , timesteps , input_size ))
156
+ y = np .random .random ((num_samples , units ))
157
+ model .fit (x , y , epochs = 4 , batch_size = 32 )
158
158
159
- times .append (time .time () - start_time )
159
+ times .append (time .time () - start_time )
160
160
161
- speedup = times [1 ] / times [0 ]
162
- print (rnn_type , 'speedup' , speedup )
163
- assert speedup > 3
164
- keras .backend .clear_session ()
161
+ speedup = times [1 ] / times [0 ]
162
+ print (rnn_type , 'speedup' , speedup )
163
+ assert speedup > 3
165
164
166
165
167
166
@keras_test
0 commit comments