|
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 | import openvino as ov
|
7 |
| -import openvino.runtime.opset14 as ov_opset |
| 7 | +import openvino.runtime.opset15 as ov_opset |
8 | 8 | from openvino import Model
|
9 | 9 | from openvino import Tensor
|
10 | 10 | from openvino import compile_model
|
@@ -810,82 +810,58 @@ def prepare_slice_index(val):
|
810 | 810 |
|
811 | 811 |
|
812 | 812 | def slice_update(inputs, start_indices, updates):
|
813 |
| - inputs = get_ov_output(inputs) |
814 | 813 | if isinstance(start_indices, (list, np.ndarray)):
|
815 | 814 | start_indices = tuple(start_indices)
|
816 | 815 | assert isinstance(start_indices, tuple), (
|
817 | 816 | "`slice_update` is not supported by openvino backend"
|
818 | 817 | " for `start_indices` of type {}".format(type(start_indices))
|
819 | 818 | )
|
820 |
| - processed_start_indices = [] |
821 |
| - for idx in start_indices: |
822 |
| - val = get_ov_output(idx) |
| 819 | + |
| 820 | + inputs = get_ov_output(inputs) |
| 821 | + updates = get_ov_output(updates) |
| 822 | + |
| 823 | + assert len(start_indices) == len(updates.get_partial_shape()), ( |
| 824 | + "Rank of updates must match length of start_indices" |
| 825 | + ) |
| 826 | + |
| 827 | + axes = [] |
| 828 | + starts = [] |
| 829 | + stops = [] |
| 830 | + |
| 831 | + def prepare_index(val): |
| 832 | + val = get_ov_output(val) |
823 | 833 | val_type = val.get_element_type()
|
824 | 834 | if not val_type.is_integral():
|
825 | 835 | raise ValueError(
|
826 |
| - "`slice` is not supported by OpenVINO backend " |
827 |
| - "for `start_indices` or `shape` with non-integer types" |
| 836 | + "`slice_update` is not supported by OpenVINO backend " |
| 837 | + "for `start_indices` with non-integer types" |
828 | 838 | )
|
829 | 839 | if val_type != Type.i32:
|
830 | 840 | val = ov_opset.convert(val, Type.i32).output(0)
|
831 | 841 | if len(val.get_partial_shape()) == 0:
|
832 | 842 | val = ov_opset.unsqueeze(
|
833 | 843 | val, ov_opset.constant(0, Type.i32)
|
834 | 844 | ).output(0)
|
835 |
| - processed_start_indices.append(val) |
836 |
| - start_indices_tensor = ov_opset.concat(processed_start_indices, axis=0) |
837 |
| - |
838 |
| - rank = len(updates.shape) |
839 |
| - ranges = [] |
840 |
| - for dim in updates.shape: |
841 |
| - r = ov_opset.range( |
842 |
| - ov_opset.constant(0, Type.i32), |
843 |
| - ov_opset.constant(dim, Type.i32), |
844 |
| - ov_opset.constant(1, Type.i32), |
845 |
| - output_type=Type.i32, |
846 |
| - ) |
847 |
| - ranges.append(r) |
848 |
| - |
849 |
| - broadcasted_ranges = [] |
850 |
| - for i, r in enumerate(ranges): |
851 |
| - shape = [1] * rank |
852 |
| - shape[i] = updates.shape[i] |
853 |
| - r_reshaped = ov_opset.reshape( |
854 |
| - r, ov_opset.constant(shape, Type.i32), special_zero=False |
855 |
| - ).output(0) |
856 |
| - target_shape = ov_opset.constant(list(updates.shape), Type.i32) |
857 |
| - r_broadcasted = ov_opset.broadcast(r_reshaped, target_shape).output(0) |
858 |
| - broadcasted_ranges.append(r_broadcasted) |
859 |
| - |
860 |
| - indices_stack = ov_opset.concat(broadcasted_ranges, axis=0).output(0) |
861 |
| - |
862 |
| - num_updates = 1 |
863 |
| - for dim in updates.shape: |
864 |
| - num_updates *= dim |
865 |
| - new_shape = ov_opset.constant([rank, num_updates], Type.i32) |
866 |
| - indices_reshaped = ov_opset.reshape( |
867 |
| - indices_stack, new_shape, special_zero=False |
868 |
| - ).output(0) |
869 |
| - absolute_indices = ov_opset.transpose( |
870 |
| - indices_reshaped, ov_opset.constant([1, 0], Type.i32) |
871 |
| - ).output(0) |
| 845 | + return val |
872 | 846 |
|
873 |
| - start_indices_expanded = ov_opset.broadcast( |
874 |
| - start_indices_tensor, ov_opset.constant([num_updates, rank], Type.i32) |
875 |
| - ).output(0) |
876 |
| - absolute_indices = ov_opset.add( |
877 |
| - absolute_indices, start_indices_expanded |
878 |
| - ).output(0) |
| 847 | + for idx, dim in enumerate(updates.shape): |
| 848 | + axes.append(idx) |
879 | 849 |
|
880 |
| - updates_tensor = get_ov_output(updates) |
881 |
| - updates_flat = ov_opset.reshape( |
882 |
| - updates_tensor, |
883 |
| - ov_opset.constant([num_updates], Type.i32), |
884 |
| - special_zero=False, |
885 |
| - ).output(0) |
886 |
| - updated = ov_opset.scatter_nd_update( |
887 |
| - inputs, absolute_indices, updates_flat |
| 850 | + start_val = prepare_index(start_indices[idx]) |
| 851 | + stop_val = prepare_index(start_indices[idx] + dim) |
| 852 | + |
| 853 | + starts.append(start_val) |
| 854 | + stops.append(stop_val) |
| 855 | + |
| 856 | + axes_tensor = ov_opset.constant(axes, dtype=Type.i32).output(0) |
| 857 | + starts_tensor = ov_opset.concat(starts, axis=0).output(0) |
| 858 | + stops_tensor = ov_opset.concat(stops, axis=0).output(0) |
| 859 | + steps_tensor = ov_opset.constant([1] * len(axes), dtype=Type.i32).output(0) |
| 860 | + |
| 861 | + updated = ov_opset.slice_scatter( |
| 862 | + inputs, updates, starts_tensor, stops_tensor, steps_tensor, axes_tensor |
888 | 863 | ).output(0)
|
| 864 | + |
889 | 865 | return OpenVINOKerasTensor(updated)
|
890 | 866 |
|
891 | 867 |
|
|
0 commit comments