Skip to content

Commit f11b422

Browse files
ozabludafchollet
authored andcommitted
Simplify and fix count_params() (keras-team#8206)
* Simplify and fix count_params() * Replace non-ascii ' with an ascii one
1 parent fd3ac2a commit f11b422

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

keras/backend/cntk_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,13 +493,13 @@ def ones_like(x, name=None):
493493

494494

495495
def count_params(x):
496-
for _ in x.shape:
496+
for _ in get_variable_shape(x):
497497
if _ == C.InferredDimension or _ == C.FreeDimension:
498498
raise ValueError('CNTK backend: `count_params` with dynamic '
499499
'shape is not supported. Please provide '
500500
'fixed dimension instead of `None`.')
501501

502-
return np.prod([x.shape[i] for i in range(len(x.shape))])
502+
return np.prod(get_variable_shape(x))
503503

504504

505505
def cast(x, dtype):

keras/backend/tensorflow_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ def shape(x):
541541

542542

543543
def int_shape(x):
544-
"""Returns the shape tensor or variable as a tuple of int or None entries.
544+
"""Returns the shape of tensor or variable as a tuple of int or None entries.
545545
546546
# Arguments
547547
x: Tensor or variable.
@@ -865,13 +865,14 @@ def random_normal_variable(shape, mean, scale, dtype=None,
865865

866866

867867
def count_params(x):
868-
"""Returns the number of scalars in a Keras variable.
868+
"""Returns the static number of elements in a Keras variable or tensor.
869869
870870
# Arguments
871-
x: Keras variable.
871+
x: Keras variable or tensor.
872872
873873
# Returns
874-
Integer, the number of scalars in `x`.
874+
Integer, the number of elements in `x`, i.e., the product of the
875+
array's static dimensions.
875876
876877
# Example
877878
```python
@@ -883,8 +884,7 @@ def count_params(x):
883884
[ 0., 0., 0.]], dtype=float32)
884885
```
885886
"""
886-
shape = x.get_shape()
887-
return np.prod([shape[i]._value for i in range(len(shape))])
887+
return np.prod(get_variable_shape(x))
888888

889889

890890
def cast(x, dtype):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy