casbin/rbac/
default_role_manager.rs

1use crate::{
2    error::RbacError,
3    rbac::{MatchingFn, RoleManager},
4    Result,
5};
6use petgraph::stable_graph::{NodeIndex, StableDiGraph};
7use std::collections::{hash_map::Entry, HashMap, HashSet};
8
9#[cfg(feature = "cached")]
10use crate::cache::{Cache, DefaultCache};
11
12#[cfg(feature = "cached")]
13use std::{
14    collections::hash_map::DefaultHasher,
15    hash::{Hash, Hasher},
16};
17
18const DEFAULT_DOMAIN: &str = "DEFAULT";
19
20pub struct DefaultRoleManager {
21    all_domains: HashMap<String, StableDiGraph<String, EdgeVariant>>,
22    all_domains_indices: HashMap<String, HashMap<String, NodeIndex<u32>>>,
23    #[cfg(feature = "cached")]
24    cache: DefaultCache<u64, bool>,
25    max_hierarchy_level: usize,
26    role_matching_fn: Option<MatchingFn>,
27    domain_matching_fn: Option<MatchingFn>,
28}
29
30#[derive(Clone, Debug)]
31enum EdgeVariant {
32    Link,
33    Match,
34}
35
36impl DefaultRoleManager {
37    pub fn new(max_hierarchy_level: usize) -> Self {
38        DefaultRoleManager {
39            all_domains: HashMap::new(),
40            all_domains_indices: HashMap::new(),
41            max_hierarchy_level,
42            #[cfg(feature = "cached")]
43            cache: DefaultCache::new(50),
44            role_matching_fn: None,
45            domain_matching_fn: None,
46        }
47    }
48
49    fn get_or_create_role(
50        &mut self,
51        name: &str,
52        domain: Option<&str>,
53    ) -> NodeIndex<u32> {
54        let domain = domain.unwrap_or(DEFAULT_DOMAIN);
55
56        // detect whether this is a new domain creation
57        let is_new_domain = !self.all_domains.contains_key(domain);
58
59        let graph = self.all_domains.entry(domain.into()).or_default();
60
61        let role_entry = self
62            .all_domains_indices
63            .entry(domain.into())
64            .or_default()
65            .entry(name.into());
66
67        let vacant_entry = match role_entry {
68            Entry::Occupied(e) => return *e.get(),
69            Entry::Vacant(e) => e,
70        };
71
72        let new_role_id = graph.add_node(name.into());
73        vacant_entry.insert(new_role_id);
74
75        if let Some(role_matching_fn) = self.role_matching_fn {
76            let mut added = false;
77
78            let node_ids: Vec<_> =
79                graph.node_indices().filter(|&i| graph[i] != name).collect();
80
81            for existing_role_id in node_ids {
82                added |= link_if_matches(
83                    graph,
84                    role_matching_fn,
85                    new_role_id,
86                    existing_role_id,
87                );
88
89                added |= link_if_matches(
90                    graph,
91                    role_matching_fn,
92                    existing_role_id,
93                    new_role_id,
94                );
95            }
96
97            if added {
98                #[cfg(feature = "cached")]
99                self.cache.clear();
100            }
101        }
102
103        // If domain matching function exists and this was a new domain, copy
104        // role links from matching domains into the newly created domain so
105        // that BFS will see inherited links in this domain's graph.
106        if is_new_domain {
107            if let Some(domain_matching_fn) = self.domain_matching_fn {
108                let keys: Vec<String> =
109                    self.all_domains.keys().cloned().collect();
110                for d in keys {
111                    if d != domain && (domain_matching_fn)(domain, &d) {
112                        self.copy_from_domain(&d, domain);
113                    }
114                }
115            }
116        }
117
118        new_role_id
119    }
120
121    // propagate a Link addition (name1 -> name2) from `domain` into all
122    // affected/matching domains. This extracts the inline logic from
123    // `add_link` so the code is clearer and avoids nested borrows.
124    fn propagate_link_to_affected_domains(
125        &mut self,
126        name1: &str,
127        name2: &str,
128        domain: &str,
129    ) {
130        let name1_owned = name1.to_string();
131        let name2_owned = name2.to_string();
132        let affected = self.affected_domain_names(domain);
133        for d in affected {
134            // obtain mutable graph and index map for the affected domain
135            let g = self.all_domains.get_mut(&d).unwrap();
136            let idx_map =
137                self.all_domains_indices.entry(d.clone()).or_default();
138            let idx1 = Self::ensure_node_in_graph(g, idx_map, &name1_owned);
139            let idx2 = Self::ensure_node_in_graph(g, idx_map, &name2_owned);
140
141            // add Link edge if missing
142            let has_link = g
143                .edges_connecting(idx1, idx2)
144                .any(|e| matches!(*e.weight(), EdgeVariant::Link));
145            if !has_link {
146                g.add_edge(idx1, idx2, EdgeVariant::Link);
147            }
148        }
149
150        #[cfg(feature = "cached")]
151        self.cache.clear();
152    }
153
154    // ensure a node with `name` exists in graph `g` and in the provided
155    // `idx_map`. Returns the NodeIndex for the node.
156    fn ensure_node_in_graph(
157        g: &mut StableDiGraph<String, EdgeVariant>,
158        idx_map: &mut HashMap<String, NodeIndex<u32>>,
159        name: &str,
160    ) -> NodeIndex<u32> {
161        if let Some(idx) = idx_map.get(name) {
162            *idx
163        } else if let Some(idx) = g.node_indices().find(|&i| g[i] == name) {
164            idx_map.insert(name.to_string(), idx);
165            idx
166        } else {
167            let ni = g.add_node(name.to_string());
168            idx_map.insert(name.to_string(), ni);
169            ni
170        }
171    }
172
173    // return the list of affected domain names (immutable) to avoid nested
174    // mutable borrows when performing operations across domains
175    fn affected_domain_names(&self, domain: &str) -> Vec<String> {
176        let mut res = Vec::new();
177        if let Some(domain_matching_fn) = self.domain_matching_fn {
178            let keys: Vec<String> = self.all_domains.keys().cloned().collect();
179            for d in keys {
180                if d != domain && (domain_matching_fn)(&d, domain) {
181                    res.push(d);
182                }
183            }
184        }
185        res
186    }
187
188    // copy all role links and nodes from `src_domain` graph into `dst_domain` graph
189    fn copy_from_domain(&mut self, src_domain: &str, dst_domain: &str) {
190        if src_domain == dst_domain {
191            return;
192        }
193
194        // ensure both graphs exist
195        if !self.all_domains.contains_key(src_domain) {
196            return;
197        }
198
199        let src_graph = match self.all_domains.get(src_domain) {
200            Some(g) => g.clone(),
201            None => return,
202        };
203
204        // ensure dst indices map exists
205        let dst_indices = self
206            .all_domains_indices
207            .entry(dst_domain.into())
208            .or_default();
209
210        let dst_graph = self.all_domains.entry(dst_domain.into()).or_default();
211
212        // copy nodes: ensure names exist in dst and capture mapping
213        let mut id_map: HashMap<NodeIndex<u32>, NodeIndex<u32>> =
214            HashMap::new();
215        for src_idx in src_graph.node_indices() {
216            let name = &src_graph[src_idx];
217            let dst_idx = if let Some(idx) = dst_indices.get(name) {
218                *idx
219            } else {
220                let new_idx = dst_graph.add_node(name.clone());
221                dst_indices.insert(name.clone(), new_idx);
222                new_idx
223            };
224            id_map.insert(src_idx, dst_idx);
225        }
226
227        // copy edges: for each edge in src_graph, add equivalent edge in dst if missing
228        for edge_idx in src_graph.edge_indices() {
229            if let Some((src_s, src_t)) = src_graph.edge_endpoints(edge_idx) {
230                if let Some(weight) = src_graph.edge_weight(edge_idx) {
231                    let dst_s = id_map.get(&src_s).unwrap();
232                    let dst_t = id_map.get(&src_t).unwrap();
233
234                    let need_add = match dst_graph.find_edge(*dst_s, *dst_t) {
235                        Some(idx) => {
236                            // if existing edge is Match but source weight is Link, allow adding Link
237                            !matches!(dst_graph[idx], EdgeVariant::Match)
238                                || !matches!(weight, &EdgeVariant::Match)
239                        }
240                        None => true,
241                    };
242
243                    if need_add {
244                        dst_graph.add_edge(*dst_s, *dst_t, weight.clone());
245                    }
246                }
247            }
248        }
249
250        #[cfg(feature = "cached")]
251        self.cache.clear();
252    }
253
254    fn matched_domains(&self, domain: Option<&str>) -> Vec<String> {
255        let domain = domain.unwrap_or(DEFAULT_DOMAIN);
256        if let Some(domain_matching_fn) = self.domain_matching_fn {
257            self.all_domains
258                .keys()
259                .filter_map(|key| {
260                    if domain_matching_fn(domain, key) {
261                        Some(key.to_owned())
262                    } else {
263                        None
264                    }
265                })
266                .collect::<Vec<String>>()
267        } else {
268            self.all_domains
269                .get(domain)
270                .map_or(vec![], |_| vec![domain.to_owned()])
271        }
272    }
273
274    fn domain_has_role(&self, name: &str, domain: Option<&str>) -> bool {
275        let matched_domains = self.matched_domains(domain);
276
277        matched_domains.iter().any(|domain| {
278            // try to find direct match of role
279            if self.all_domains_indices[domain].contains_key(name) {
280                true
281            } else if let Some(role_matching_fn) = self.role_matching_fn {
282                // else if role_matching_fn is set, iterate all graph nodes and try to find matching role
283                let graph = &self.all_domains[domain];
284
285                graph
286                    .node_weights()
287                    .any(|role| role_matching_fn(name, role))
288            } else {
289                false
290            }
291        })
292    }
293}
294
295/// link node of `not_pattern_id` to `maybe_pattern_id` if
296/// `not_pattern` matches `maybe_pattern`'s pattern and
297/// there doesn't exist a match edge yet
298fn link_if_matches(
299    graph: &mut StableDiGraph<String, EdgeVariant>,
300    role_matching_fn: fn(&str, &str) -> bool,
301    not_pattern_id: NodeIndex<u32>,
302    maybe_pattern_id: NodeIndex<u32>,
303) -> bool {
304    let not_pattern = &graph[not_pattern_id];
305    let maybe_pattern = &graph[maybe_pattern_id];
306
307    if !role_matching_fn(maybe_pattern, not_pattern) {
308        return false;
309    }
310
311    let add_edge =
312        if let Some(idx) = graph.find_edge(not_pattern_id, maybe_pattern_id) {
313            !matches!(graph[idx], EdgeVariant::Match)
314        } else {
315            true
316        };
317
318    if add_edge {
319        graph.add_edge(not_pattern_id, maybe_pattern_id, EdgeVariant::Match);
320
321        true
322    } else {
323        false
324    }
325}
326
327impl RoleManager for DefaultRoleManager {
328    fn clear(&mut self) {
329        self.all_domains_indices.clear();
330        self.all_domains.clear();
331        #[cfg(feature = "cached")]
332        self.cache.clear();
333    }
334
335    fn add_link(&mut self, name1: &str, name2: &str, domain: Option<&str>) {
336        if name1 == name2 {
337            return;
338        }
339
340        let role1 = self.get_or_create_role(name1, domain);
341        let role2 = self.get_or_create_role(name2, domain);
342
343        let graph = self
344            .all_domains
345            .get_mut(domain.unwrap_or(DEFAULT_DOMAIN))
346            .unwrap();
347
348        let add_link = if let Some(edge) = graph.find_edge(role1, role2) {
349            !matches!(graph[edge], EdgeVariant::Link)
350        } else {
351            true
352        };
353
354        if add_link {
355            graph.add_edge(role1, role2, EdgeVariant::Link);
356
357            if let Some(domain_str) = domain {
358                self.propagate_link_to_affected_domains(
359                    name1, name2, domain_str,
360                );
361            }
362
363            #[cfg(feature = "cached")]
364            self.cache.clear();
365        }
366    }
367
368    fn matching_fn(
369        &mut self,
370        role_matching_fn: Option<MatchingFn>,
371        domain_matching_fn: Option<MatchingFn>,
372    ) {
373        self.domain_matching_fn = domain_matching_fn;
374        self.role_matching_fn = role_matching_fn;
375    }
376
377    fn delete_link(
378        &mut self,
379        name1: &str,
380        name2: &str,
381        domain: Option<&str>,
382    ) -> Result<()> {
383        if !self.domain_has_role(name1, domain)
384            || !self.domain_has_role(name2, domain)
385        {
386            return Err(
387                RbacError::NotFound(format!("{} OR {}", name1, name2)).into()
388            );
389        }
390
391        let role1 = self.get_or_create_role(name1, domain);
392        let role2 = self.get_or_create_role(name2, domain);
393
394        let graph = self
395            .all_domains
396            .get_mut(domain.unwrap_or(DEFAULT_DOMAIN))
397            .unwrap();
398
399        if let Some(edge_index) = graph.find_edge(role1, role2) {
400            graph.remove_edge(edge_index).unwrap();
401
402            #[cfg(feature = "cached")]
403            self.cache.clear();
404        }
405
406        Ok(())
407    }
408
409    fn has_link(&self, name1: &str, name2: &str, domain: Option<&str>) -> bool {
410        if name1 == name2 {
411            return true;
412        }
413
414        #[cfg(feature = "cached")]
415        let cache_key = {
416            let mut hasher = DefaultHasher::new();
417            name1.hash(&mut hasher);
418            name2.hash(&mut hasher);
419            domain.unwrap_or(DEFAULT_DOMAIN).hash(&mut hasher);
420            hasher.finish()
421        };
422
423        #[cfg(feature = "cached")]
424        if let Some(res) = self.cache.get(&cache_key) {
425            return res;
426        }
427
428        let matched_domains = self.matched_domains(domain);
429
430        let mut res = false;
431
432        for domain in matched_domains {
433            let graph = self.all_domains.get(&domain).unwrap();
434            let indices = self.all_domains_indices.get(&domain).unwrap();
435
436            let role1 = if let Some(role1) = indices.get(name1) {
437                Some(*role1)
438            } else {
439                graph.node_indices().find(|&i| {
440                    let role_name = &graph[i];
441
442                    role_name == name1
443                        || self
444                            .role_matching_fn
445                            .map(|f| f(name1, role_name))
446                            .unwrap_or_default()
447                })
448            };
449
450            let role1 = if let Some(role1) = role1 {
451                role1
452            } else {
453                continue;
454            };
455
456            let mut bfs = matching_bfs::Bfs::new(
457                graph,
458                role1,
459                self.max_hierarchy_level,
460                self.role_matching_fn.is_some(),
461            );
462
463            while let Some(node) = bfs.next(graph) {
464                let role_name = &graph[node];
465
466                if role_name == name2
467                    || self
468                        .role_matching_fn
469                        .map(|f| f(role_name, name2))
470                        .unwrap_or_default()
471                {
472                    res = true;
473                    break;
474                }
475            }
476        }
477
478        #[cfg(feature = "cached")]
479        self.cache.set(cache_key, res);
480
481        res
482    }
483
484    fn get_roles(&self, name: &str, domain: Option<&str>) -> Vec<String> {
485        let matched_domains = self.matched_domains(domain);
486
487        let res = matched_domains.into_iter().fold(
488            HashSet::new(),
489            |mut set, domain| {
490                let graph = &self.all_domains[&domain];
491
492                if let Some(role_node) = graph.node_indices().find(|&i| {
493                    graph[i] == name
494                        || self.role_matching_fn.unwrap_or(|_, _| false)(
495                            name, &graph[i],
496                        )
497                }) {
498                    let neighbors = matching_bfs::bfs_iterator(
499                        graph,
500                        role_node,
501                        self.role_matching_fn.is_some(),
502                    )
503                    .map(|i| graph[i].clone());
504
505                    set.extend(neighbors);
506                }
507
508                set
509            },
510        );
511        res.into_iter().collect()
512    }
513
514    fn get_users(&self, name: &str, domain: Option<&str>) -> Vec<String> {
515        let matched_domains = self.matched_domains(domain);
516
517        let res = matched_domains.into_iter().fold(
518            HashSet::new(),
519            |mut set, domain| {
520                let graph = &self.all_domains[&domain];
521
522                if let Some(role_node) = graph.node_indices().find(|&i| {
523                    graph[i] == name
524                        || self
525                            .role_matching_fn
526                            .map(|f| f(name, &graph[i]))
527                            .unwrap_or_default()
528                }) {
529                    let neighbors = graph
530                        .neighbors_directed(
531                            role_node,
532                            petgraph::Direction::Incoming,
533                        )
534                        .map(|i| graph[i].clone());
535
536                    set.extend(neighbors);
537                }
538
539                set
540            },
541        );
542
543        res.into_iter().collect()
544    }
545}
546
547mod matching_bfs {
548    use super::EdgeVariant;
549    use fixedbitset::FixedBitSet;
550    use petgraph::graph::NodeIndex;
551    use petgraph::stable_graph::StableDiGraph;
552    use petgraph::visit::{EdgeRef, VisitMap, Visitable};
553    use std::collections::VecDeque;
554
555    #[derive(Clone)]
556    pub(super) struct Bfs {
557        /// The queue of nodes to visit
558        pub queue: VecDeque<NodeIndex<u32>>,
559        /// The map of discovered nodes
560        pub discovered: FixedBitSet,
561        /// Maximum depth
562        pub max_depth: usize,
563        /// Consider `Match` edges
564        pub with_pattern_matching: bool,
565
566        /// Current depth
567        depth: usize,
568        /// Number of elements until next depth is reached
569        depth_elements_remaining: usize,
570    }
571
572    impl Bfs {
573        /// Create a new **Bfs**, using the graph's visitor map, and put **start**
574        /// in the stack of nodes to visit.
575        pub fn new(
576            graph: &StableDiGraph<String, EdgeVariant>,
577            start: NodeIndex<u32>,
578            max_depth: usize,
579            with_pattern_matching: bool,
580        ) -> Self {
581            let mut discovered = graph.visit_map();
582            discovered.visit(start);
583
584            let mut queue = VecDeque::new();
585            queue.push_front(start);
586
587            Bfs {
588                queue,
589                discovered,
590                max_depth,
591                with_pattern_matching,
592                depth: 0,
593                depth_elements_remaining: 1,
594            }
595        }
596
597        /// Return the next node in the bfs, or **None** if the traversal is done.
598        pub fn next(
599            &mut self,
600            graph: &StableDiGraph<String, EdgeVariant>,
601        ) -> Option<NodeIndex<u32>> {
602            if self.max_depth <= self.depth {
603                return None;
604            }
605
606            if let Some(node) = self.queue.pop_front() {
607                self.update_depth();
608
609                let mut counter = 0;
610                for succ in
611                    bfs_iterator(graph, node, self.with_pattern_matching)
612                {
613                    if self.discovered.visit(succ) {
614                        self.queue.push_back(succ);
615                        counter += 1;
616                    }
617                }
618
619                self.depth_elements_remaining += counter;
620
621                Some(node)
622            } else {
623                None
624            }
625        }
626
627        fn update_depth(&mut self) {
628            self.depth_elements_remaining -= 1;
629            if self.depth_elements_remaining == 0 {
630                self.depth += 1
631            }
632        }
633    }
634
635    pub(super) fn bfs_iterator(
636        graph: &StableDiGraph<String, EdgeVariant>,
637        node: NodeIndex<u32>,
638        with_matches: bool,
639    ) -> Box<dyn Iterator<Item = NodeIndex<u32>> + '_> {
640        // outgoing LINK edges of node
641        let outgoing_direct_edge = graph
642            .edges_directed(node, petgraph::Direction::Outgoing)
643            .filter_map(|edge| match *edge.weight() {
644                EdgeVariant::Link => Some(edge.target()),
645                EdgeVariant::Match => None,
646            });
647
648        if !with_matches {
649            return Box::new(outgoing_direct_edge);
650        }
651
652        // x := outgoing LINK edges of node
653        // outgoing_match_edge : outgoing MATCH edges of x FOR ALL x
654        let outgoing_match_edge = graph
655            .edges_directed(node, petgraph::Direction::Outgoing)
656            .filter(|edge| matches!(*edge.weight(), EdgeVariant::Link))
657            .flat_map(move |edge| {
658                graph
659                    .edges_directed(
660                        edge.target(),
661                        petgraph::Direction::Outgoing,
662                    )
663                    .filter_map(|edge| match *edge.weight() {
664                        EdgeVariant::Match => Some(edge.target()),
665                        EdgeVariant::Link => None,
666                    })
667            });
668
669        // x := incoming MATCH edges of node
670        // sibling_matched_by := outgoing LINK edges of x FOR ALL x
671        let sibling_matched_by = graph
672            .edges_directed(node, petgraph::Direction::Incoming)
673            .filter(|edge| matches!(*edge.weight(), EdgeVariant::Match))
674            .flat_map(move |edge| {
675                graph
676                    .edges_directed(
677                        edge.source(),
678                        petgraph::Direction::Outgoing,
679                    )
680                    .filter_map(|edge| match *edge.weight() {
681                        EdgeVariant::Link => Some(edge.target()),
682                        EdgeVariant::Match => None,
683                    })
684            });
685
686        Box::new(
687            outgoing_direct_edge
688                .chain(outgoing_match_edge)
689                .chain(sibling_matched_by),
690        )
691    }
692
693    #[cfg(test)]
694    mod test {
695        use super::*;
696        use petgraph::stable_graph::StableDiGraph;
697
698        #[test]
699        fn test_max_depth() {
700            let mut deps = StableDiGraph::<String, EdgeVariant>::new();
701            let pg = deps.add_node("petgraph".into());
702            let fb = deps.add_node("fixedbitset".into());
703            let qc = deps.add_node("quickcheck".into());
704            let rand = deps.add_node("rand".into());
705            let libc = deps.add_node("libc".into());
706
707            deps.extend_with_edges([
708                (pg, fb, EdgeVariant::Link),
709                (pg, qc, EdgeVariant::Link),
710                (qc, rand, EdgeVariant::Link),
711                (rand, libc, EdgeVariant::Link),
712            ]);
713
714            let mut bfs = Bfs::new(&deps, pg, 2, false);
715
716            let mut nodes = vec![];
717            while let Some(x) = bfs.next(&deps) {
718                nodes.push(x);
719            }
720
721            assert!(nodes.contains(&fb));
722            assert!(nodes.contains(&qc));
723            assert!(nodes.contains(&rand));
724            assert!(!nodes.contains(&libc));
725        }
726    }
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732
733    fn sort_unstable<T: Ord>(mut v: Vec<T>) -> Vec<T> {
734        v.sort_unstable();
735        v
736    }
737
738    #[test]
739    fn test_role() {
740        let mut rm = DefaultRoleManager::new(3);
741        rm.add_link("u1", "g1", None);
742        rm.add_link("u2", "g1", None);
743        rm.add_link("u3", "g2", None);
744        rm.add_link("u4", "g2", None);
745        rm.add_link("u4", "g3", None);
746        rm.add_link("g1", "g3", None);
747
748        assert_eq!(true, rm.has_link("u1", "g1", None));
749        assert_eq!(false, rm.has_link("u1", "g2", None));
750        assert_eq!(true, rm.has_link("u1", "g3", None));
751        assert_eq!(true, rm.has_link("u2", "g1", None));
752        assert_eq!(false, rm.has_link("u2", "g2", None));
753        assert_eq!(true, rm.has_link("u2", "g3", None));
754        assert_eq!(false, rm.has_link("u3", "g1", None));
755        assert_eq!(true, rm.has_link("u3", "g2", None));
756        assert_eq!(false, rm.has_link("u3", "g3", None));
757        assert_eq!(false, rm.has_link("u4", "g1", None));
758        assert_eq!(true, rm.has_link("u4", "g2", None));
759        assert_eq!(true, rm.has_link("u4", "g3", None));
760
761        // test get_roles
762        assert_eq!(vec!["g1"], rm.get_roles("u1", None));
763        assert_eq!(vec!["g1"], rm.get_roles("u2", None));
764        assert_eq!(vec!["g2"], rm.get_roles("u3", None));
765        assert_eq!(vec!["g2", "g3"], sort_unstable(rm.get_roles("u4", None)));
766        assert_eq!(vec!["g3"], rm.get_roles("g1", None));
767        assert_eq!(vec![String::new(); 0], rm.get_roles("g2", None));
768        assert_eq!(vec![String::new(); 0], rm.get_roles("g3", None));
769
770        // test delete_link
771        rm.delete_link("g1", "g3", None).unwrap();
772        rm.delete_link("u4", "g2", None).unwrap();
773        assert_eq!(true, rm.has_link("u1", "g1", None));
774        assert_eq!(false, rm.has_link("u1", "g2", None));
775        assert_eq!(false, rm.has_link("u1", "g3", None));
776        assert_eq!(true, rm.has_link("u2", "g1", None));
777        assert_eq!(false, rm.has_link("u2", "g2", None));
778        assert_eq!(false, rm.has_link("u2", "g3", None));
779        assert_eq!(false, rm.has_link("u3", "g1", None));
780        assert_eq!(true, rm.has_link("u3", "g2", None));
781        assert_eq!(false, rm.has_link("u3", "g3", None));
782        assert_eq!(false, rm.has_link("u4", "g1", None));
783        assert_eq!(false, rm.has_link("u4", "g2", None));
784        assert_eq!(true, rm.has_link("u4", "g3", None));
785        assert_eq!(vec!["g1"], rm.get_roles("u1", None));
786        assert_eq!(vec!["g1"], rm.get_roles("u2", None));
787        assert_eq!(vec!["g2"], rm.get_roles("u3", None));
788        assert_eq!(vec!["g3"], rm.get_roles("u4", None));
789        assert_eq!(vec![String::new(); 0], rm.get_roles("g1", None));
790        assert_eq!(vec![String::new(); 0], rm.get_roles("g2", None));
791        assert_eq!(vec![String::new(); 0], rm.get_roles("g3", None));
792    }
793
794    #[test]
795    fn test_clear() {
796        let mut rm = DefaultRoleManager::new(3);
797        rm.add_link("u1", "g1", None);
798        rm.add_link("u2", "g1", None);
799        rm.add_link("u3", "g2", None);
800        rm.add_link("u4", "g2", None);
801        rm.add_link("u4", "g3", None);
802        rm.add_link("g1", "g3", None);
803
804        rm.clear();
805        assert_eq!(false, rm.has_link("u1", "g1", None));
806        assert_eq!(false, rm.has_link("u1", "g2", None));
807        assert_eq!(false, rm.has_link("u1", "g3", None));
808        assert_eq!(false, rm.has_link("u2", "g1", None));
809        assert_eq!(false, rm.has_link("u2", "g2", None));
810        assert_eq!(false, rm.has_link("u2", "g3", None));
811        assert_eq!(false, rm.has_link("u3", "g1", None));
812        assert_eq!(false, rm.has_link("u3", "g2", None));
813        assert_eq!(false, rm.has_link("u3", "g3", None));
814        assert_eq!(false, rm.has_link("u4", "g1", None));
815        assert_eq!(false, rm.has_link("u4", "g2", None));
816        assert_eq!(false, rm.has_link("u4", "g3", None));
817    }
818
819    #[test]
820    fn test_domain_role() {
821        let mut rm = DefaultRoleManager::new(3);
822        rm.add_link("u1", "g1", Some("domain1"));
823        rm.add_link("u2", "g1", Some("domain1"));
824        rm.add_link("u3", "admin", Some("domain2"));
825        rm.add_link("u4", "admin", Some("domain2"));
826        rm.add_link("u4", "admin", Some("domain1"));
827        rm.add_link("g1", "admin", Some("domain1"));
828
829        assert_eq!(true, rm.has_link("u1", "g1", Some("domain1")));
830        assert_eq!(false, rm.has_link("u1", "g1", Some("domain2")));
831        assert_eq!(true, rm.has_link("u1", "admin", Some("domain1")));
832        assert_eq!(false, rm.has_link("u1", "admin", Some("domain2")));
833
834        assert_eq!(true, rm.has_link("u2", "g1", Some("domain1")));
835        assert_eq!(false, rm.has_link("u2", "g1", Some("domain2")));
836        assert_eq!(true, rm.has_link("u2", "admin", Some("domain1")));
837        assert_eq!(false, rm.has_link("u2", "admin", Some("domain2")));
838
839        assert_eq!(false, rm.has_link("u3", "g1", Some("domain1")));
840        assert_eq!(false, rm.has_link("u3", "g1", Some("domain2")));
841        assert_eq!(false, rm.has_link("u3", "admin", Some("domain1")));
842        assert_eq!(true, rm.has_link("u3", "admin", Some("domain2")));
843
844        assert_eq!(false, rm.has_link("u4", "g1", Some("domain1")));
845        assert_eq!(false, rm.has_link("u4", "g1", Some("domain2")));
846        assert_eq!(true, rm.has_link("u4", "admin", Some("domain1")));
847        assert_eq!(true, rm.has_link("u4", "admin", Some("domain2")));
848
849        rm.delete_link("g1", "admin", Some("domain1")).unwrap();
850
851        rm.delete_link("u4", "admin", Some("domain2")).unwrap();
852
853        assert_eq!(true, rm.has_link("u1", "g1", Some("domain1")));
854        assert_eq!(false, rm.has_link("u1", "g1", Some("domain2")));
855        assert_eq!(false, rm.has_link("u1", "admin", Some("domain1")));
856        assert_eq!(false, rm.has_link("u1", "admin", Some("domain2")));
857
858        assert_eq!(true, rm.has_link("u2", "g1", Some("domain1")));
859        assert_eq!(false, rm.has_link("u2", "g1", Some("domain2")));
860        assert_eq!(false, rm.has_link("u2", "admin", Some("domain1")));
861        assert_eq!(false, rm.has_link("u2", "admin", Some("domain2")));
862
863        assert_eq!(false, rm.has_link("u3", "g1", Some("domain1")));
864        assert_eq!(false, rm.has_link("u3", "g1", Some("domain2")));
865        assert_eq!(false, rm.has_link("u3", "admin", Some("domain1")));
866        assert_eq!(true, rm.has_link("u3", "admin", Some("domain2")));
867
868        assert_eq!(false, rm.has_link("u4", "g1", Some("domain1")));
869        assert_eq!(false, rm.has_link("u4", "g1", Some("domain2")));
870        assert_eq!(true, rm.has_link("u4", "admin", Some("domain1")));
871        assert_eq!(false, rm.has_link("u4", "admin", Some("domain2")));
872    }
873
874    #[test]
875    fn test_users() {
876        let mut rm = DefaultRoleManager::new(3);
877        rm.add_link("u1", "g1", Some("domain1"));
878        rm.add_link("u2", "g1", Some("domain1"));
879
880        rm.add_link("u3", "g2", Some("domain2"));
881        rm.add_link("u4", "g2", Some("domain2"));
882
883        rm.add_link("u5", "g3", None);
884
885        assert_eq!(
886            vec!["u1", "u2"],
887            sort_unstable(rm.get_users("g1", Some("domain1")))
888        );
889        assert_eq!(
890            vec!["u3", "u4"],
891            sort_unstable(rm.get_users("g2", Some("domain2")))
892        );
893        assert_eq!(vec!["u5"], rm.get_users("g3", None));
894    }
895
896    #[test]
897    fn test_pattern_domain() {
898        use crate::model::key_match;
899        let mut rm = DefaultRoleManager::new(3);
900        rm.matching_fn(None, Some(key_match));
901        rm.add_link("u1", "g1", Some("*"));
902
903        assert!(rm.domain_has_role("u1", Some("domain2")));
904    }
905
906    #[test]
907    fn test_basic_role_matching() {
908        use crate::model::key_match;
909        let mut rm = DefaultRoleManager::new(10);
910        rm.matching_fn(Some(key_match), None);
911        rm.add_link("bob", "book_group", None);
912        rm.add_link("*", "book_group", None);
913        rm.add_link("*", "pen_group", None);
914        rm.add_link("eve", "pen_group", None);
915
916        assert!(rm.has_link("alice", "book_group", None));
917        assert!(rm.has_link("eve", "book_group", None));
918        assert!(rm.has_link("bob", "book_group", None));
919
920        assert_eq!(
921            vec!["book_group", "pen_group"],
922            sort_unstable(rm.get_roles("alice", None))
923        );
924    }
925
926    #[test]
927    fn test_basic_role_matching2() {
928        use crate::model::key_match;
929        let mut rm = DefaultRoleManager::new(10);
930        rm.matching_fn(Some(key_match), None);
931        rm.add_link("alice", "book_group", None);
932        rm.add_link("alice", "*", None);
933        rm.add_link("bob", "pen_group", None);
934
935        assert!(rm.has_link("alice", "book_group", None));
936        assert!(rm.has_link("alice", "pen_group", None));
937        assert!(rm.has_link("bob", "pen_group", None));
938        assert!(!rm.has_link("bob", "book_group", None));
939
940        assert_eq!(
941            vec!["*", "alice", "bob", "book_group", "pen_group"],
942            sort_unstable(rm.get_roles("alice", None))
943        );
944
945        assert_eq!(vec!["alice"], sort_unstable(rm.get_users("*", None)));
946    }
947
948    #[test]
949    fn test_cross_domain_role_inheritance_complex() {
950        use crate::model::key_match;
951        let mut rm = DefaultRoleManager::new(10);
952        rm.matching_fn(None, Some(key_match));
953
954        rm.add_link("editor", "admin", Some("*"));
955        rm.add_link("viewer", "editor", Some("*"));
956
957        rm.add_link("alice", "editor", Some("domain1"));
958        rm.add_link("bob", "viewer", Some("domain2"));
959
960        assert!(rm.has_link("alice", "admin", Some("domain1")));
961        assert!(rm.has_link("bob", "editor", Some("domain2")));
962        assert!(rm.has_link("bob", "admin", Some("domain2")));
963
964        rm.add_link("charlie", "viewer", Some("domain3"));
965        assert!(rm.has_link("charlie", "editor", Some("domain3")));
966        assert!(rm.has_link("charlie", "admin", Some("domain3")));
967
968        rm.add_link("super_admin", "admin", Some("domain1"));
969        assert!(rm.has_link("super_admin", "admin", Some("domain1")));
970    }
971}