We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b771b97 commit bbe9f3dCopy full SHA for bbe9f3d
utils.py
@@ -228,10 +228,16 @@ def num_or_str(x):
228
return str(x).strip()
229
230
231
-def normalize(numbers):
+def normalize(dist):
232
"""Multiply each number by a constant such that the sum is 1.0"""
233
- total = float(sum(numbers))
234
- return [(n / total) for n in numbers]
+ if isinstance(dist, dict):
+ total = sum(dist.values())
235
+ for key in dist:
236
+ dist[key] = dist[key] / total
237
+ assert 0 <= dist[key] <= 1, "Probabilities must be between 0 and 1."
238
+ return dist
239
+ total = sum(dist)
240
+ return [(n / total) for n in dist]
241
242
243
def clip(x, lowest, highest):
0 commit comments