Skip to content

Commit 7b35041

Browse files
taehoonleefchollet
authored andcommitted
Fix cuDNN tests (keras-team#8189)
Move `times = []` from the outer to the inner loop Delete `clear_session()` which is redundant with `keras_test` Reduce the example size under the 3x speed-up is satisfied
1 parent ab30f73 commit 7b35041

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

tests/keras/layers/cudnn_recurrent_test.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -122,46 +122,45 @@ def test_cudnn_rnn_canonical_to_params_gru():
122122

123123

124124
@keras_test
125+
@pytest.mark.parametrize('rnn_type', ['lstm', 'gru'], ids=['LSTM', 'GRU'])
125126
@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'),
126127
reason='Requires TensorFlow backend')
127128
@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(),
128129
reason='Requires GPU')
129-
def test_cudnn_rnn_timing():
130+
def test_cudnn_rnn_timing(rnn_type):
130131
input_size = 1000
131132
timesteps = 60
132133
units = 256
133134
num_samples = 10000
134135

135136
times = []
136-
for rnn_type in ['lstm', 'gru']:
137-
for use_cudnn in [True, False]:
138-
start_time = time.time()
139-
inputs = keras.layers.Input(shape=(None, input_size))
140-
if use_cudnn:
141-
if rnn_type == 'lstm':
142-
layer = keras.layers.CuDNNLSTM(units)
143-
else:
144-
layer = keras.layers.CuDNNGRU(units)
137+
for use_cudnn in [True, False]:
138+
start_time = time.time()
139+
inputs = keras.layers.Input(shape=(None, input_size))
140+
if use_cudnn:
141+
if rnn_type == 'lstm':
142+
layer = keras.layers.CuDNNLSTM(units)
145143
else:
146-
if rnn_type == 'lstm':
147-
layer = keras.layers.LSTM(units)
148-
else:
149-
layer = keras.layers.GRU(units)
150-
outputs = layer(inputs)
144+
layer = keras.layers.CuDNNGRU(units)
145+
else:
146+
if rnn_type == 'lstm':
147+
layer = keras.layers.LSTM(units)
148+
else:
149+
layer = keras.layers.GRU(units)
150+
outputs = layer(inputs)
151151

152-
model = keras.models.Model(inputs, outputs)
153-
model.compile('sgd', 'mse')
152+
model = keras.models.Model(inputs, outputs)
153+
model.compile('sgd', 'mse')
154154

155-
x = np.random.random((num_samples, timesteps, input_size))
156-
y = np.random.random((num_samples, units))
157-
model.fit(x, y, epochs=4, batch_size=32)
155+
x = np.random.random((num_samples, timesteps, input_size))
156+
y = np.random.random((num_samples, units))
157+
model.fit(x, y, epochs=4, batch_size=32)
158158

159-
times.append(time.time() - start_time)
159+
times.append(time.time() - start_time)
160160

161-
speedup = times[1] / times[0]
162-
print(rnn_type, 'speedup', speedup)
163-
assert speedup > 3
164-
keras.backend.clear_session()
161+
speedup = times[1] / times[0]
162+
print(rnn_type, 'speedup', speedup)
163+
assert speedup > 3
165164

166165

167166
@keras_test

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy