Skip to content

Commit 2dfad7e

Browse files
Add matrix multiplication with double[][] and unit tests (TheAlgorithms#6417)
* MatrixMultiplication.java created and updated. * Add necessary comment to MatrixMultiplication.java * Create MatrixMultiplicationTest.java * method for 2 by 2 matrix multiplication is created * Use assertMatrixEquals(), otherwise there can be error due to floating point arithmetic errors * assertMatrixEquals method created and updated * method created for 3by2 matrix multiply with 2by1 matrix * method created for null matrix multiplication * method for test matrix dimension error * method for test empty matrix input * testMultiply3by2and2by1 test case updated * Check for empty matrices part updated * Updated Unit test coverage * files updated * clean the code * clean the code * Updated files with google-java-format * Updated files * Updated files * Updated files * Updated files * Add reference links and complexities * Add test cases for 1by1 matrix and non-rectangular matrix * Add reference links and complexities --------- Co-authored-by: Deniz Altunkapan <93663085+DenizAltunkapan@users.noreply.github.com>
1 parent 2722b0e commit 2dfad7e

File tree

2 files changed

+170
-0
lines changed

2 files changed

+170
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package com.thealgorithms.matrix;
2+
3+
/**
4+
* This class provides a method to perform matrix multiplication.
5+
*
6+
* <p>Matrix multiplication takes two 2D arrays (matrices) as input and
7+
* produces their product, following the mathematical definition of
8+
* matrix multiplication.
9+
*
10+
* <p>For more details:
11+
* https://www.geeksforgeeks.org/java/java-program-to-multiply-two-matrices-of-any-size/
12+
* https://en.wikipedia.org/wiki/Matrix_multiplication
13+
*
14+
* <p>Time Complexity: O(n^3) – where n is the dimension of the matrices
15+
* (assuming square matrices for simplicity).
16+
*
17+
* <p>Space Complexity: O(n^2) – for storing the result matrix.
18+
*
19+
*
20+
* @author Nishitha Wihala Pitigala
21+
*
22+
*/
23+
24+
public final class MatrixMultiplication {
25+
private MatrixMultiplication() {
26+
}
27+
28+
/**
29+
* Multiplies two matrices.
30+
*
31+
* @param matrixA the first matrix rowsA x colsA
32+
* @param matrixB the second matrix rowsB x colsB
33+
* @return the product of the two matrices rowsA x colsB
34+
* @throws IllegalArgumentException if the matrices cannot be multiplied
35+
*/
36+
public static double[][] multiply(double[][] matrixA, double[][] matrixB) {
37+
// Check the input matrices are not null
38+
if (matrixA == null || matrixB == null) {
39+
throw new IllegalArgumentException("Input matrices cannot be null");
40+
}
41+
42+
// Check for empty matrices
43+
if (matrixA.length == 0 || matrixB.length == 0 || matrixA[0].length == 0 || matrixB[0].length == 0) {
44+
throw new IllegalArgumentException("Input matrices must not be empty");
45+
}
46+
47+
// Validate the matrix dimensions
48+
if (matrixA[0].length != matrixB.length) {
49+
throw new IllegalArgumentException("Matrices cannot be multiplied: incompatible dimensions.");
50+
}
51+
52+
int rowsA = matrixA.length;
53+
int colsA = matrixA[0].length;
54+
int colsB = matrixB[0].length;
55+
56+
// Initialize the result matrix with zeros
57+
double[][] result = new double[rowsA][colsB];
58+
59+
// Perform matrix multiplication
60+
for (int i = 0; i < rowsA; i++) {
61+
for (int j = 0; j < colsB; j++) {
62+
for (int k = 0; k < colsA; k++) {
63+
result[i][j] += matrixA[i][k] * matrixB[k][j];
64+
}
65+
}
66+
}
67+
return result;
68+
}
69+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package com.thealgorithms.matrix;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
import static org.junit.jupiter.api.Assertions.assertTrue;
6+
7+
import org.junit.jupiter.api.Test;
8+
9+
public class MatrixMultiplicationTest {
10+
11+
private static final double EPSILON = 1e-9; // for floating point comparison
12+
13+
@Test
14+
void testMultiply1by1() {
15+
double[][] matrixA = {{1.0}};
16+
double[][] matrixB = {{2.0}};
17+
double[][] expected = {{2.0}};
18+
19+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
20+
assertMatrixEquals(expected, result);
21+
}
22+
23+
@Test
24+
void testMultiply2by2() {
25+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
26+
double[][] matrixB = {{5.0, 6.0}, {7.0, 8.0}};
27+
double[][] expected = {{19.0, 22.0}, {43.0, 50.0}};
28+
29+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
30+
assertMatrixEquals(expected, result); // Use custom method due to floating point issues
31+
}
32+
33+
@Test
34+
void testMultiply3by2and2by1() {
35+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}};
36+
double[][] matrixB = {{7.0}, {8.0}};
37+
double[][] expected = {{23.0}, {53.0}, {83.0}};
38+
39+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
40+
assertMatrixEquals(expected, result);
41+
}
42+
43+
@Test
44+
void testMultiplyNonRectangularMatrices() {
45+
double[][] matrixA = {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}};
46+
double[][] matrixB = {{7.0, 8.0}, {9.0, 10.0}, {11.0, 12.0}};
47+
double[][] expected = {{58.0, 64.0}, {139.0, 154.0}};
48+
49+
double[][] result = MatrixMultiplication.multiply(matrixA, matrixB);
50+
assertMatrixEquals(expected, result);
51+
}
52+
53+
@Test
54+
void testNullMatrixA() {
55+
double[][] b = {{1, 2}, {3, 4}};
56+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(null, b));
57+
}
58+
59+
@Test
60+
void testNullMatrixB() {
61+
double[][] a = {{1, 2}, {3, 4}};
62+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, null));
63+
}
64+
65+
@Test
66+
void testMultiplyNull() {
67+
double[][] matrixA = {{1.0, 2.0}, {3.0, 4.0}};
68+
double[][] matrixB = null;
69+
70+
Exception exception = assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(matrixA, matrixB));
71+
72+
String expectedMessage = "Input matrices cannot be null";
73+
String actualMessage = exception.getMessage();
74+
75+
assertTrue(actualMessage.contains(expectedMessage));
76+
}
77+
78+
@Test
79+
void testIncompatibleDimensions() {
80+
double[][] a = {{1.0, 2.0}};
81+
double[][] b = {{1.0, 2.0}};
82+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
83+
}
84+
85+
@Test
86+
void testEmptyMatrices() {
87+
double[][] a = new double[0][0];
88+
double[][] b = new double[0][0];
89+
assertThrows(IllegalArgumentException.class, () -> MatrixMultiplication.multiply(a, b));
90+
}
91+
92+
private void assertMatrixEquals(double[][] expected, double[][] actual) {
93+
assertEquals(expected.length, actual.length, "Row count mismatch");
94+
for (int i = 0; i < expected.length; i++) {
95+
assertEquals(expected[i].length, actual[i].length, "Column count mismatch at row " + i);
96+
for (int j = 0; j < expected[i].length; j++) {
97+
assertEquals(expected[i][j], actual[i][j], EPSILON, "Mismatch at (" + i + "," + j + ")");
98+
}
99+
}
100+
}
101+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy