-
Notifications
You must be signed in to change notification settings - Fork 2
/
区间树
173 lines (150 loc) · 4.95 KB
/
区间树
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <iostream>
#include <vector>
#include <math.h>
using namespace std;
int getMin(vector<int>& hi, int i, int j);
//保存在数组中
int constructST(vector<int>& hi, vector<int>& st, int idx, int l, int r) {
if(l > r) return -1;
if(l == r) {
st[idx] = l;
return st[idx];
}
int mid = l + (r - l) / 2;
int pos = getMin(hi, constructST(hi, st, 2 * idx + 1, l, mid), constructST(hi, st, 2 * idx + 2, mid + 1, r));
st[idx] = pos;
return st[idx];
}
//返回索引下标
int getMin(vector<int>& hi, int i, int j) {
if(i == -1) return j;
if(j == -1) return i;
return hi[i] < hi[j] ? i : j;
}
//范围查找,返回下标
int find(vector<int>& hi, vector<int>& st, int s, int e, int fs, int fe, int pos) {
if(s > fe || e < fs || s > e || fs > fe) return -1;
if(fs <= s && fe >= e) {
return st[pos];
}
int mid = s + (e - s) / 2;
int ans = getMin(hi, find(hi, st, s, mid, fs, fe, 2 * pos + 1), find(hi, st, mid + 1, e, fs, fe, 2 * pos + 2));
return ans;
}
int getStSize(int n) {
int level = ceil(log(n) / log(2)) + 1;
cout<<"log(n)="<<log(n)<<endl;
cout <<level<<" "<< pow(2, level) - 1<<endl;
return pow(2, level) - 1;
}
int main() {
int arr[5] = {1,4,2,6,9};
vector<int> hi(arr, arr + sizeof(arr)/sizeof(int));
int n = hi.size();
cout<<"hi size="<<n<<endl;
vector<int> st(getStSize(n), -1);
constructST(hi, st, 0, 0, n - 1);
for(int i=0; i<st.size(); i++) {
cout<<st[i]<<" ";
}
cout<<endl;
return 0;
}
package cn.webank.ctrl.learn.lecture.lecture.list;
import java.util.ArrayList;
import java.util.List;
/**
* 区间数-查找某个区间的最小值
* @author ctrlzhang on 2018/5/15.
*/
public class SegmentTree {
private List<Integer> data = new ArrayList<>();
public SegmentTree(List<Integer> data) {
this.data = data;
}
/**
* 线段树中存储的是 原始数组的元素下标
* 注意返回的是树 原始数组的下标
*/
public int constructTree(List<Integer> data, List<Integer> tree, int l, int r, int idx) {
//System.out.println("idx=" + idx);
if (l > r) return -1;
if (l == r) {
System.out.println("l == r " + l + " " + idx);
tree.set(idx, l);
return l;
}
int mid = l + (r - l) / 2;
int minPos = getMin(data, constructTree(data, tree, l, mid, 2 * idx +1), constructTree(data, tree, mid +1,
r, 2 * idx +2));
tree.set(idx, minPos);
return minPos;
}
int getMin(List<Integer> data, int l, int r) {
if (l == -1) return r;
if (r == -1) return l;
return data.get(l) > data.get(r) ? l : r;
}
/**
* 开始查找 - 查找过程也是递归的查找 - 查找到某一个区间内的最小值
* (ss-se)同pos是一一对应的
*/
int search(List<Integer> data, List<Integer> tree, int ss, int se, int fs, int fe, int pos) {
System.out.print(ss + " ");
System.out.print(se + " ");
System.out.print(fs + " ");
System.out.print(fe + " ");
System.out.println(pos);
//非相交区间
if (ss > se || fs > fe || ss > fe || se < fs) return -1;
//查找区间覆盖了 线段树区间
if (fs <= ss && fe >= se) {
return tree.get(pos);
}
//相交的情况
int mid = ss + (se - ss) / 2;
int findAns = getMin(data,
search(data, tree, ss, mid, fs, fe, 2 * pos +1), //下标索引
search(data, tree, mid + 1, se, fs, fe, 2 * pos +2));//下标索引
if (-1 != findAns) {
System.out.println("findAns=" + findAns + " " + data.get(findAns));
}
return findAns;
}
public static void main(String[] args) {
List<Integer> input = new ArrayList<>();
input.add(1);
input.add(2);
input.add(3);
input.add(9);
input.add(11);
input.add(13);
SegmentTree tree = new SegmentTree(input);
int num = input.size();
int k = (int) (Math.ceil(Math.log(num) / Math.log(2)) + 1);
int initSize = (int) (Math.pow(2, k) - 1);
//System.out.println(k);
//System.out.println(num);
//System.out.println(initSize);
List<Integer> segTree = new ArrayList<Integer>(initSize);
for (int i = 0; i < initSize; i++) {
segTree.add(-1);
}
int pos = tree.constructTree(input, segTree, 0, input.size() - 1, 0);
/*
System.out.println(pos);
System.out.println("=====");
for (int d : segTree) {
System.out.println(d);
}
System.out.println("=====");
*/
int ans = tree.search(input, segTree, 0, input.size() - 1, 0, 4, 0);
if (-1 == ans) {
System.out.println("not found");
return;
}
System.out.println(ans);
System.out.println(input.get(ans));
}
}