Skip to content

Commit abf3a3a

Browse files
[OpenVINO Backend] support slice_update
1 parent 771b001 commit abf3a3a

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

keras/src/backend/openvino/core.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,93 @@ def slice(inputs, start_indices, shape):
657657

658658

659659
def slice_update(inputs, start_indices, updates):
660-
raise NotImplementedError(
661-
"`slice_update` is not supported with openvino backend"
660+
inputs = get_ov_output(inputs)
661+
if isinstance(start_indices, (list, np.ndarray)):
662+
if isinstance(start_indices, np.ndarray):
663+
start_indices = start_indices.tolist()
664+
start_indices = tuple(start_indices)
665+
if isinstance(updates, (list, np.ndarray)):
666+
if isinstance(updates, np.ndarray):
667+
updates = updates.tolist()
668+
updates = tuple(updates)
669+
assert isinstance(start_indices, tuple), (
670+
"`slice` is not supported by openvino backend"
671+
" for `start_indices` of type {}".format(type(start_indices))
662672
)
673+
assert isinstance(updates, tuple), (
674+
"`slice` is not supported by openvino backend"
675+
" for `updates` of type {}".format(type(updates))
676+
)
677+
processed_start_indices = []
678+
for idx in start_indices:
679+
val = get_ov_output(idx)
680+
val_type = val.get_element_type()
681+
if not val_type.is_integral():
682+
raise ValueError(
683+
"`slice` is not supported by OpenVINO backend "
684+
"for `start_indices` or `shape` with non-integer types"
685+
)
686+
if val_type != Type.i32:
687+
val = ov_opset.convert(val, Type.i32).output(0)
688+
if len(val.get_partial_shape()) == 0:
689+
val = ov_opset.unsqueeze(
690+
val, ov_opset.constant(0, Type.i32)
691+
).output(0)
692+
processed_start_indices.append(val)
693+
start_indices_tensor = ov_opset.concat(processed_start_indices, axis=0)
694+
695+
rank = len(updates.shape)
696+
ranges = []
697+
for dim in updates.shape:
698+
r = ov_opset.range(
699+
ov_opset.constant(0, Type.i32),
700+
ov_opset.constant(dim, Type.i32),
701+
ov_opset.constant(1, Type.i32),
702+
output_type=Type.i32,
703+
)
704+
ranges.append(r)
705+
706+
broadcasted_ranges = []
707+
for i, r in enumerate(ranges):
708+
shape = [1] * rank
709+
shape[i] = updates.shape[i]
710+
r_reshaped = ov_opset.reshape(
711+
r, ov_opset.constant(shape, Type.i32), special_zero=False
712+
).output(0)
713+
target_shape = ov_opset.constant(list(updates.shape), Type.i32)
714+
r_broadcasted = ov_opset.broadcast(r_reshaped, target_shape).output(0)
715+
broadcasted_ranges.append(r_broadcasted)
716+
717+
indices_stack = ov_opset.concat(broadcasted_ranges, axis=0).output(0)
718+
719+
num_updates = 1
720+
for dim in updates.shape:
721+
num_updates *= dim
722+
new_shape = ov_opset.constant([rank, num_updates], Type.i32)
723+
indices_reshaped = ov_opset.reshape(
724+
indices_stack, new_shape, special_zero=False
725+
).output(0)
726+
absolute_indices = ov_opset.transpose(
727+
indices_reshaped, ov_opset.constant([1, 0], Type.i32)
728+
).output(0)
729+
730+
start_indices_expanded = ov_opset.broadcast(
731+
start_indices_tensor, ov_opset.constant([num_updates, rank], Type.i32)
732+
).output(0)
733+
absolute_indices = ov_opset.add(
734+
absolute_indices, start_indices_expanded
735+
).output(0)
736+
737+
updates_tensor = get_ov_output(updates)
738+
updates_flat = ov_opset.reshape(
739+
updates_tensor,
740+
ov_opset.constant([num_updates], Type.i32),
741+
special_zero=False,
742+
).output(0)
743+
updated = ov_opset.scatter_nd_update(
744+
inputs, absolute_indices, updates_flat
745+
).output(0)
746+
return OpenVINOKerasTensor(updated)
663747

664748

665749
def while_loop(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy