Skip to content

Commit 8a26b28

Browse files
krishnaw14norvig
authored andcommitted
Added function and test cases for cross-entropy loss (aimacode#853)
* Correction in the formula for mean square error * Added cross-entropy loss * Test case for cross-entropy loss * Decimal point mistake * Added spaces around = and ==
1 parent d1ea3fe commit 8a26b28

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

learning.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
def euclidean_distance(X, Y):
2222
return math.sqrt(sum((x - y)**2 for x, y in zip(X, Y)))
2323

24+
def cross_entropy_loss(X,Y):
25+
n=len(X)
26+
return (-1.0/n)*sum(x*math.log(y)+(1-x)*math.log(1-y) for x,y in zip(X,Y) )
27+
2428

2529
def rms_error(X, Y):
2630
return math.sqrt(ms_error(X, Y))

tests/test_learning.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ def test_euclidean():
1818
distance = euclidean_distance([0, 0, 0], [0, 0, 0])
1919
assert distance == 0
2020

21+
def test_cross_entropy():
22+
loss = cross_entropy_loss([1,0], [0.9, 0.3])
23+
assert round(loss,2) == 0.23
24+
25+
loss = cross_entropy_loss([1,0,0,1], [0.9,0.3,0.5,0.75])
26+
assert round(loss,2) == 0.36
27+
28+
loss = cross_entropy_loss([1,0,0,1,1,0,1,1], [0.9,0.3,0.5,0.75,0.85,0.14,0.93,0.79])
29+
assert round(loss,2) == 0.26
30+
2131

2232
def test_rms_error():
2333
assert rms_error([2, 2], [2, 2]) == 0

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy