@@ -657,9 +657,93 @@ def slice(inputs, start_indices, shape):
657
657
658
658
659
659
def slice_update (inputs , start_indices , updates ):
660
- raise NotImplementedError (
661
- "`slice_update` is not supported with openvino backend"
660
+ inputs = get_ov_output (inputs )
661
+ if isinstance (start_indices , (list , np .ndarray )):
662
+ if isinstance (start_indices , np .ndarray ):
663
+ start_indices = start_indices .tolist ()
664
+ start_indices = tuple (start_indices )
665
+ if isinstance (updates , (list , np .ndarray )):
666
+ if isinstance (updates , np .ndarray ):
667
+ updates = updates .tolist ()
668
+ updates = tuple (updates )
669
+ assert isinstance (start_indices , tuple ), (
670
+ "`slice` is not supported by openvino backend"
671
+ " for `start_indices` of type {}" .format (type (start_indices ))
662
672
)
673
+ assert isinstance (updates , tuple ), (
674
+ "`slice` is not supported by openvino backend"
675
+ " for `updates` of type {}" .format (type (updates ))
676
+ )
677
+ processed_start_indices = []
678
+ for idx in start_indices :
679
+ val = get_ov_output (idx )
680
+ val_type = val .get_element_type ()
681
+ if not val_type .is_integral ():
682
+ raise ValueError (
683
+ "`slice` is not supported by OpenVINO backend "
684
+ "for `start_indices` or `shape` with non-integer types"
685
+ )
686
+ if val_type != Type .i32 :
687
+ val = ov_opset .convert (val , Type .i32 ).output (0 )
688
+ if len (val .get_partial_shape ()) == 0 :
689
+ val = ov_opset .unsqueeze (
690
+ val , ov_opset .constant (0 , Type .i32 )
691
+ ).output (0 )
692
+ processed_start_indices .append (val )
693
+ start_indices_tensor = ov_opset .concat (processed_start_indices , axis = 0 )
694
+
695
+ rank = len (updates .shape )
696
+ ranges = []
697
+ for dim in updates .shape :
698
+ r = ov_opset .range (
699
+ ov_opset .constant (0 , Type .i32 ),
700
+ ov_opset .constant (dim , Type .i32 ),
701
+ ov_opset .constant (1 , Type .i32 ),
702
+ output_type = Type .i32 ,
703
+ )
704
+ ranges .append (r )
705
+
706
+ broadcasted_ranges = []
707
+ for i , r in enumerate (ranges ):
708
+ shape = [1 ] * rank
709
+ shape [i ] = updates .shape [i ]
710
+ r_reshaped = ov_opset .reshape (
711
+ r , ov_opset .constant (shape , Type .i32 ), special_zero = False
712
+ ).output (0 )
713
+ target_shape = ov_opset .constant (list (updates .shape ), Type .i32 )
714
+ r_broadcasted = ov_opset .broadcast (r_reshaped , target_shape ).output (0 )
715
+ broadcasted_ranges .append (r_broadcasted )
716
+
717
+ indices_stack = ov_opset .concat (broadcasted_ranges , axis = 0 ).output (0 )
718
+
719
+ num_updates = 1
720
+ for dim in updates .shape :
721
+ num_updates *= dim
722
+ new_shape = ov_opset .constant ([rank , num_updates ], Type .i32 )
723
+ indices_reshaped = ov_opset .reshape (
724
+ indices_stack , new_shape , special_zero = False
725
+ ).output (0 )
726
+ absolute_indices = ov_opset .transpose (
727
+ indices_reshaped , ov_opset .constant ([1 , 0 ], Type .i32 )
728
+ ).output (0 )
729
+
730
+ start_indices_expanded = ov_opset .broadcast (
731
+ start_indices_tensor , ov_opset .constant ([num_updates , rank ], Type .i32 )
732
+ ).output (0 )
733
+ absolute_indices = ov_opset .add (
734
+ absolute_indices , start_indices_expanded
735
+ ).output (0 )
736
+
737
+ updates_tensor = get_ov_output (updates )
738
+ updates_flat = ov_opset .reshape (
739
+ updates_tensor ,
740
+ ov_opset .constant ([num_updates ], Type .i32 ),
741
+ special_zero = False ,
742
+ ).output (0 )
743
+ updated = ov_opset .scatter_nd_update (
744
+ inputs , absolute_indices , updates_flat
745
+ ).output (0 )
746
+ return OpenVINOKerasTensor (updated )
663
747
664
748
665
749
def while_loop (
0 commit comments