Skip to content

Commit 307110e

Browse files
update opset from 14 to 15, and optimize slice_update
1 parent a285aa7 commit 307110e

File tree

2 files changed

+35
-59
lines changed

2 files changed

+35
-59
lines changed

keras/src/backend/openvino/core.py

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import openvino as ov
7-
import openvino.runtime.opset14 as ov_opset
7+
import openvino.runtime.opset15 as ov_opset
88
from openvino import Model
99
from openvino import Tensor
1010
from openvino import compile_model
@@ -810,82 +810,58 @@ def prepare_slice_index(val):
810810

811811

812812
def slice_update(inputs, start_indices, updates):
813-
inputs = get_ov_output(inputs)
814813
if isinstance(start_indices, (list, np.ndarray)):
815814
start_indices = tuple(start_indices)
816815
assert isinstance(start_indices, tuple), (
817816
"`slice_update` is not supported by openvino backend"
818817
" for `start_indices` of type {}".format(type(start_indices))
819818
)
820-
processed_start_indices = []
821-
for idx in start_indices:
822-
val = get_ov_output(idx)
819+
820+
inputs = get_ov_output(inputs)
821+
updates = get_ov_output(updates)
822+
823+
assert len(start_indices) == len(updates.get_partial_shape()), (
824+
"Rank of updates must match length of start_indices"
825+
)
826+
827+
axes = []
828+
starts = []
829+
stops = []
830+
831+
def prepare_index(val):
832+
val = get_ov_output(val)
823833
val_type = val.get_element_type()
824834
if not val_type.is_integral():
825835
raise ValueError(
826-
"`slice` is not supported by OpenVINO backend "
827-
"for `start_indices` or `shape` with non-integer types"
836+
"`slice_update` is not supported by OpenVINO backend "
837+
"for `start_indices` with non-integer types"
828838
)
829839
if val_type != Type.i32:
830840
val = ov_opset.convert(val, Type.i32).output(0)
831841
if len(val.get_partial_shape()) == 0:
832842
val = ov_opset.unsqueeze(
833843
val, ov_opset.constant(0, Type.i32)
834844
).output(0)
835-
processed_start_indices.append(val)
836-
start_indices_tensor = ov_opset.concat(processed_start_indices, axis=0)
837-
838-
rank = len(updates.shape)
839-
ranges = []
840-
for dim in updates.shape:
841-
r = ov_opset.range(
842-
ov_opset.constant(0, Type.i32),
843-
ov_opset.constant(dim, Type.i32),
844-
ov_opset.constant(1, Type.i32),
845-
output_type=Type.i32,
846-
)
847-
ranges.append(r)
848-
849-
broadcasted_ranges = []
850-
for i, r in enumerate(ranges):
851-
shape = [1] * rank
852-
shape[i] = updates.shape[i]
853-
r_reshaped = ov_opset.reshape(
854-
r, ov_opset.constant(shape, Type.i32), special_zero=False
855-
).output(0)
856-
target_shape = ov_opset.constant(list(updates.shape), Type.i32)
857-
r_broadcasted = ov_opset.broadcast(r_reshaped, target_shape).output(0)
858-
broadcasted_ranges.append(r_broadcasted)
859-
860-
indices_stack = ov_opset.concat(broadcasted_ranges, axis=0).output(0)
861-
862-
num_updates = 1
863-
for dim in updates.shape:
864-
num_updates *= dim
865-
new_shape = ov_opset.constant([rank, num_updates], Type.i32)
866-
indices_reshaped = ov_opset.reshape(
867-
indices_stack, new_shape, special_zero=False
868-
).output(0)
869-
absolute_indices = ov_opset.transpose(
870-
indices_reshaped, ov_opset.constant([1, 0], Type.i32)
871-
).output(0)
845+
return val
872846

873-
start_indices_expanded = ov_opset.broadcast(
874-
start_indices_tensor, ov_opset.constant([num_updates, rank], Type.i32)
875-
).output(0)
876-
absolute_indices = ov_opset.add(
877-
absolute_indices, start_indices_expanded
878-
).output(0)
847+
for idx, dim in enumerate(updates.shape):
848+
axes.append(idx)
879849

880-
updates_tensor = get_ov_output(updates)
881-
updates_flat = ov_opset.reshape(
882-
updates_tensor,
883-
ov_opset.constant([num_updates], Type.i32),
884-
special_zero=False,
885-
).output(0)
886-
updated = ov_opset.scatter_nd_update(
887-
inputs, absolute_indices, updates_flat
850+
start_val = prepare_index(start_indices[idx])
851+
stop_val = prepare_index(start_indices[idx] + dim)
852+
853+
starts.append(start_val)
854+
stops.append(stop_val)
855+
856+
axes_tensor = ov_opset.constant(axes, dtype=Type.i32).output(0)
857+
starts_tensor = ov_opset.concat(starts, axis=0).output(0)
858+
stops_tensor = ov_opset.concat(stops, axis=0).output(0)
859+
steps_tensor = ov_opset.constant([1] * len(axes), dtype=Type.i32).output(0)
860+
861+
updated = ov_opset.slice_scatter(
862+
inputs, updates, starts_tensor, stops_tensor, steps_tensor, axes_tensor
888863
).output(0)
864+
889865
return OpenVINOKerasTensor(updated)
890866

891867

keras/src/backend/openvino/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
import openvino.runtime.opset14 as ov_opset
2+
import openvino.runtime.opset15 as ov_opset
33
from openvino import Type
44

55
from keras.src.backend import config

0 commit comments

Comments
 (0)