@@ -100,40 +100,43 @@ def find_solve_clients(var, assume_a):
100
100
elif isinstance (cl .op , DimShuffle ) and cl .op .is_left_expand_dims :
101
101
# If it's a left expand_dims, recurse on the output
102
102
clients .extend (find_solve_clients (cl .outputs [0 ], assume_a ))
103
+
103
104
return clients
104
105
105
106
assume_a = node .op .core_op .assume_a
106
107
107
108
if assume_a not in allowed_assume_a :
108
109
return None
109
110
110
- A , _ = get_root_A (node .inputs [0 ])
111
+ root_A , root_A_transposed = get_root_A (node .inputs [0 ])
111
112
112
113
# Find Solve using A (or left expand_dims of A)
113
114
# TODO: We could handle arbitrary shuffle of the batch dimensions, just need to propagate
114
115
# that to the A_decomp outputs
115
- A_solve_clients_and_transpose = [
116
- (client , False ) for client in find_solve_clients (A , assume_a )
116
+ root_A_solve_clients_and_transpose = [
117
+ (client , False ) for client in find_solve_clients (root_A , assume_a )
117
118
]
118
119
119
120
# Find Solves using A.T
120
- for cl , _ in fgraph .clients [A ]:
121
+ for cl , _ in fgraph .clients [root_A ]:
121
122
if isinstance (cl .op , DimShuffle ) and is_matrix_transpose (cl .out ):
122
123
A_T = cl .out
123
- A_solve_clients_and_transpose .extend (
124
+ root_A_solve_clients_and_transpose .extend (
124
125
(client , True ) for client in find_solve_clients (A_T , assume_a )
125
126
)
126
127
127
- if not eager and len (A_solve_clients_and_transpose ) == 1 :
128
+ if not eager and len (root_A_solve_clients_and_transpose ) == 1 :
128
129
# If theres' a single use don't do it... unless it's being broadcast in a Blockwise (or we're eager)
129
130
# That's a "reuse" inside the inner vectorized loop
130
131
batch_ndim = node .op .batch_ndim (node )
131
- (client , _ ) = A_solve_clients_and_transpose [0 ]
132
- original_A , b = client .inputs
132
+ (client , _ ) = root_A_solve_clients_and_transpose [0 ]
133
+
134
+ A , b = client .inputs
135
+
133
136
if not any (
134
137
a_bcast and not b_bcast
135
138
for a_bcast , b_bcast in zip (
136
- original_A .type .broadcastable [:batch_ndim ],
139
+ A .type .broadcastable [:batch_ndim ],
137
140
b .type .broadcastable [:batch_ndim ],
138
141
strict = True ,
139
142
)
@@ -142,19 +145,27 @@ def find_solve_clients(var, assume_a):
142
145
143
146
# If any Op had check_finite=True, we also do it for the LU decomposition
144
147
check_finite_decomp = False
145
- for client , _ in A_solve_clients_and_transpose :
148
+ for client , _ in root_A_solve_clients_and_transpose :
146
149
if client .op .core_op .check_finite :
147
150
check_finite_decomp = True
148
151
break
149
152
150
- lower = node .op .core_op .lower
153
+ (first_solve , transposed ) = root_A_solve_clients_and_transpose [0 ]
154
+ lower = first_solve .op .core_op .lower
155
+ if transposed :
156
+ lower = not lower
157
+
151
158
A_decomp = decompose_A (
152
- A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
159
+ root_A , assume_a = assume_a , check_finite = check_finite_decomp , lower = lower
153
160
)
154
161
155
162
replacements = {}
156
- for client , transposed in A_solve_clients_and_transpose :
163
+ for client , transposed in root_A_solve_clients_and_transpose :
157
164
_ , b = client .inputs
165
+ lower = client .op .core_op .lower
166
+ if transposed :
167
+ lower = not lower
168
+
158
169
new_x = solve_decomposed_system (
159
170
A_decomp ,
160
171
b ,
0 commit comments