-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtensormon
executable file
·51 lines (43 loc) · 1.29 KB
/
tensormon
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/python3
import argparse
import time
import tensorcom
parser = argparse.ArgumentParser("show tensor inputs")
parser.add_argument("input", nargs="*")
parser.add_argument("-R", "--raw", action="store_true")
parser.add_argument("-r", "--report", type=int, default=10)
parser.add_argument("-c", "--count", type=int, default=999999999)
args = parser.parse_args()
if args.input == []:
args.input = ["zsub://127.0.0.1:7880"]
def make_source():
print("input:", args.input)
source = tensorcom.Connection(device=None, raw=args.raw)
for c in args.input:
print(c)
source.connect(c)
return source
index = 0
total = 0
while True:
source = make_source()
for i, batch in enumerate(source.items()):
if index == 0:
print("connected")
last = time.time()
index += 1
total += 1
bs = len(batch[0])
if index % args.report == 0:
delta = time.time() - last
print(
"{:20d} {:8.3f} batches/s {:8.3f} samples/s (batchsize: {:d})".format(
index, total / delta, total / delta * bs, bs
)
)
total = 0
last = time.time()
if index > args.count:
break
if index > args.count:
break