-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDotProduct.java
72 lines (64 loc) · 2.66 KB
/
DotProduct.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import java.math.BigDecimal;
import java.util.Arrays;
import java.util.Scanner;
import java.util.concurrent.ThreadLocalRandom;
public class DotProduct {
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
System.out.print("生成的长度:");
int n = in.nextInt();
System.out.println();
long startTime = System.currentTimeMillis();
double[] vector1 = generateVector(n);
double[] vector2 = generateVector(n);
advancedCalculate(vector1,vector2,8);
long finishTime = System.currentTimeMillis();
System.out.println("花费时间:" + (finishTime - startTime) + "ms");
}
public static double[] generateVector(int n){
double[] nums = new double[n];
Arrays.parallelSetAll(nums, i -> ThreadLocalRandom.current().nextDouble(-10000, 10000));
return nums;
}
public static void basicCalculate(double[] vector1, double[] vector2){
BigDecimal sum = new BigDecimal("0");
for(int i = 0; i < vector1.length; i++){
sum = sum.add(BigDecimal.valueOf(vector1[i] * vector2[i]));
}
System.out.println("结果是:" + sum);
}
public static void advancedCalculate(double[] vector1, double[] vector2, int numThreads) {
BigDecimal[] partialSums = new BigDecimal[numThreads];
for (int i = 0; i < numThreads; i++) {
partialSums[i] = BigDecimal.ZERO;
}
int n = vector1.length;
int chunkSize = (n + numThreads - 1) / numThreads; // 计算每个线程处理的元素个数
Thread[] threads = new Thread[numThreads];
for (int i = 0; i < numThreads; i++) {
final int tid = i;
threads[i] = new Thread(() -> {
int start = tid * chunkSize;
int end = Math.min(n, start + chunkSize); //防止爆掉
for (int j = start; j < end; j++) {
partialSums[tid] = partialSums[tid].add(BigDecimal.valueOf(vector1[j] * vector2[j])); //每一个线程处理一部分
}
});
threads[i].start();
}
// 等待所有线程执行完毕
for (int i = 0; i < numThreads; i++) {
try {
threads[i].join();
} catch (InterruptedException e) {
// ignored
}
}
// 将所有部分和累加起来
BigDecimal sum = BigDecimal.ZERO;
for (int i = 0; i < numThreads; i++) {
sum = sum.add(partialSums[i]);
}
System.out.println("结果是:" + sum);
}
}