casbin/rbac/
default_role_manager.rs

1use crate::{
2    error::RbacError,
3    rbac::{MatchingFn, RoleManager},
4    Result,
5};
6use petgraph::stable_graph::{NodeIndex, StableDiGraph};
7use std::collections::{hash_map::Entry, HashMap, HashSet};
8
9#[cfg(feature = "cached")]
10use crate::cache::{Cache, DefaultCache};
11
12#[cfg(feature = "cached")]
13use std::{
14    collections::hash_map::DefaultHasher,
15    hash::{Hash, Hasher},
16};
17
18const DEFAULT_DOMAIN: &str = "DEFAULT";
19
20pub struct DefaultRoleManager {
21    all_domains: HashMap<String, StableDiGraph<String, EdgeVariant>>,
22    all_domains_indices: HashMap<String, HashMap<String, NodeIndex<u32>>>,
23    #[cfg(feature = "cached")]
24    cache: DefaultCache<u64, bool>,
25    max_hierarchy_level: usize,
26    role_matching_fn: Option<MatchingFn>,
27    domain_matching_fn: Option<MatchingFn>,
28}
29
30#[derive(Clone, Debug)]
31enum EdgeVariant {
32    Link,
33    Match,
34}
35
36impl DefaultRoleManager {
37    pub fn new(max_hierarchy_level: usize) -> Self {
38        DefaultRoleManager {
39            all_domains: HashMap::new(),
40            all_domains_indices: HashMap::new(),
41            max_hierarchy_level,
42            #[cfg(feature = "cached")]
43            cache: DefaultCache::new(50),
44            role_matching_fn: None,
45            domain_matching_fn: None,
46        }
47    }
48
49    fn get_or_create_role(
50        &mut self,
51        name: &str,
52        domain: Option<&str>,
53    ) -> NodeIndex<u32> {
54        let domain = domain.unwrap_or(DEFAULT_DOMAIN);
55
56        // detect whether this is a new domain creation
57        let is_new_domain = !self.all_domains.contains_key(domain);
58
59        let graph = self.all_domains.entry(domain.into()).or_default();
60
61        let role_entry = self
62            .all_domains_indices
63            .entry(domain.into())
64            .or_default()
65            .entry(name.into());
66
67        let vacant_entry = match role_entry {
68            Entry::Occupied(e) => return *e.get(),
69            Entry::Vacant(e) => e,
70        };
71
72        let new_role_id = graph.add_node(name.into());
73        vacant_entry.insert(new_role_id);
74
75        if let Some(role_matching_fn) = self.role_matching_fn {
76            let mut added = false;
77
78            let node_ids: Vec<_> =
79                graph.node_indices().filter(|&i| graph[i] != name).collect();
80
81            for existing_role_id in node_ids {
82                added |= link_if_matches(
83                    graph,
84                    role_matching_fn,
85                    new_role_id,
86                    existing_role_id,
87                );
88
89                added |= link_if_matches(
90                    graph,
91                    role_matching_fn,
92                    existing_role_id,
93                    new_role_id,
94                );
95            }
96
97            if added {
98                #[cfg(feature = "cached")]
99                self.cache.clear();
100            }
101        }
102
103        // If domain matching function exists and this was a new domain, copy
104        // role links from matching domains into the newly created domain so
105        // that BFS will see inherited links in this domain's graph.
106        if is_new_domain {
107            if let Some(domain_matching_fn) = self.domain_matching_fn {
108                let keys: Vec<String> =
109                    self.all_domains.keys().cloned().collect();
110                for d in keys {
111                    if d != domain && (domain_matching_fn)(domain, &d) {
112                        self.copy_from_domain(&d, domain);
113                    }
114                }
115            }
116        }
117
118        new_role_id
119    }
120
121    // propagate a Link addition (name1 -> name2) from `domain` into all
122    // affected/matching domains. This extracts the inline logic from
123    // `add_link` so the code is clearer and avoids nested borrows.
124    fn propagate_link_to_affected_domains(
125        &mut self,
126        name1: &str,
127        name2: &str,
128        domain: &str,
129    ) {
130        let name1_owned = name1.to_string();
131        let name2_owned = name2.to_string();
132        let affected = self.affected_domain_names(domain);
133        for d in affected {
134            // obtain mutable graph and index map for the affected domain
135            let g = self.all_domains.get_mut(&d).unwrap();
136            let idx_map =
137                self.all_domains_indices.entry(d.clone()).or_default();
138            let idx1 = Self::ensure_node_in_graph(g, idx_map, &name1_owned);
139            let idx2 = Self::ensure_node_in_graph(g, idx_map, &name2_owned);
140
141            // add Link edge if missing
142            let has_link = g
143                .edges_connecting(idx1, idx2)
144                .any(|e| matches!(*e.weight(), EdgeVariant::Link));
145            if !has_link {
146                g.add_edge(idx1, idx2, EdgeVariant::Link);
147            }
148        }
149
150        #[cfg(feature = "cached")]
151        self.cache.clear();
152    }
153
154    // ensure a node with `name` exists in graph `g` and in the provided
155    // `idx_map`. Returns the NodeIndex for the node.
156    fn ensure_node_in_graph(
157        g: &mut StableDiGraph<String, EdgeVariant>,
158        idx_map: &mut HashMap<String, NodeIndex<u32>>,
159        name: &str,
160    ) -> NodeIndex<u32> {
161        if let Some(idx) = idx_map.get(name) {
162            *idx
163        } else if let Some(idx) = g.node_indices().find(|&i| g[i] == name) {
164            idx_map.insert(name.to_string(), idx);
165            idx
166        } else {
167            let ni = g.add_node(name.to_string());
168            idx_map.insert(name.to_string(), ni);
169            ni
170        }
171    }
172
173    // return the list of affected domain names (immutable) to avoid nested
174    // mutable borrows when performing operations across domains
175    fn affected_domain_names(&self, domain: &str) -> Vec<String> {
176        if let Some(matcher) = self.domain_matching_fn {
177            self.all_domains
178                .keys()
179                .filter(|d| *d != domain && matcher(d, domain))
180                .cloned()
181                .collect()
182        } else {
183            Vec::new()
184        }
185    }
186
187    // copy all role links and nodes from `src_domain` graph into `dst_domain` graph
188    fn copy_from_domain(&mut self, src_domain: &str, dst_domain: &str) {
189        if src_domain == dst_domain {
190            return;
191        }
192
193        // ensure both graphs exist
194        if !self.all_domains.contains_key(src_domain) {
195            return;
196        }
197
198        let src_graph = match self.all_domains.get(src_domain) {
199            Some(g) => g.clone(),
200            None => return,
201        };
202
203        // ensure dst indices map exists
204        let dst_indices = self
205            .all_domains_indices
206            .entry(dst_domain.into())
207            .or_default();
208
209        let dst_graph = self.all_domains.entry(dst_domain.into()).or_default();
210
211        // copy nodes: ensure names exist in dst and capture mapping
212        let mut id_map: HashMap<NodeIndex<u32>, NodeIndex<u32>> =
213            HashMap::new();
214        for src_idx in src_graph.node_indices() {
215            let name = &src_graph[src_idx];
216            let dst_idx = if let Some(idx) = dst_indices.get(name) {
217                *idx
218            } else {
219                let new_idx = dst_graph.add_node(name.clone());
220                dst_indices.insert(name.clone(), new_idx);
221                new_idx
222            };
223            id_map.insert(src_idx, dst_idx);
224        }
225
226        // copy edges: for each edge in src_graph, add equivalent edge in dst if missing
227        for edge_idx in src_graph.edge_indices() {
228            if let Some((src_s, src_t)) = src_graph.edge_endpoints(edge_idx) {
229                if let Some(weight) = src_graph.edge_weight(edge_idx) {
230                    let dst_s = id_map.get(&src_s).unwrap();
231                    let dst_t = id_map.get(&src_t).unwrap();
232
233                    let need_add = match dst_graph.find_edge(*dst_s, *dst_t) {
234                        Some(idx) => {
235                            // if existing edge is Match but source weight is Link, allow adding Link
236                            !matches!(dst_graph[idx], EdgeVariant::Match)
237                                || !matches!(weight, &EdgeVariant::Match)
238                        }
239                        None => true,
240                    };
241
242                    if need_add {
243                        dst_graph.add_edge(*dst_s, *dst_t, weight.clone());
244                    }
245                }
246            }
247        }
248
249        #[cfg(feature = "cached")]
250        self.cache.clear();
251    }
252
253    fn matched_domains(&self, domain: Option<&str>) -> Vec<String> {
254        let domain = domain.unwrap_or(DEFAULT_DOMAIN);
255        if let Some(domain_matching_fn) = self.domain_matching_fn {
256            self.all_domains
257                .keys()
258                .filter_map(|key| {
259                    if domain_matching_fn(domain, key) {
260                        Some(key.to_owned())
261                    } else {
262                        None
263                    }
264                })
265                .collect::<Vec<String>>()
266        } else {
267            self.all_domains
268                .get(domain)
269                .map_or(vec![], |_| vec![domain.to_owned()])
270        }
271    }
272
273    fn domain_has_role(&self, name: &str, domain: Option<&str>) -> bool {
274        let matched_domains = self.matched_domains(domain);
275
276        matched_domains.iter().any(|domain| {
277            // try to find direct match of role
278            if self.all_domains_indices[domain].contains_key(name) {
279                true
280            } else if let Some(role_matching_fn) = self.role_matching_fn {
281                // else if role_matching_fn is set, iterate all graph nodes and try to find matching role
282                let graph = &self.all_domains[domain];
283
284                graph
285                    .node_weights()
286                    .any(|role| role_matching_fn(name, role))
287            } else {
288                false
289            }
290        })
291    }
292}
293
294/// link node of `not_pattern_id` to `maybe_pattern_id` if
295/// `not_pattern` matches `maybe_pattern`'s pattern and
296/// there doesn't exist a match edge yet
297fn link_if_matches(
298    graph: &mut StableDiGraph<String, EdgeVariant>,
299    role_matching_fn: fn(&str, &str) -> bool,
300    not_pattern_id: NodeIndex<u32>,
301    maybe_pattern_id: NodeIndex<u32>,
302) -> bool {
303    let not_pattern = &graph[not_pattern_id];
304    let maybe_pattern = &graph[maybe_pattern_id];
305
306    if !role_matching_fn(maybe_pattern, not_pattern) {
307        return false;
308    }
309
310    let add_edge =
311        if let Some(idx) = graph.find_edge(not_pattern_id, maybe_pattern_id) {
312            !matches!(graph[idx], EdgeVariant::Match)
313        } else {
314            true
315        };
316
317    if add_edge {
318        graph.add_edge(not_pattern_id, maybe_pattern_id, EdgeVariant::Match);
319
320        true
321    } else {
322        false
323    }
324}
325
326impl RoleManager for DefaultRoleManager {
327    fn clear(&mut self) {
328        self.all_domains_indices.clear();
329        self.all_domains.clear();
330        #[cfg(feature = "cached")]
331        self.cache.clear();
332    }
333
334    fn add_link(&mut self, name1: &str, name2: &str, domain: Option<&str>) {
335        if name1 == name2 {
336            return;
337        }
338
339        let role1 = self.get_or_create_role(name1, domain);
340        let role2 = self.get_or_create_role(name2, domain);
341
342        let graph = self
343            .all_domains
344            .get_mut(domain.unwrap_or(DEFAULT_DOMAIN))
345            .unwrap();
346
347        let add_link = if let Some(edge) = graph.find_edge(role1, role2) {
348            !matches!(graph[edge], EdgeVariant::Link)
349        } else {
350            true
351        };
352
353        if add_link {
354            graph.add_edge(role1, role2, EdgeVariant::Link);
355
356            if let Some(domain_str) = domain {
357                self.propagate_link_to_affected_domains(
358                    name1, name2, domain_str,
359                );
360            }
361
362            #[cfg(feature = "cached")]
363            self.cache.clear();
364        }
365    }
366
367    fn matching_fn(
368        &mut self,
369        role_matching_fn: Option<MatchingFn>,
370        domain_matching_fn: Option<MatchingFn>,
371    ) {
372        self.domain_matching_fn = domain_matching_fn;
373        self.role_matching_fn = role_matching_fn;
374    }
375
376    fn delete_link(
377        &mut self,
378        name1: &str,
379        name2: &str,
380        domain: Option<&str>,
381    ) -> Result<()> {
382        if !self.domain_has_role(name1, domain)
383            || !self.domain_has_role(name2, domain)
384        {
385            return Err(
386                RbacError::NotFound(format!("{} OR {}", name1, name2)).into()
387            );
388        }
389
390        let role1 = self.get_or_create_role(name1, domain);
391        let role2 = self.get_or_create_role(name2, domain);
392
393        let graph = self
394            .all_domains
395            .get_mut(domain.unwrap_or(DEFAULT_DOMAIN))
396            .unwrap();
397
398        if let Some(edge_index) = graph.find_edge(role1, role2) {
399            graph.remove_edge(edge_index).unwrap();
400
401            #[cfg(feature = "cached")]
402            self.cache.clear();
403        }
404
405        Ok(())
406    }
407
408    fn has_link(&self, name1: &str, name2: &str, domain: Option<&str>) -> bool {
409        if name1 == name2 {
410            return true;
411        }
412
413        #[cfg(feature = "cached")]
414        let cache_key = {
415            let mut hasher = DefaultHasher::new();
416            name1.hash(&mut hasher);
417            name2.hash(&mut hasher);
418            domain.unwrap_or(DEFAULT_DOMAIN).hash(&mut hasher);
419            hasher.finish()
420        };
421
422        #[cfg(feature = "cached")]
423        if let Some(res) = self.cache.get(&cache_key) {
424            return res;
425        }
426
427        let matched_domains = self.matched_domains(domain);
428
429        let mut res = false;
430
431        for domain in matched_domains {
432            let graph = self.all_domains.get(&domain).unwrap();
433            let indices = self.all_domains_indices.get(&domain).unwrap();
434
435            let role1 = if let Some(role1) = indices.get(name1) {
436                Some(*role1)
437            } else {
438                graph.node_indices().find(|&i| {
439                    let role_name = &graph[i];
440
441                    role_name == name1
442                        || self
443                            .role_matching_fn
444                            .map(|f| f(name1, role_name))
445                            .unwrap_or_default()
446                })
447            };
448
449            let role1 = if let Some(role1) = role1 {
450                role1
451            } else {
452                continue;
453            };
454
455            let mut bfs = matching_bfs::Bfs::new(
456                graph,
457                role1,
458                self.max_hierarchy_level,
459                self.role_matching_fn.is_some(),
460            );
461
462            while let Some(node) = bfs.next(graph) {
463                let role_name = &graph[node];
464
465                if role_name == name2
466                    || self
467                        .role_matching_fn
468                        .map(|f| f(role_name, name2))
469                        .unwrap_or_default()
470                {
471                    res = true;
472                    break;
473                }
474            }
475        }
476
477        #[cfg(feature = "cached")]
478        self.cache.set(cache_key, res);
479
480        res
481    }
482
483    fn get_roles(&self, name: &str, domain: Option<&str>) -> Vec<String> {
484        let matched_domains = self.matched_domains(domain);
485
486        let res = matched_domains.into_iter().fold(
487            HashSet::new(),
488            |mut set, domain| {
489                let graph = &self.all_domains[&domain];
490
491                if let Some(role_node) = graph.node_indices().find(|&i| {
492                    graph[i] == name
493                        || self.role_matching_fn.unwrap_or(|_, _| false)(
494                            name, &graph[i],
495                        )
496                }) {
497                    let neighbors = matching_bfs::bfs_iterator(
498                        graph,
499                        role_node,
500                        self.role_matching_fn.is_some(),
501                    )
502                    .map(|i| graph[i].clone());
503
504                    set.extend(neighbors);
505                }
506
507                set
508            },
509        );
510        res.into_iter().collect()
511    }
512
513    fn get_users(&self, name: &str, domain: Option<&str>) -> Vec<String> {
514        let matched_domains = self.matched_domains(domain);
515
516        let res = matched_domains.into_iter().fold(
517            HashSet::new(),
518            |mut set, domain| {
519                let graph = &self.all_domains[&domain];
520
521                if let Some(role_node) = graph.node_indices().find(|&i| {
522                    graph[i] == name
523                        || self
524                            .role_matching_fn
525                            .map(|f| f(name, &graph[i]))
526                            .unwrap_or_default()
527                }) {
528                    let neighbors = graph
529                        .neighbors_directed(
530                            role_node,
531                            petgraph::Direction::Incoming,
532                        )
533                        .map(|i| graph[i].clone());
534
535                    set.extend(neighbors);
536                }
537
538                set
539            },
540        );
541
542        res.into_iter().collect()
543    }
544}
545
546mod matching_bfs {
547    use super::EdgeVariant;
548    use fixedbitset::FixedBitSet;
549    use petgraph::graph::NodeIndex;
550    use petgraph::stable_graph::StableDiGraph;
551    use petgraph::visit::{EdgeRef, VisitMap, Visitable};
552    use std::collections::VecDeque;
553
554    #[derive(Clone)]
555    pub(super) struct Bfs {
556        /// The queue of nodes to visit
557        pub queue: VecDeque<NodeIndex<u32>>,
558        /// The map of discovered nodes
559        pub discovered: FixedBitSet,
560        /// Maximum depth
561        pub max_depth: usize,
562        /// Consider `Match` edges
563        pub with_pattern_matching: bool,
564
565        /// Current depth
566        depth: usize,
567        /// Number of elements until next depth is reached
568        depth_elements_remaining: usize,
569    }
570
571    impl Bfs {
572        /// Create a new **Bfs**, using the graph's visitor map, and put **start**
573        /// in the stack of nodes to visit.
574        pub fn new(
575            graph: &StableDiGraph<String, EdgeVariant>,
576            start: NodeIndex<u32>,
577            max_depth: usize,
578            with_pattern_matching: bool,
579        ) -> Self {
580            let mut discovered = graph.visit_map();
581            discovered.visit(start);
582
583            let mut queue = VecDeque::new();
584            queue.push_front(start);
585
586            Bfs {
587                queue,
588                discovered,
589                max_depth,
590                with_pattern_matching,
591                depth: 0,
592                depth_elements_remaining: 1,
593            }
594        }
595
596        /// Return the next node in the bfs, or **None** if the traversal is done.
597        pub fn next(
598            &mut self,
599            graph: &StableDiGraph<String, EdgeVariant>,
600        ) -> Option<NodeIndex<u32>> {
601            if self.max_depth <= self.depth {
602                return None;
603            }
604
605            if let Some(node) = self.queue.pop_front() {
606                self.update_depth();
607
608                let mut counter = 0;
609                for succ in
610                    bfs_iterator(graph, node, self.with_pattern_matching)
611                {
612                    if self.discovered.visit(succ) {
613                        self.queue.push_back(succ);
614                        counter += 1;
615                    }
616                }
617
618                self.depth_elements_remaining += counter;
619
620                Some(node)
621            } else {
622                None
623            }
624        }
625
626        fn update_depth(&mut self) {
627            self.depth_elements_remaining -= 1;
628            if self.depth_elements_remaining == 0 {
629                self.depth += 1
630            }
631        }
632    }
633
634    pub(super) fn bfs_iterator(
635        graph: &StableDiGraph<String, EdgeVariant>,
636        node: NodeIndex<u32>,
637        with_matches: bool,
638    ) -> Box<dyn Iterator<Item = NodeIndex<u32>> + '_> {
639        // outgoing LINK edges of node
640        let outgoing_direct_edge = graph
641            .edges_directed(node, petgraph::Direction::Outgoing)
642            .filter_map(|edge| match *edge.weight() {
643                EdgeVariant::Link => Some(edge.target()),
644                EdgeVariant::Match => None,
645            });
646
647        if !with_matches {
648            return Box::new(outgoing_direct_edge);
649        }
650
651        // x := outgoing LINK edges of node
652        // outgoing_match_edge : outgoing MATCH edges of x FOR ALL x
653        let outgoing_match_edge = graph
654            .edges_directed(node, petgraph::Direction::Outgoing)
655            .filter(|edge| matches!(*edge.weight(), EdgeVariant::Link))
656            .flat_map(move |edge| {
657                graph
658                    .edges_directed(
659                        edge.target(),
660                        petgraph::Direction::Outgoing,
661                    )
662                    .filter_map(|edge| match *edge.weight() {
663                        EdgeVariant::Match => Some(edge.target()),
664                        EdgeVariant::Link => None,
665                    })
666            });
667
668        // x := incoming MATCH edges of node
669        // sibling_matched_by := outgoing LINK edges of x FOR ALL x
670        let sibling_matched_by = graph
671            .edges_directed(node, petgraph::Direction::Incoming)
672            .filter(|edge| matches!(*edge.weight(), EdgeVariant::Match))
673            .flat_map(move |edge| {
674                graph
675                    .edges_directed(
676                        edge.source(),
677                        petgraph::Direction::Outgoing,
678                    )
679                    .filter_map(|edge| match *edge.weight() {
680                        EdgeVariant::Link => Some(edge.target()),
681                        EdgeVariant::Match => None,
682                    })
683            });
684
685        Box::new(
686            outgoing_direct_edge
687                .chain(outgoing_match_edge)
688                .chain(sibling_matched_by),
689        )
690    }
691
692    #[cfg(test)]
693    mod test {
694        use super::*;
695        use petgraph::stable_graph::StableDiGraph;
696
697        #[test]
698        fn test_max_depth() {
699            let mut deps = StableDiGraph::<String, EdgeVariant>::new();
700            let pg = deps.add_node("petgraph".into());
701            let fb = deps.add_node("fixedbitset".into());
702            let qc = deps.add_node("quickcheck".into());
703            let rand = deps.add_node("rand".into());
704            let libc = deps.add_node("libc".into());
705
706            deps.extend_with_edges([
707                (pg, fb, EdgeVariant::Link),
708                (pg, qc, EdgeVariant::Link),
709                (qc, rand, EdgeVariant::Link),
710                (rand, libc, EdgeVariant::Link),
711            ]);
712
713            let mut bfs = Bfs::new(&deps, pg, 2, false);
714
715            let mut nodes = vec![];
716            while let Some(x) = bfs.next(&deps) {
717                nodes.push(x);
718            }
719
720            assert!(nodes.contains(&fb));
721            assert!(nodes.contains(&qc));
722            assert!(nodes.contains(&rand));
723            assert!(!nodes.contains(&libc));
724        }
725    }
726}
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731
732    fn sort_unstable<T: Ord>(mut v: Vec<T>) -> Vec<T> {
733        v.sort_unstable();
734        v
735    }
736
737    #[test]
738    fn test_role() {
739        let mut rm = DefaultRoleManager::new(3);
740        rm.add_link("u1", "g1", None);
741        rm.add_link("u2", "g1", None);
742        rm.add_link("u3", "g2", None);
743        rm.add_link("u4", "g2", None);
744        rm.add_link("u4", "g3", None);
745        rm.add_link("g1", "g3", None);
746
747        assert_eq!(true, rm.has_link("u1", "g1", None));
748        assert_eq!(false, rm.has_link("u1", "g2", None));
749        assert_eq!(true, rm.has_link("u1", "g3", None));
750        assert_eq!(true, rm.has_link("u2", "g1", None));
751        assert_eq!(false, rm.has_link("u2", "g2", None));
752        assert_eq!(true, rm.has_link("u2", "g3", None));
753        assert_eq!(false, rm.has_link("u3", "g1", None));
754        assert_eq!(true, rm.has_link("u3", "g2", None));
755        assert_eq!(false, rm.has_link("u3", "g3", None));
756        assert_eq!(false, rm.has_link("u4", "g1", None));
757        assert_eq!(true, rm.has_link("u4", "g2", None));
758        assert_eq!(true, rm.has_link("u4", "g3", None));
759
760        // test get_roles
761        assert_eq!(vec!["g1"], rm.get_roles("u1", None));
762        assert_eq!(vec!["g1"], rm.get_roles("u2", None));
763        assert_eq!(vec!["g2"], rm.get_roles("u3", None));
764        assert_eq!(vec!["g2", "g3"], sort_unstable(rm.get_roles("u4", None)));
765        assert_eq!(vec!["g3"], rm.get_roles("g1", None));
766        assert_eq!(vec![String::new(); 0], rm.get_roles("g2", None));
767        assert_eq!(vec![String::new(); 0], rm.get_roles("g3", None));
768
769        // test delete_link
770        rm.delete_link("g1", "g3", None).unwrap();
771        rm.delete_link("u4", "g2", None).unwrap();
772        assert_eq!(true, rm.has_link("u1", "g1", None));
773        assert_eq!(false, rm.has_link("u1", "g2", None));
774        assert_eq!(false, rm.has_link("u1", "g3", None));
775        assert_eq!(true, rm.has_link("u2", "g1", None));
776        assert_eq!(false, rm.has_link("u2", "g2", None));
777        assert_eq!(false, rm.has_link("u2", "g3", None));
778        assert_eq!(false, rm.has_link("u3", "g1", None));
779        assert_eq!(true, rm.has_link("u3", "g2", None));
780        assert_eq!(false, rm.has_link("u3", "g3", None));
781        assert_eq!(false, rm.has_link("u4", "g1", None));
782        assert_eq!(false, rm.has_link("u4", "g2", None));
783        assert_eq!(true, rm.has_link("u4", "g3", None));
784        assert_eq!(vec!["g1"], rm.get_roles("u1", None));
785        assert_eq!(vec!["g1"], rm.get_roles("u2", None));
786        assert_eq!(vec!["g2"], rm.get_roles("u3", None));
787        assert_eq!(vec!["g3"], rm.get_roles("u4", None));
788        assert_eq!(vec![String::new(); 0], rm.get_roles("g1", None));
789        assert_eq!(vec![String::new(); 0], rm.get_roles("g2", None));
790        assert_eq!(vec![String::new(); 0], rm.get_roles("g3", None));
791    }
792
793    #[test]
794    fn test_clear() {
795        let mut rm = DefaultRoleManager::new(3);
796        rm.add_link("u1", "g1", None);
797        rm.add_link("u2", "g1", None);
798        rm.add_link("u3", "g2", None);
799        rm.add_link("u4", "g2", None);
800        rm.add_link("u4", "g3", None);
801        rm.add_link("g1", "g3", None);
802
803        rm.clear();
804        assert_eq!(false, rm.has_link("u1", "g1", None));
805        assert_eq!(false, rm.has_link("u1", "g2", None));
806        assert_eq!(false, rm.has_link("u1", "g3", None));
807        assert_eq!(false, rm.has_link("u2", "g1", None));
808        assert_eq!(false, rm.has_link("u2", "g2", None));
809        assert_eq!(false, rm.has_link("u2", "g3", None));
810        assert_eq!(false, rm.has_link("u3", "g1", None));
811        assert_eq!(false, rm.has_link("u3", "g2", None));
812        assert_eq!(false, rm.has_link("u3", "g3", None));
813        assert_eq!(false, rm.has_link("u4", "g1", None));
814        assert_eq!(false, rm.has_link("u4", "g2", None));
815        assert_eq!(false, rm.has_link("u4", "g3", None));
816    }
817
818    #[test]
819    fn test_domain_role() {
820        let mut rm = DefaultRoleManager::new(3);
821        rm.add_link("u1", "g1", Some("domain1"));
822        rm.add_link("u2", "g1", Some("domain1"));
823        rm.add_link("u3", "admin", Some("domain2"));
824        rm.add_link("u4", "admin", Some("domain2"));
825        rm.add_link("u4", "admin", Some("domain1"));
826        rm.add_link("g1", "admin", Some("domain1"));
827
828        assert_eq!(true, rm.has_link("u1", "g1", Some("domain1")));
829        assert_eq!(false, rm.has_link("u1", "g1", Some("domain2")));
830        assert_eq!(true, rm.has_link("u1", "admin", Some("domain1")));
831        assert_eq!(false, rm.has_link("u1", "admin", Some("domain2")));
832
833        assert_eq!(true, rm.has_link("u2", "g1", Some("domain1")));
834        assert_eq!(false, rm.has_link("u2", "g1", Some("domain2")));
835        assert_eq!(true, rm.has_link("u2", "admin", Some("domain1")));
836        assert_eq!(false, rm.has_link("u2", "admin", Some("domain2")));
837
838        assert_eq!(false, rm.has_link("u3", "g1", Some("domain1")));
839        assert_eq!(false, rm.has_link("u3", "g1", Some("domain2")));
840        assert_eq!(false, rm.has_link("u3", "admin", Some("domain1")));
841        assert_eq!(true, rm.has_link("u3", "admin", Some("domain2")));
842
843        assert_eq!(false, rm.has_link("u4", "g1", Some("domain1")));
844        assert_eq!(false, rm.has_link("u4", "g1", Some("domain2")));
845        assert_eq!(true, rm.has_link("u4", "admin", Some("domain1")));
846        assert_eq!(true, rm.has_link("u4", "admin", Some("domain2")));
847
848        rm.delete_link("g1", "admin", Some("domain1")).unwrap();
849
850        rm.delete_link("u4", "admin", Some("domain2")).unwrap();
851
852        assert_eq!(true, rm.has_link("u1", "g1", Some("domain1")));
853        assert_eq!(false, rm.has_link("u1", "g1", Some("domain2")));
854        assert_eq!(false, rm.has_link("u1", "admin", Some("domain1")));
855        assert_eq!(false, rm.has_link("u1", "admin", Some("domain2")));
856
857        assert_eq!(true, rm.has_link("u2", "g1", Some("domain1")));
858        assert_eq!(false, rm.has_link("u2", "g1", Some("domain2")));
859        assert_eq!(false, rm.has_link("u2", "admin", Some("domain1")));
860        assert_eq!(false, rm.has_link("u2", "admin", Some("domain2")));
861
862        assert_eq!(false, rm.has_link("u3", "g1", Some("domain1")));
863        assert_eq!(false, rm.has_link("u3", "g1", Some("domain2")));
864        assert_eq!(false, rm.has_link("u3", "admin", Some("domain1")));
865        assert_eq!(true, rm.has_link("u3", "admin", Some("domain2")));
866
867        assert_eq!(false, rm.has_link("u4", "g1", Some("domain1")));
868        assert_eq!(false, rm.has_link("u4", "g1", Some("domain2")));
869        assert_eq!(true, rm.has_link("u4", "admin", Some("domain1")));
870        assert_eq!(false, rm.has_link("u4", "admin", Some("domain2")));
871    }
872
873    #[test]
874    fn test_users() {
875        let mut rm = DefaultRoleManager::new(3);
876        rm.add_link("u1", "g1", Some("domain1"));
877        rm.add_link("u2", "g1", Some("domain1"));
878
879        rm.add_link("u3", "g2", Some("domain2"));
880        rm.add_link("u4", "g2", Some("domain2"));
881
882        rm.add_link("u5", "g3", None);
883
884        assert_eq!(
885            vec!["u1", "u2"],
886            sort_unstable(rm.get_users("g1", Some("domain1")))
887        );
888        assert_eq!(
889            vec!["u3", "u4"],
890            sort_unstable(rm.get_users("g2", Some("domain2")))
891        );
892        assert_eq!(vec!["u5"], rm.get_users("g3", None));
893    }
894
895    #[test]
896    fn test_pattern_domain() {
897        use crate::model::key_match;
898        let mut rm = DefaultRoleManager::new(3);
899        rm.matching_fn(None, Some(key_match));
900        rm.add_link("u1", "g1", Some("*"));
901
902        assert!(rm.domain_has_role("u1", Some("domain2")));
903    }
904
905    #[test]
906    fn test_basic_role_matching() {
907        use crate::model::key_match;
908        let mut rm = DefaultRoleManager::new(10);
909        rm.matching_fn(Some(key_match), None);
910        rm.add_link("bob", "book_group", None);
911        rm.add_link("*", "book_group", None);
912        rm.add_link("*", "pen_group", None);
913        rm.add_link("eve", "pen_group", None);
914
915        assert!(rm.has_link("alice", "book_group", None));
916        assert!(rm.has_link("eve", "book_group", None));
917        assert!(rm.has_link("bob", "book_group", None));
918
919        assert_eq!(
920            vec!["book_group", "pen_group"],
921            sort_unstable(rm.get_roles("alice", None))
922        );
923    }
924
925    #[test]
926    fn test_basic_role_matching2() {
927        use crate::model::key_match;
928        let mut rm = DefaultRoleManager::new(10);
929        rm.matching_fn(Some(key_match), None);
930        rm.add_link("alice", "book_group", None);
931        rm.add_link("alice", "*", None);
932        rm.add_link("bob", "pen_group", None);
933
934        assert!(rm.has_link("alice", "book_group", None));
935        assert!(rm.has_link("alice", "pen_group", None));
936        assert!(rm.has_link("bob", "pen_group", None));
937        assert!(!rm.has_link("bob", "book_group", None));
938
939        assert_eq!(
940            vec!["*", "alice", "bob", "book_group", "pen_group"],
941            sort_unstable(rm.get_roles("alice", None))
942        );
943
944        assert_eq!(vec!["alice"], sort_unstable(rm.get_users("*", None)));
945    }
946
947    #[test]
948    fn test_cross_domain_role_inheritance_complex() {
949        use crate::model::key_match;
950        let mut rm = DefaultRoleManager::new(10);
951        rm.matching_fn(None, Some(key_match));
952
953        rm.add_link("editor", "admin", Some("*"));
954        rm.add_link("viewer", "editor", Some("*"));
955
956        rm.add_link("alice", "editor", Some("domain1"));
957        rm.add_link("bob", "viewer", Some("domain2"));
958
959        assert!(rm.has_link("alice", "admin", Some("domain1")));
960        assert!(rm.has_link("bob", "editor", Some("domain2")));
961        assert!(rm.has_link("bob", "admin", Some("domain2")));
962
963        rm.add_link("charlie", "viewer", Some("domain3"));
964        assert!(rm.has_link("charlie", "editor", Some("domain3")));
965        assert!(rm.has_link("charlie", "admin", Some("domain3")));
966
967        rm.add_link("super_admin", "admin", Some("domain1"));
968        assert!(rm.has_link("super_admin", "admin", Some("domain1")));
969    }
970}