Given two sparse matrices A and B, return the result of AB.

You may assume that A’s column number is equal to B’s row number.

Example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
A = [
[ 1, 0, 0],
[-1, 0, 3]
]

B = [
[ 7, 0, 0 ],
[ 0, 0, 0 ],
[ 0, 0, 1 ]
]


| 1 0 0 | | 7 0 0 | | 7 0 0 |
AB = | -1 0 3 | x | 0 0 0 | = | -7 0 3 |
| 0 0 1 |

Optimized brute force

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class Solution {
public int[][] multiply(int[][] A, int[][] B) {
int row = A.length, column = B[0].length, colA = A[0].length;
int[][] res = new int[row][column];

for (int i = 0; i < row; i++){
for (int j = 0; j < colA; j++){
if (A[i][j] != 0){
for (int k = 0; k < column; k++){
if (B[j][k] != 0){
res[i][k] += A[i][j] * B[j][k];
}
}
}
}
}
return res;
}
}

However, this solution still checks matrix B multiple times.

One hash table that build index for non-zero values in each row of Matrix B

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public class Solution {
public int[][] multiply(int[][] A, int[][] B) {
if (A == null || A[0] == null || B == null || B[0] == null) return null;
int m = A.length, n = A[0].length, l = B[0].length;
int[][] C = new int[m][l];
Map<Integer, HashMap<Integer, Integer>> tableB = new HashMap<>(); //

for(int k = 0; k < n; k++) {
tableB.put(k, new HashMap<Integer, Integer>());
for(int j = 0; j < l; j++) {
if (B[k][j] != 0){
tableB.get(k).put(j, B[k][j]);
}
}
}

for(int i = 0; i < m; i++) {
for(int k = 0; k < n; k++) {
if (A[i][k] != 0){
for (Integer j: tableB.get(k).keySet()) {
C[i][j] += A[i][k] * tableB.get(k).get(j);
}
}
}
}
return C;
}
}

Two hash tables for both matrix A and B

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
public class Solution {
public int[][] multiply(int[][] A, int[][] B) {
if (A == null || B == null || A.length == 0 || B.length == 0 || A[0].length == 0 || B[0].length == 0)
return null;
int m = A.length, n = A[0].length, l = B[0].length;
if (n != l)
throw new IllegalArgumentException("A's column number must be equal to B's row number.");
int[][] res = new int[m][l];

HashMap<Integer, HashMap<Integer, Integer>> matrixA = convertToMatrix(A);
HashMap<Integer, HashMap<Integer, Integer>> matrixB = convertToMatrix(B);

for (Integer i : matrixA.keySet()) {
for (Integer k : matrixA.get(i).keySet()) {
if (matrixB.containsKey(k)) {
for (Integer j : matrixB.get(k).keySet()) {
res[i][j] += matrixA.get(i).get(k) * matrixB.get(k).get(j);
}
}
}
}
return res;
}

public HashMap<Integer, HashMap<Integer, Integer>> convertToMatrix(int[][] A) {
HashMap<Integer, HashMap<Integer, Integer>> map = new HashMap<>();
for (int i = 0; i < A.length; i++) {
map.put(i, new HashMap<Integer, Integer>());
for (int j = 0; j < A[i].length; j++) {
if (A[i][j] != 0) {
map.get(i).put(j, A[i][j]);
}
}
}
return map;
}
}