|
1 |
| -/** NOTE: This file is a current WIP */ |
| 1 | +// NOTE: This file is a current WIP! |
| 2 | +// |
| 3 | +// Run with: |
| 4 | +// ./gradlew run -Palgorithm=datastructures.segmenttree.RangeQueryPointUpdateSegmentTree |
| 5 | +// |
| 6 | +// Several thanks to cp-algorithms for their great article on segment trees: |
| 7 | +// https://cp-algorithms.com/data_structures/segment_tree.html |
| 8 | + |
2 | 9 | package com.williamfiset.algorithms.datastructures.segmenttree;
|
3 | 10 |
|
4 | 11 | public class RangeQueryPointUpdateSegmentTree {
|
5 | 12 |
|
6 |
| - // Tree values |
7 |
| - private int[] t; |
| 13 | + // Tree values. |
| 14 | + // TODO(william): make these members private |
| 15 | + Integer[] t; |
| 16 | + |
| 17 | + // The number of values in the original input values array. |
| 18 | + int n; |
8 | 19 |
|
9 |
| - private int n; |
| 20 | + // The size of the segment tree `t` |
| 21 | + // NOTE: the size is not necessarily = number of segments. |
| 22 | + int N; |
10 | 23 |
|
11 | 24 | public RangeQueryPointUpdateSegmentTree(int[] values) {
|
12 | 25 | if (values == null) {
|
13 | 26 | throw new NullPointerException("Segment tree values cannot be null.");
|
14 | 27 | }
|
15 | 28 | n = values.length;
|
16 |
| - t = new int[2 * n]; |
| 29 | + // TODO(william): Investigate to reduce this space. There are only 2n-1 segments, so we should |
| 30 | + // be able to reduce the space, but may need to reorganize the tree/queries. One idea is to use |
| 31 | + // the Eulerian tour structure of the tree to densely pack the segments. |
| 32 | + N = 4 * n; |
| 33 | + t = new Integer[N]; |
| 34 | + |
| 35 | + buildTree(0, 0, n - 1, values); |
17 | 36 |
|
18 |
| - // buildTree(0, 0, n-1); |
| 37 | + // System.out.println(java.util.Arrays.toString(values)); |
| 38 | + // System.out.println(java.util.Arrays.toString(t)); |
19 | 39 | }
|
20 | 40 |
|
21 |
| - // Builds tree bottom up starting with leaf nodes and combining |
22 |
| - // values on callback. |
23 |
| - // Range is inclusive: [l, r] |
24 |
| - private int buildTree(int i, int l, int r, int[] values) { |
25 |
| - if (l == r) { |
26 |
| - return 0; |
| 41 | + /** |
| 42 | + * Builds the segment tree starting with leaf nodes and combining values on callback. This |
| 43 | + * construction method takes O(n) time since there are only 2n - 1 segments in the segment tree. |
| 44 | + * |
| 45 | + * @param i the index of the segment in the segment tree |
| 46 | + * @param l the left index of the range on the values array |
| 47 | + * @param r the right index of the range on the values array |
| 48 | + * @param values the initial values array |
| 49 | + * <p>The range [l, r] over the values array is inclusive. |
| 50 | + */ |
| 51 | + private int buildTree(int i, int tl, int tr, int[] values) { |
| 52 | + if (tl == tr) { |
| 53 | + return t[i] = values[tl]; |
27 | 54 | }
|
28 |
| - int leftChild = (i * 2); |
29 |
| - int rightChild = (i * 2) + 1; |
30 |
| - int mid = (l + r) / 2; |
31 |
| - // TODO(will): herm... doesn't look quite righT? |
32 |
| - // t[i] = buildTree(leftChild, l, mid, values) + buildTree(rightChild, mid, r, values); |
| 55 | + int mid = (tl + tr) / 2; |
| 56 | + |
| 57 | + // TODO(william): fix segment index out of bounds issue |
| 58 | + // System.out.printf("Range [%d, %d] splits into: [%d, %d] and [%d, %d] | %d -> %d and %d\n", l, |
| 59 | + // r, l, mid, mid+1, r, i, tl, tr); |
| 60 | + int lSum = buildTree(2 * i + 1, tl, mid, values); |
| 61 | + int rSum = buildTree(2 * i + 2, mid + 1, tr, values); |
| 62 | + |
| 63 | + // TODO(william): Make generic to support min, max and other queries. One idea is to keep |
| 64 | + // segment multiple trees for each query type? |
| 65 | + return t[i] = lSum + rSum; |
| 66 | + } |
| 67 | + |
| 68 | + /** |
| 69 | + * Returns the sum of the range [l, r] in the original `values` array. |
| 70 | + * |
| 71 | + * @param l the left endpoint of the sum range query (inclusive) |
| 72 | + * @param r the right endpoint of the sum range query (inclusive) |
| 73 | + */ |
| 74 | + public long sumQuery(int l, int r) { |
| 75 | + return sumQuery(0, 0, n - 1, l, r); |
| 76 | + } |
33 | 77 |
|
34 |
| - return 0; |
| 78 | + /** |
| 79 | + * @param i the index of the current segment in the tree |
| 80 | + * @param tl the left endpoint that the of the current segment |
| 81 | + * @param tr the right endpoint that the of the current segment |
| 82 | + * @param l the target left endpoint for the range query |
| 83 | + * @param r the target right endpoint for the range query |
| 84 | + */ |
| 85 | + private long sumQuery(int i, int tl, int tr, int l, int r) { |
| 86 | + if (l > r) { |
| 87 | + return 0; |
| 88 | + } |
| 89 | + if (tl == l && tr == r) { |
| 90 | + return t[i]; |
| 91 | + } |
| 92 | + int tm = (tl + tr) / 2; |
| 93 | + // Instead of checking if [tl, tm] overlaps [l, r] and [tm+1, tr] overlaps |
| 94 | + // [l, r], simply recurse on both and return a sum of 0 if the interval is invalid. |
| 95 | + return sumQuery(2 * i + 1, tl, tm, l, Math.min(tm, r)) |
| 96 | + + sumQuery(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r); |
35 | 97 | }
|
36 | 98 |
|
37 |
| - public int query(int l, int r) { |
38 |
| - return 0; |
| 99 | + // Alternative implementation of summing that intelligently only digs into |
| 100 | + // the branches which overlap with the query [l, r] |
| 101 | + private long sumQuery2(int i, int tl, int tr, int l, int r) { |
| 102 | + if (tl == l && tr == r) { |
| 103 | + return t[i]; |
| 104 | + } |
| 105 | + int tm = (tl + tr) / 2; |
| 106 | + // Test how the current segment [tl, tr] overlaps with the query [l, r] |
| 107 | + boolean overlapsLeftSegment = (l <= tm); |
| 108 | + boolean overlapsRightSegment = (r > tm); |
| 109 | + if (overlapsLeftSegment && overlapsRightSegment) { |
| 110 | + return sumQuery2(2 * i + 1, tl, tm, l, Math.min(tm, r)) |
| 111 | + + sumQuery2(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r); |
| 112 | + } else if (overlapsLeftSegment) { |
| 113 | + return sumQuery2(2 * i + 1, tl, tm, l, Math.min(tm, r)); |
| 114 | + } else { |
| 115 | + return sumQuery2(2 * i + 2, tm + 1, tr, Math.max(l, tm + 1), r); |
| 116 | + } |
39 | 117 | }
|
40 | 118 |
|
41 | 119 | public void set(int i, int value) {
|
42 | 120 | // update(i, 0, n-1, value);
|
43 | 121 | }
|
44 | 122 |
|
45 |
| - private void update(int at, int to, int l, int r, int value) { |
46 |
| - if (l == r) { |
| 123 | + public void update(int i, int newValue) { |
| 124 | + update(0, i, 0, n - 1, newValue); |
| 125 | + } |
| 126 | + |
| 127 | + /** |
| 128 | + * Update a point in the segment tree by doing a binary search, updating the leaf node and |
| 129 | + * re-computing all the segment values on the callback. |
| 130 | + * |
| 131 | + * @param at the index of the current segment in the tree |
| 132 | + * @param to the target position to update left endpoint for the range query |
| 133 | + * @param tl the left endpoint that the of the current segment |
| 134 | + * @param tr the right endpoint that the of the current segment |
| 135 | + * @param r the target right endpoint for the range query |
| 136 | + */ |
| 137 | + private void update(int at, int to, int tl, int tr, int newValue) { |
| 138 | + if (tl > tr) { |
| 139 | + return; |
| 140 | + } |
| 141 | + if (tl == tr) { // or `tl == to && tr == to` |
| 142 | + t[at] = newValue; |
47 | 143 | return;
|
| 144 | + } |
| 145 | + int tm = (tl + tr) / 2; |
| 146 | + // Dig into the left segment |
| 147 | + if (to <= tm) { |
| 148 | + update(2 * at + 1, to, tl, tm, newValue); |
| 149 | + // Dig into the right segment |
48 | 150 | } else {
|
49 |
| - int lv = t[at * 2]; |
50 |
| - int rv = t[at * 2 + 1]; |
51 |
| - int m = (l + r) >>> 1; |
52 |
| - if (l <= r) {} |
| 151 | + update(2 * at + 2, to, tm + 1, tr, newValue); |
53 | 152 | }
|
| 153 | + // Re-compute the segment value of the current segment on the callback |
| 154 | + t[at] = t[2 * at + 1] + t[2 * at + 2]; |
| 155 | + } |
| 156 | + |
| 157 | + public static void main(String[] args) { |
| 158 | + int[] values = new int[6]; |
| 159 | + java.util.Arrays.fill(values, 1); |
| 160 | + RangeQueryPointUpdateSegmentTree st = new RangeQueryPointUpdateSegmentTree(values); |
| 161 | + System.out.println(st.sumQuery(1, 4)); |
| 162 | + |
| 163 | + st.update(1, 2); |
| 164 | + System.out.println(st.sumQuery(1, 1)); |
| 165 | + System.out.println(st.sumQuery(0, 1)); |
| 166 | + System.out.println(st.sumQuery(0, 2)); |
| 167 | + |
| 168 | + // for (int i = 1; i < 500; i++) { |
| 169 | + // // System.out.println(); |
| 170 | + // int[] values = new int[i]; |
| 171 | + // java.util.Arrays.fill(values, 1); |
| 172 | + // RangeQueryPointUpdateSegmentTree st = new RangeQueryPointUpdateSegmentTree(values); |
| 173 | + // } |
| 174 | + // for (int i = 1; i < 20; i++) { |
| 175 | + // System.out.printf("%d -> %d\n", i, nextPowerOf2(i)); |
| 176 | + // } |
54 | 177 | }
|
55 | 178 | }
|
0 commit comments