]> Pileus Git - ~andy/linux/blob - drivers/infiniband/hw/usnic/usnic_uiom_interval_tree.c
Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/jikos/hid
[~andy/linux] / drivers / infiniband / hw / usnic / usnic_uiom_interval_tree.c
1 #include <linux/init.h>
2 #include <linux/list.h>
3 #include <linux/slab.h>
4 #include <linux/list_sort.h>
5
6 #include <linux/interval_tree_generic.h>
7 #include "usnic_uiom_interval_tree.h"
8
9 #define START(node) ((node)->start)
10 #define LAST(node) ((node)->last)
11
12 #define MAKE_NODE(node, start, end, ref_cnt, flags, err, err_out)       \
13                 do {                                                    \
14                         node = usnic_uiom_interval_node_alloc(start,    \
15                                         end, ref_cnt, flags);           \
16                                 if (!node) {                            \
17                                         err = -ENOMEM;                  \
18                                         goto err_out;                   \
19                                 }                                       \
20                 } while (0)
21
22 #define MARK_FOR_ADD(node, list) (list_add_tail(&node->link, list))
23
24 #define MAKE_NODE_AND_APPEND(node, start, end, ref_cnt, flags, err,     \
25                                 err_out, list)                          \
26                                 do {                                    \
27                                         MAKE_NODE(node, start, end,     \
28                                                 ref_cnt, flags, err,    \
29                                                 err_out);               \
30                                         MARK_FOR_ADD(node, list);       \
31                                 } while (0)
32
33 #define FLAGS_EQUAL(flags1, flags2, mask)                               \
34                         (((flags1) & (mask)) == ((flags2) & (mask)))
35
36 static struct usnic_uiom_interval_node*
37 usnic_uiom_interval_node_alloc(long int start, long int last, int ref_cnt,
38                                 int flags)
39 {
40         struct usnic_uiom_interval_node *interval = kzalloc(sizeof(*interval),
41                                                                 GFP_ATOMIC);
42         if (!interval)
43                 return NULL;
44
45         interval->start = start;
46         interval->last = last;
47         interval->flags = flags;
48         interval->ref_cnt = ref_cnt;
49
50         return interval;
51 }
52
53 static int interval_cmp(void *priv, struct list_head *a, struct list_head *b)
54 {
55         struct usnic_uiom_interval_node *node_a, *node_b;
56
57         node_a = list_entry(a, struct usnic_uiom_interval_node, link);
58         node_b = list_entry(b, struct usnic_uiom_interval_node, link);
59
60         /* long to int */
61         if (node_a->start < node_b->start)
62                 return -1;
63         else if (node_a->start > node_b->start)
64                 return 1;
65
66         return 0;
67 }
68
69 static void
70 find_intervals_intersection_sorted(struct rb_root *root, unsigned long start,
71                                         unsigned long last,
72                                         struct list_head *list)
73 {
74         struct usnic_uiom_interval_node *node;
75
76         INIT_LIST_HEAD(list);
77
78         for (node = usnic_uiom_interval_tree_iter_first(root, start, last);
79                 node;
80                 node = usnic_uiom_interval_tree_iter_next(node, start, last))
81                 list_add_tail(&node->link, list);
82
83         list_sort(NULL, list, interval_cmp);
84 }
85
86 int usnic_uiom_get_intervals_diff(unsigned long start, unsigned long last,
87                                         int flags, int flag_mask,
88                                         struct rb_root *root,
89                                         struct list_head *diff_set)
90 {
91         struct usnic_uiom_interval_node *interval, *tmp;
92         int err = 0;
93         long int pivot = start;
94         LIST_HEAD(intersection_set);
95
96         INIT_LIST_HEAD(diff_set);
97
98         find_intervals_intersection_sorted(root, start, last,
99                                                 &intersection_set);
100
101         list_for_each_entry(interval, &intersection_set, link) {
102                 if (pivot < interval->start) {
103                         MAKE_NODE_AND_APPEND(tmp, pivot, interval->start - 1,
104                                                 1, flags, err, err_out,
105                                                 diff_set);
106                         pivot = interval->start;
107                 }
108
109                 /*
110                  * Invariant: Set [start, pivot] is either in diff_set or root,
111                  * but not in both.
112                  */
113
114                 if (pivot > interval->last) {
115                         continue;
116                 } else if (pivot <= interval->last &&
117                                 FLAGS_EQUAL(interval->flags, flags,
118                                 flag_mask)) {
119                         pivot = interval->last + 1;
120                 }
121         }
122
123         if (pivot <= last)
124                 MAKE_NODE_AND_APPEND(tmp, pivot, last, 1, flags, err, err_out,
125                                         diff_set);
126
127         return 0;
128
129 err_out:
130         list_for_each_entry_safe(interval, tmp, diff_set, link) {
131                 list_del(&interval->link);
132                 kfree(interval);
133         }
134
135         return err;
136 }
137
138 void usnic_uiom_put_interval_set(struct list_head *intervals)
139 {
140         struct usnic_uiom_interval_node *interval, *tmp;
141         list_for_each_entry_safe(interval, tmp, intervals, link)
142                 kfree(interval);
143 }
144
145 int usnic_uiom_insert_interval(struct rb_root *root, unsigned long start,
146                                 unsigned long last, int flags)
147 {
148         struct usnic_uiom_interval_node *interval, *tmp;
149         unsigned long istart, ilast;
150         int iref_cnt, iflags;
151         unsigned long lpivot = start;
152         int err = 0;
153         LIST_HEAD(to_add);
154         LIST_HEAD(intersection_set);
155
156         find_intervals_intersection_sorted(root, start, last,
157                                                 &intersection_set);
158
159         list_for_each_entry(interval, &intersection_set, link) {
160                 /*
161                  * Invariant - lpivot is the left edge of next interval to be
162                  * inserted
163                  */
164                 istart = interval->start;
165                 ilast = interval->last;
166                 iref_cnt = interval->ref_cnt;
167                 iflags = interval->flags;
168
169                 if (istart < lpivot) {
170                         MAKE_NODE_AND_APPEND(tmp, istart, lpivot - 1, iref_cnt,
171                                                 iflags, err, err_out, &to_add);
172                 } else if (istart > lpivot) {
173                         MAKE_NODE_AND_APPEND(tmp, lpivot, istart - 1, 1, flags,
174                                                 err, err_out, &to_add);
175                         lpivot = istart;
176                 } else {
177                         lpivot = istart;
178                 }
179
180                 if (ilast > last) {
181                         MAKE_NODE_AND_APPEND(tmp, lpivot, last, iref_cnt + 1,
182                                                 iflags | flags, err, err_out,
183                                                 &to_add);
184                         MAKE_NODE_AND_APPEND(tmp, last + 1, ilast, iref_cnt,
185                                                 iflags, err, err_out, &to_add);
186                 } else {
187                         MAKE_NODE_AND_APPEND(tmp, lpivot, ilast, iref_cnt + 1,
188                                                 iflags | flags, err, err_out,
189                                                 &to_add);
190                 }
191
192                 lpivot = ilast + 1;
193         }
194
195         if (lpivot <= last)
196                 MAKE_NODE_AND_APPEND(tmp, lpivot, last, 1, flags, err, err_out,
197                                         &to_add);
198
199         list_for_each_entry_safe(interval, tmp, &intersection_set, link) {
200                 usnic_uiom_interval_tree_remove(interval, root);
201                 kfree(interval);
202         }
203
204         list_for_each_entry(interval, &to_add, link)
205                 usnic_uiom_interval_tree_insert(interval, root);
206
207         return 0;
208
209 err_out:
210         list_for_each_entry_safe(interval, tmp, &to_add, link)
211                 kfree(interval);
212
213         return err;
214 }
215
216 void usnic_uiom_remove_interval(struct rb_root *root, unsigned long start,
217                                 unsigned long last, struct list_head *removed)
218 {
219         struct usnic_uiom_interval_node *interval;
220
221         for (interval = usnic_uiom_interval_tree_iter_first(root, start, last);
222                         interval;
223                         interval = usnic_uiom_interval_tree_iter_next(interval,
224                                                                         start,
225                                                                         last)) {
226                 if (--interval->ref_cnt == 0)
227                         list_add_tail(&interval->link, removed);
228         }
229
230         list_for_each_entry(interval, removed, link)
231                 usnic_uiom_interval_tree_remove(interval, root);
232 }
233
234 INTERVAL_TREE_DEFINE(struct usnic_uiom_interval_node, rb,
235                         unsigned long, __subtree_last,
236                         START, LAST, , usnic_uiom_interval_tree)