Skip to content

Commit 6eaab4c

Browse files
authored
Add script to run benchmarks and script to download data (#39)
Example usage: ``` $ python scripts/bench.py -b graphblas -f pagerank -d amazon0302 ```
1 parent 83bfa40 commit 6eaab4c

File tree

5 files changed

+377
-250
lines changed

5 files changed

+377
-250
lines changed

environment.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ dependencies:
4646
- pre-commit
4747
# For testing
4848
- pytest-cov
49+
# For benchmarking
50+
- requests
4951
# For debugging
5052
- icecream
5153
- ipython
54+
# For type annotations
55+
- mypy

scripts/bench.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#!/usr/bin/env python
2+
import argparse
3+
import json
4+
import os
5+
import statistics
6+
import sys
7+
import timeit
8+
9+
import download_data
10+
import graphblas as gb
11+
import networkx as nx
12+
import numpy as np
13+
import scipy.sparse
14+
15+
import graphblas_algorithms as ga
16+
import scipy_impl
17+
from graphblas_algorithms.interface import Dispatcher as ga_dispatcher
18+
19+
thisdir = os.path.dirname(__file__)
20+
datapaths = [
21+
os.path.join(thisdir, "..", "data"),
22+
os.path.curdir,
23+
]
24+
25+
26+
def find_data(dataname):
27+
if os.path.exists(dataname):
28+
return os.path.relpath(dataname)
29+
for path in datapaths:
30+
path = os.path.join(path, dataname) + ".mtx"
31+
if os.path.exists(path):
32+
return os.path.relpath(path)
33+
path = path.removesuffix(".mtx")
34+
if os.path.exists(path):
35+
return os.path.relpath(path)
36+
if dataname in download_data.data_urls:
37+
return os.path.relpath(download_data.main([dataname])[0])
38+
raise FileNotFoundError(f"Unable to find data file for {dataname}")
39+
40+
41+
def get_symmetry(file_or_mminfo):
42+
if not isinstance(file_or_mminfo, tuple):
43+
mminfo = scipy.io.mminfo(file_or_mminfo)
44+
else:
45+
mminfo = file_or_mminfo
46+
return mminfo[5]
47+
48+
49+
def readfile(filename, is_symmetric, backend):
50+
name = filename.split(".", 1)[0].rsplit("/", 1)[0]
51+
if backend == "graphblas":
52+
A = gb.io.mmread(filename, name=name)
53+
A.wait()
54+
if is_symmetric:
55+
return ga.Graph(A)
56+
return ga.DiGraph(A)
57+
a = scipy.io.mmread(filename)
58+
if backend == "networkx":
59+
create_using = nx.Graph if is_symmetric else nx.DiGraph
60+
return nx.from_scipy_sparse_array(a, create_using=create_using)
61+
if backend == "scipy":
62+
return scipy.sparse.csr_array(a)
63+
raise ValueError(
64+
f"Backend {backend!r} not understood; must be 'graphblas', 'networkx', or 'scipy'"
65+
)
66+
67+
68+
def best_units(num):
69+
"""Returns scale factor and prefix such that 1 <= num*scale < 1000"""
70+
if num < 1e-12:
71+
return 1e15, "f"
72+
if num < 1e-9:
73+
return 1e12, "p"
74+
if num < 1e-6:
75+
return 1e9, "n"
76+
if num < 1e-3:
77+
return 1e6, "\N{MICRO SIGN}"
78+
if num < 1:
79+
return 1e3, "m"
80+
if num < 1e3:
81+
return 1.0, ""
82+
if num < 1e6:
83+
return 1e-3, "k"
84+
if num < 1e9:
85+
return 1e-6, "M"
86+
if num < 1e12:
87+
return 1e-9, "G"
88+
return 1e-12, "T"
89+
90+
91+
def stime(time):
92+
scale, units = best_units(time)
93+
return f"{time * scale:4.3g} {units}s"
94+
95+
96+
# Functions that aren't available in the main networkx namespace
97+
functionpaths = {
98+
"inter_community_edges": "community.quality.inter_community_edges",
99+
"intra_community_edges": "community.quality.intra_community_edges",
100+
"is_tournament": "tournament.is_tournament",
101+
"mutual_weight": "structuralholes.mutual_weight",
102+
"score_sequence": "tournament.score_sequence",
103+
"tournament_matrix": "tournament.tournament_matrix",
104+
}
105+
functioncall = {
106+
"s_metric": "func(G, normalized=False)",
107+
}
108+
poweriteration = {"eigenvector_centrality", "katz_centrality", "pagerank"}
109+
directed_only = {
110+
"in_degree_centrality",
111+
"is_tournament",
112+
"out_degree_centrality",
113+
"score_sequence",
114+
"tournament_matrix",
115+
"reciprocity",
116+
"overall_reciprocity",
117+
}
118+
# Is square_clustering undirected only? graphblas-algorthms doesn't implement it for directed
119+
undirected_only = {"generalized_degree", "k_truss", "triangles", "square_clustering"}
120+
121+
122+
def getfunction(functionname, backend):
123+
if backend == "graphblas":
124+
return getattr(ga_dispatcher, functionname)
125+
if backend == "scipy":
126+
return getattr(scipy_impl, functionname)
127+
if functionname in functionpaths:
128+
func = nx
129+
for attr in functionpaths[functionname].split("."):
130+
func = getattr(func, attr)
131+
return func
132+
return getattr(nx, functionname)
133+
134+
135+
def main(dataname, backend, functionname, time=3.0, n=None, extra=None, display=True):
136+
filename = find_data(dataname)
137+
is_symmetric = get_symmetry(filename) == "symmetric"
138+
if not is_symmetric and functionname in undirected_only:
139+
# Should we automatically symmetrize?
140+
raise ValueError(
141+
f"Data {dataname!r} is not symmetric, but {functionname} only works on undirected"
142+
)
143+
if is_symmetric and functionname in directed_only:
144+
is_symmetric = False # Make into directed graph
145+
G = readfile(filename, is_symmetric, backend)
146+
func = getfunction(functionname, backend)
147+
benchstring = functioncall.get(functionname, "func(G)")
148+
if extra is not None:
149+
benchstring = f"{benchstring[:-1]}, {extra})"
150+
globals = {"func": func, "G": G}
151+
if functionname in poweriteration:
152+
benchstring = f"try:\n {benchstring}\nexcept exc:\n pass"
153+
globals["exc"] = nx.PowerIterationFailedConvergence
154+
if backend == "graphblas":
155+
benchstring = f"G._cache.clear()\n{benchstring}"
156+
timer = timeit.Timer(benchstring, globals=globals)
157+
if display:
158+
line = f"Backend = {backend}, function = {functionname}, data = {dataname}"
159+
if extra is not None:
160+
line += f", extra = {extra}"
161+
print("=" * len(line))
162+
print(line)
163+
print("-" * len(line))
164+
info = {"backend": backend, "function": functionname, "data": dataname}
165+
if extra is not None:
166+
info["extra"] = extra
167+
try:
168+
first_time = timer.timeit(1)
169+
except Exception as exc:
170+
if display:
171+
print(f"EXCEPTION: {exc}")
172+
print("=" * len(line))
173+
raise
174+
info["exception"] = str(exc)
175+
return info
176+
if time == 0:
177+
n = 1
178+
elif n is None:
179+
n = 2 ** max(0, int(np.ceil(np.log2(time / first_time))))
180+
if display:
181+
print("Number of runs:", n)
182+
print("first: ", stime(first_time))
183+
info["n"] = n
184+
info["first"] = first_time
185+
if n > 1:
186+
results = timer.repeat(n - 1, 1)
187+
results.append(first_time)
188+
if display:
189+
print("median:", stime(statistics.median(results)))
190+
print("mean: ", stime(statistics.mean(results)))
191+
print("stdev: ", stime(statistics.stdev(results)))
192+
print("min: ", stime(min(results)))
193+
print("max: ", stime(max(results)))
194+
info["median"] = statistics.median(results)
195+
info["mean"] = statistics.mean(results)
196+
info["stdev"] = statistics.stdev(results)
197+
info["min"] = min(results)
198+
info["max"] = max(results)
199+
if display:
200+
print("=" * len(line))
201+
return info
202+
203+
204+
if __name__ == "__main__":
205+
parser = argparse.ArgumentParser(
206+
description=f"Example usage: python {sys.argv[0]} -b graphblas -f pagerank -d amazon0302"
207+
)
208+
parser.add_argument(
209+
"-b", "--backend", choices=["graphblas", "networkx", "scipy"], default="graphblas"
210+
)
211+
parser.add_argument(
212+
"-t", "--time", type=float, default=3.0, help="Target minimum time to run benchmarks"
213+
)
214+
parser.add_argument(
215+
"-n",
216+
type=int,
217+
help="The number of times to run the benchmark (the default is to run according to time)",
218+
)
219+
parser.add_argument(
220+
"-d",
221+
"--data",
222+
required=True,
223+
help="The path to a mtx file or one of the following data names: {"
224+
+ ", ".join(sorted(download_data.data_urls))
225+
+ "}; data will be downloaded if necessary",
226+
)
227+
parser.add_argument(
228+
"-j",
229+
"--json",
230+
action="store_true",
231+
help="Print results as json instead of human-readable text",
232+
)
233+
parser.add_argument("-f", "--func", required=True, help="Which function to benchmark")
234+
parser.add_argument("--extra", help="Extra string to add to the function call")
235+
args = parser.parse_args()
236+
info = main(
237+
args.data,
238+
args.backend,
239+
args.func,
240+
time=args.time,
241+
n=args.n,
242+
extra=args.extra,
243+
display=not args.json,
244+
)
245+
if args.json:
246+
print(json.dumps(info))

0 commit comments

Comments
 (0)