|
| 1 | +#!/usr/bin/env python |
| 2 | +import argparse |
| 3 | +import json |
| 4 | +import os |
| 5 | +import statistics |
| 6 | +import sys |
| 7 | +import timeit |
| 8 | + |
| 9 | +import download_data |
| 10 | +import graphblas as gb |
| 11 | +import networkx as nx |
| 12 | +import numpy as np |
| 13 | +import scipy.sparse |
| 14 | + |
| 15 | +import graphblas_algorithms as ga |
| 16 | +import scipy_impl |
| 17 | +from graphblas_algorithms.interface import Dispatcher as ga_dispatcher |
| 18 | + |
| 19 | +thisdir = os.path.dirname(__file__) |
| 20 | +datapaths = [ |
| 21 | + os.path.join(thisdir, "..", "data"), |
| 22 | + os.path.curdir, |
| 23 | +] |
| 24 | + |
| 25 | + |
| 26 | +def find_data(dataname): |
| 27 | + if os.path.exists(dataname): |
| 28 | + return os.path.relpath(dataname) |
| 29 | + for path in datapaths: |
| 30 | + path = os.path.join(path, dataname) + ".mtx" |
| 31 | + if os.path.exists(path): |
| 32 | + return os.path.relpath(path) |
| 33 | + path = path.removesuffix(".mtx") |
| 34 | + if os.path.exists(path): |
| 35 | + return os.path.relpath(path) |
| 36 | + if dataname in download_data.data_urls: |
| 37 | + return os.path.relpath(download_data.main([dataname])[0]) |
| 38 | + raise FileNotFoundError(f"Unable to find data file for {dataname}") |
| 39 | + |
| 40 | + |
| 41 | +def get_symmetry(file_or_mminfo): |
| 42 | + if not isinstance(file_or_mminfo, tuple): |
| 43 | + mminfo = scipy.io.mminfo(file_or_mminfo) |
| 44 | + else: |
| 45 | + mminfo = file_or_mminfo |
| 46 | + return mminfo[5] |
| 47 | + |
| 48 | + |
| 49 | +def readfile(filename, is_symmetric, backend): |
| 50 | + name = filename.split(".", 1)[0].rsplit("/", 1)[0] |
| 51 | + if backend == "graphblas": |
| 52 | + A = gb.io.mmread(filename, name=name) |
| 53 | + A.wait() |
| 54 | + if is_symmetric: |
| 55 | + return ga.Graph(A) |
| 56 | + return ga.DiGraph(A) |
| 57 | + a = scipy.io.mmread(filename) |
| 58 | + if backend == "networkx": |
| 59 | + create_using = nx.Graph if is_symmetric else nx.DiGraph |
| 60 | + return nx.from_scipy_sparse_array(a, create_using=create_using) |
| 61 | + if backend == "scipy": |
| 62 | + return scipy.sparse.csr_array(a) |
| 63 | + raise ValueError( |
| 64 | + f"Backend {backend!r} not understood; must be 'graphblas', 'networkx', or 'scipy'" |
| 65 | + ) |
| 66 | + |
| 67 | + |
| 68 | +def best_units(num): |
| 69 | + """Returns scale factor and prefix such that 1 <= num*scale < 1000""" |
| 70 | + if num < 1e-12: |
| 71 | + return 1e15, "f" |
| 72 | + if num < 1e-9: |
| 73 | + return 1e12, "p" |
| 74 | + if num < 1e-6: |
| 75 | + return 1e9, "n" |
| 76 | + if num < 1e-3: |
| 77 | + return 1e6, "\N{MICRO SIGN}" |
| 78 | + if num < 1: |
| 79 | + return 1e3, "m" |
| 80 | + if num < 1e3: |
| 81 | + return 1.0, "" |
| 82 | + if num < 1e6: |
| 83 | + return 1e-3, "k" |
| 84 | + if num < 1e9: |
| 85 | + return 1e-6, "M" |
| 86 | + if num < 1e12: |
| 87 | + return 1e-9, "G" |
| 88 | + return 1e-12, "T" |
| 89 | + |
| 90 | + |
| 91 | +def stime(time): |
| 92 | + scale, units = best_units(time) |
| 93 | + return f"{time * scale:4.3g} {units}s" |
| 94 | + |
| 95 | + |
| 96 | +# Functions that aren't available in the main networkx namespace |
| 97 | +functionpaths = { |
| 98 | + "inter_community_edges": "community.quality.inter_community_edges", |
| 99 | + "intra_community_edges": "community.quality.intra_community_edges", |
| 100 | + "is_tournament": "tournament.is_tournament", |
| 101 | + "mutual_weight": "structuralholes.mutual_weight", |
| 102 | + "score_sequence": "tournament.score_sequence", |
| 103 | + "tournament_matrix": "tournament.tournament_matrix", |
| 104 | +} |
| 105 | +functioncall = { |
| 106 | + "s_metric": "func(G, normalized=False)", |
| 107 | +} |
| 108 | +poweriteration = {"eigenvector_centrality", "katz_centrality", "pagerank"} |
| 109 | +directed_only = { |
| 110 | + "in_degree_centrality", |
| 111 | + "is_tournament", |
| 112 | + "out_degree_centrality", |
| 113 | + "score_sequence", |
| 114 | + "tournament_matrix", |
| 115 | + "reciprocity", |
| 116 | + "overall_reciprocity", |
| 117 | +} |
| 118 | +# Is square_clustering undirected only? graphblas-algorthms doesn't implement it for directed |
| 119 | +undirected_only = {"generalized_degree", "k_truss", "triangles", "square_clustering"} |
| 120 | + |
| 121 | + |
| 122 | +def getfunction(functionname, backend): |
| 123 | + if backend == "graphblas": |
| 124 | + return getattr(ga_dispatcher, functionname) |
| 125 | + if backend == "scipy": |
| 126 | + return getattr(scipy_impl, functionname) |
| 127 | + if functionname in functionpaths: |
| 128 | + func = nx |
| 129 | + for attr in functionpaths[functionname].split("."): |
| 130 | + func = getattr(func, attr) |
| 131 | + return func |
| 132 | + return getattr(nx, functionname) |
| 133 | + |
| 134 | + |
| 135 | +def main(dataname, backend, functionname, time=3.0, n=None, extra=None, display=True): |
| 136 | + filename = find_data(dataname) |
| 137 | + is_symmetric = get_symmetry(filename) == "symmetric" |
| 138 | + if not is_symmetric and functionname in undirected_only: |
| 139 | + # Should we automatically symmetrize? |
| 140 | + raise ValueError( |
| 141 | + f"Data {dataname!r} is not symmetric, but {functionname} only works on undirected" |
| 142 | + ) |
| 143 | + if is_symmetric and functionname in directed_only: |
| 144 | + is_symmetric = False # Make into directed graph |
| 145 | + G = readfile(filename, is_symmetric, backend) |
| 146 | + func = getfunction(functionname, backend) |
| 147 | + benchstring = functioncall.get(functionname, "func(G)") |
| 148 | + if extra is not None: |
| 149 | + benchstring = f"{benchstring[:-1]}, {extra})" |
| 150 | + globals = {"func": func, "G": G} |
| 151 | + if functionname in poweriteration: |
| 152 | + benchstring = f"try:\n {benchstring}\nexcept exc:\n pass" |
| 153 | + globals["exc"] = nx.PowerIterationFailedConvergence |
| 154 | + if backend == "graphblas": |
| 155 | + benchstring = f"G._cache.clear()\n{benchstring}" |
| 156 | + timer = timeit.Timer(benchstring, globals=globals) |
| 157 | + if display: |
| 158 | + line = f"Backend = {backend}, function = {functionname}, data = {dataname}" |
| 159 | + if extra is not None: |
| 160 | + line += f", extra = {extra}" |
| 161 | + print("=" * len(line)) |
| 162 | + print(line) |
| 163 | + print("-" * len(line)) |
| 164 | + info = {"backend": backend, "function": functionname, "data": dataname} |
| 165 | + if extra is not None: |
| 166 | + info["extra"] = extra |
| 167 | + try: |
| 168 | + first_time = timer.timeit(1) |
| 169 | + except Exception as exc: |
| 170 | + if display: |
| 171 | + print(f"EXCEPTION: {exc}") |
| 172 | + print("=" * len(line)) |
| 173 | + raise |
| 174 | + info["exception"] = str(exc) |
| 175 | + return info |
| 176 | + if time == 0: |
| 177 | + n = 1 |
| 178 | + elif n is None: |
| 179 | + n = 2 ** max(0, int(np.ceil(np.log2(time / first_time)))) |
| 180 | + if display: |
| 181 | + print("Number of runs:", n) |
| 182 | + print("first: ", stime(first_time)) |
| 183 | + info["n"] = n |
| 184 | + info["first"] = first_time |
| 185 | + if n > 1: |
| 186 | + results = timer.repeat(n - 1, 1) |
| 187 | + results.append(first_time) |
| 188 | + if display: |
| 189 | + print("median:", stime(statistics.median(results))) |
| 190 | + print("mean: ", stime(statistics.mean(results))) |
| 191 | + print("stdev: ", stime(statistics.stdev(results))) |
| 192 | + print("min: ", stime(min(results))) |
| 193 | + print("max: ", stime(max(results))) |
| 194 | + info["median"] = statistics.median(results) |
| 195 | + info["mean"] = statistics.mean(results) |
| 196 | + info["stdev"] = statistics.stdev(results) |
| 197 | + info["min"] = min(results) |
| 198 | + info["max"] = max(results) |
| 199 | + if display: |
| 200 | + print("=" * len(line)) |
| 201 | + return info |
| 202 | + |
| 203 | + |
| 204 | +if __name__ == "__main__": |
| 205 | + parser = argparse.ArgumentParser( |
| 206 | + description=f"Example usage: python {sys.argv[0]} -b graphblas -f pagerank -d amazon0302" |
| 207 | + ) |
| 208 | + parser.add_argument( |
| 209 | + "-b", "--backend", choices=["graphblas", "networkx", "scipy"], default="graphblas" |
| 210 | + ) |
| 211 | + parser.add_argument( |
| 212 | + "-t", "--time", type=float, default=3.0, help="Target minimum time to run benchmarks" |
| 213 | + ) |
| 214 | + parser.add_argument( |
| 215 | + "-n", |
| 216 | + type=int, |
| 217 | + help="The number of times to run the benchmark (the default is to run according to time)", |
| 218 | + ) |
| 219 | + parser.add_argument( |
| 220 | + "-d", |
| 221 | + "--data", |
| 222 | + required=True, |
| 223 | + help="The path to a mtx file or one of the following data names: {" |
| 224 | + + ", ".join(sorted(download_data.data_urls)) |
| 225 | + + "}; data will be downloaded if necessary", |
| 226 | + ) |
| 227 | + parser.add_argument( |
| 228 | + "-j", |
| 229 | + "--json", |
| 230 | + action="store_true", |
| 231 | + help="Print results as json instead of human-readable text", |
| 232 | + ) |
| 233 | + parser.add_argument("-f", "--func", required=True, help="Which function to benchmark") |
| 234 | + parser.add_argument("--extra", help="Extra string to add to the function call") |
| 235 | + args = parser.parse_args() |
| 236 | + info = main( |
| 237 | + args.data, |
| 238 | + args.backend, |
| 239 | + args.func, |
| 240 | + time=args.time, |
| 241 | + n=args.n, |
| 242 | + extra=args.extra, |
| 243 | + display=not args.json, |
| 244 | + ) |
| 245 | + if args.json: |
| 246 | + print(json.dumps(info)) |
0 commit comments