Skip to content

Commit 9fecc51

Browse files
authored
Fix cuDNN RNNs (keras-team#8244)
* Quick fix cudnn rnns * Tests passing.
1 parent 7144aeb commit 9fecc51

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

keras/layers/cudnn_recurrent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ def __init__(self,
3636
self.supports_masking = False
3737
self.input_spec = [InputSpec(ndim=3)]
3838
if hasattr(self.cell.state_size, '__len__'):
39-
self.state_spec = [InputSpec(shape=(None, dim))
40-
for dim in self.cell.state_size]
39+
state_size = self.cell.state_size
4140
else:
42-
self.state_spec = InputSpec(shape=(None, self.cell.state_size))
41+
state_size = [self.cell.state_size]
42+
self.state_spec = [InputSpec(shape=(None, dim))
43+
for dim in state_size]
4344
self.constants_spec = None
4445
self._states = None
4546
self._num_constants = None

keras/layers/recurrent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,10 @@ def build(self, input_shape):
454454
# initial_state was passed in call, check compatibility
455455
if not [spec.shape[-1] for spec in self.state_spec] == state_size:
456456
raise ValueError(
457-
'an initial_state was passed that is not compatible with'
458-
' cell.state_size, state_spec: {}, cell.state_size:'
459-
' {}'.format(self.state_spec, self.cell.state_size))
457+
'An initial_state was passed that is not compatible with '
458+
'`cell.state_size`. Received `state_spec`={}; '
459+
'However `cell.state_size` is '
460+
'{}'.format(self.state_spec, self.cell.state_size))
460461
else:
461462
self.state_spec = [InputSpec(shape=(None, dim))
462463
for dim in state_size]

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy