casbin/
enforcer.rs

1use crate::{
2    adapter::{Adapter, Filter},
3    convert::{EnforceArgs, TryIntoAdapter, TryIntoModel},
4    core_api::CoreApi,
5    effector::{DefaultEffector, EffectKind, Effector},
6    emitter::{Event, EventData, EventEmitter},
7    error::{ModelError, PolicyError, RequestError},
8    get_or_err, get_or_err_with_context,
9    management_api::MgmtApi,
10    model::{FunctionMap, Model, OperatorFunction},
11    rbac::{DefaultRoleManager, RoleManager},
12    register_g_function,
13    util::{escape_assertion, escape_eval},
14    Result,
15};
16
17use crate::model::DefaultModel;
18
19#[cfg(any(feature = "logging", feature = "watcher"))]
20use crate::emitter::notify_logger_and_watcher;
21
22#[cfg(feature = "watcher")]
23use crate::watcher::Watcher;
24
25#[cfg(feature = "logging")]
26use crate::{DefaultLogger, Logger};
27
28use async_trait::async_trait;
29use once_cell::sync::Lazy;
30use parking_lot::RwLock;
31use rhai::{
32    def_package,
33    packages::{
34        ArithmeticPackage, BasicArrayPackage, BasicMapPackage, LogicPackage,
35        Package,
36    },
37    Dynamic, Engine, EvalAltResult, ImmutableString, Scope,
38};
39
40def_package! {
41    pub CasbinPackage(lib) {
42        ArithmeticPackage::init(lib);
43        LogicPackage::init(lib);
44        BasicArrayPackage::init(lib);
45        BasicMapPackage::init(lib);
46
47        lib.set_native_fn("escape_assertion", |s: ImmutableString| {
48            Ok(escape_assertion(&s))
49        });
50    }
51}
52
53static CASBIN_PACKAGE: Lazy<CasbinPackage> = Lazy::new(CasbinPackage::new);
54
55use std::{cmp::max, collections::HashMap, sync::Arc};
56
57type EventCallback = fn(&mut Enforcer, EventData);
58
59/// Enforcer is the main interface for authorization enforcement and policy management.
60pub struct Enforcer {
61    model: Box<dyn Model>,
62    adapter: Box<dyn Adapter>,
63    fm: FunctionMap,
64    eft: Box<dyn Effector>,
65    rm: Arc<RwLock<dyn RoleManager>>,
66    enabled: bool,
67    auto_save: bool,
68    auto_build_role_links: bool,
69    #[cfg(feature = "watcher")]
70    auto_notify_watcher: bool,
71    #[cfg(feature = "watcher")]
72    watcher: Option<Box<dyn Watcher>>,
73    events: HashMap<Event, Vec<EventCallback>>,
74    engine: Engine,
75    #[cfg(feature = "logging")]
76    logger: Box<dyn Logger>,
77}
78
79pub struct EnforceContext {
80    pub r_type: String,
81    pub p_type: String,
82    pub e_type: String,
83    pub m_type: String,
84}
85
86impl EnforceContext {
87    pub fn new(suffix: &str) -> Self {
88        Self {
89            r_type: format!("r{}", suffix),
90            p_type: format!("p{}", suffix),
91            e_type: format!("e{}", suffix),
92            m_type: format!("m{}", suffix),
93        }
94    }
95    pub fn get_cache_key(&self) -> String {
96        format!(
97            "EnforceContext{{{}-{}-{}-{}}}",
98            &self.r_type, &self.p_type, &self.e_type, &self.m_type,
99        )
100    }
101}
102
103impl EventEmitter<Event> for Enforcer {
104    fn on(&mut self, e: Event, f: fn(&mut Self, EventData)) {
105        self.events.entry(e).or_default().push(f)
106    }
107
108    fn off(&mut self, e: Event) {
109        self.events.remove(&e);
110    }
111
112    fn emit(&mut self, e: Event, d: EventData) {
113        if let Some(cbs) = self.events.get(&e) {
114            for cb in cbs.clone().iter() {
115                cb(self, d.clone())
116            }
117        }
118    }
119}
120
121impl Enforcer {
122    pub(crate) fn private_enforce(
123        &self,
124        rvals: &[Dynamic],
125    ) -> Result<(bool, Option<Vec<usize>>)> {
126        if !self.enabled {
127            return Ok((true, None));
128        }
129
130        let mut scope: Scope = Scope::new();
131
132        let r_ast = get_or_err!(self, "r", ModelError::R, "request");
133        let p_ast = get_or_err!(self, "p", ModelError::P, "policy");
134        let m_ast = get_or_err!(self, "m", ModelError::M, "matcher");
135        let e_ast = get_or_err!(self, "e", ModelError::E, "effector");
136
137        if r_ast.tokens.len() != rvals.len() {
138            return Err(RequestError::UnmatchRequestDefinition(
139                r_ast.tokens.len(),
140                rvals.len(),
141            )
142            .into());
143        }
144
145        for (rtoken, rval) in r_ast.tokens.iter().zip(rvals.iter()) {
146            scope.push_constant_dynamic(rtoken, rval.to_owned());
147        }
148
149        let policies = p_ast.get_policy();
150        let (policy_len, scope_len) = (policies.len(), scope.len());
151
152        let mut eft_stream =
153            self.eft.new_stream(&e_ast.value, max(policy_len, 1));
154        let m_ast_compiled = if let Some(default_model) =
155            self.model.as_any().downcast_ref::<DefaultModel>()
156        {
157            default_model.get_compiled_matcher("m").ok_or_else(|| {
158                crate::error::Error::ModelError(crate::error::ModelError::M(
159                    "Matcher 'm' not compiled".to_string(),
160                ))
161            })?
162        } else {
163            // Fallback to original compilation (for other Model implementations)
164            &self
165                .engine
166                .compile_expression(escape_eval(&m_ast.value))
167                .map_err(Into::<Box<EvalAltResult>>::into)?
168        };
169
170        if policy_len == 0 {
171            for token in p_ast.tokens.iter() {
172                scope.push_constant(token, String::new());
173            }
174
175            let eval_result = self
176                .engine
177                .eval_ast_with_scope::<bool>(&mut scope, m_ast_compiled)?;
178            let eft = if eval_result {
179                EffectKind::Allow
180            } else {
181                EffectKind::Indeterminate
182            };
183
184            eft_stream.push_effect(eft);
185
186            return Ok((eft_stream.next(), None));
187        }
188
189        for pvals in policies {
190            scope.rewind(scope_len);
191
192            if p_ast.tokens.len() != pvals.len() {
193                return Err(PolicyError::UnmatchPolicyDefinition(
194                    p_ast.tokens.len(),
195                    pvals.len(),
196                )
197                .into());
198            }
199            for (ptoken, pval) in p_ast.tokens.iter().zip(pvals.iter()) {
200                scope.push_constant(ptoken, pval.to_owned());
201            }
202
203            let eval_result = self
204                .engine
205                .eval_ast_with_scope::<bool>(&mut scope, m_ast_compiled)?;
206            let eft = match p_ast.tokens.iter().position(|x| x == "p_eft") {
207                Some(j) if eval_result => {
208                    let p_eft = &pvals[j];
209                    if p_eft == "deny" {
210                        EffectKind::Deny
211                    } else if p_eft == "allow" {
212                        EffectKind::Allow
213                    } else {
214                        EffectKind::Indeterminate
215                    }
216                }
217                None if eval_result => EffectKind::Allow,
218                _ => EffectKind::Indeterminate,
219            };
220
221            if eft_stream.push_effect(eft) {
222                break;
223            }
224        }
225
226        Ok((eft_stream.next(), {
227            #[cfg(feature = "explain")]
228            {
229                eft_stream.explain()
230            }
231            #[cfg(not(feature = "explain"))]
232            {
233                None
234            }
235        }))
236    }
237
238    pub(crate) fn private_enforce_with_context(
239        &self,
240        ctx: EnforceContext,
241        rvals: &[Dynamic],
242    ) -> Result<(bool, Option<Vec<usize>>)> {
243        if !self.enabled {
244            return Ok((true, None));
245        }
246
247        let mut scope: Scope = Scope::new();
248        let r_ast = get_or_err_with_context!(
249            self,
250            "r",
251            &ctx.r_type,
252            ModelError::R,
253            "request"
254        );
255        let p_ast = get_or_err_with_context!(
256            self,
257            "p",
258            &ctx.p_type,
259            ModelError::P,
260            "policy"
261        );
262        let m_ast = get_or_err_with_context!(
263            self,
264            "m",
265            &ctx.m_type,
266            ModelError::M,
267            "matcher"
268        );
269        let e_ast = get_or_err_with_context!(
270            self,
271            "e",
272            &ctx.e_type,
273            ModelError::E,
274            "effector"
275        );
276
277        if r_ast.tokens.len() != rvals.len() {
278            return Err(RequestError::UnmatchRequestDefinition(
279                r_ast.tokens.len(),
280                rvals.len(),
281            )
282            .into());
283        }
284
285        for (rtoken, rval) in r_ast.tokens.iter().zip(rvals.iter()) {
286            scope.push_constant_dynamic(rtoken, rval.to_owned());
287        }
288
289        let policies = p_ast.get_policy();
290        let (policy_len, scope_len) = (policies.len(), scope.len());
291
292        let mut eft_stream =
293            self.eft.new_stream(&e_ast.value, max(policy_len, 1));
294        let m_ast_compiled = if let Some(default_model) =
295            self.model.as_any().downcast_ref::<DefaultModel>()
296        {
297            default_model.get_compiled_matcher(&ctx.m_type).ok_or_else(
298                || {
299                    crate::error::Error::ModelError(
300                        crate::error::ModelError::M(format!(
301                            "Matcher '{}' not compiled",
302                            ctx.m_type
303                        )),
304                    )
305                },
306            )?
307        } else {
308            // Fallback to original compilation (for other Model implementations)
309            &self
310                .engine
311                .compile_expression(escape_eval(&m_ast.value))
312                .map_err(Into::<Box<EvalAltResult>>::into)?
313        };
314
315        if policy_len == 0 {
316            for token in p_ast.tokens.iter() {
317                scope.push_constant(token, String::new());
318            }
319
320            let eval_result = self
321                .engine
322                .eval_ast_with_scope::<bool>(&mut scope, m_ast_compiled)?;
323            let eft = if eval_result {
324                EffectKind::Allow
325            } else {
326                EffectKind::Indeterminate
327            };
328
329            eft_stream.push_effect(eft);
330
331            return Ok((eft_stream.next(), None));
332        }
333
334        for pvals in policies {
335            scope.rewind(scope_len);
336
337            if p_ast.tokens.len() != pvals.len() {
338                return Err(PolicyError::UnmatchPolicyDefinition(
339                    p_ast.tokens.len(),
340                    pvals.len(),
341                )
342                .into());
343            }
344            for (ptoken, pval) in p_ast.tokens.iter().zip(pvals.iter()) {
345                scope.push_constant(ptoken, pval.to_owned());
346            }
347
348            let eval_result = self
349                .engine
350                .eval_ast_with_scope::<bool>(&mut scope, m_ast_compiled)?;
351            let eft = match p_ast.tokens.iter().position(|x| x == "p_eft") {
352                Some(j) if eval_result => {
353                    let p_eft = &pvals[j];
354                    if p_eft == "deny" {
355                        EffectKind::Deny
356                    } else if p_eft == "allow" {
357                        EffectKind::Allow
358                    } else {
359                        EffectKind::Indeterminate
360                    }
361                }
362                None if eval_result => EffectKind::Allow,
363                _ => EffectKind::Indeterminate,
364            };
365
366            if eft_stream.push_effect(eft) {
367                break;
368            }
369        }
370
371        Ok((eft_stream.next(), {
372            #[cfg(feature = "explain")]
373            {
374                eft_stream.explain()
375            }
376            #[cfg(not(feature = "explain"))]
377            {
378                None
379            }
380        }))
381    }
382
383    fn register_function(engine: &mut Engine, key: &str, f: OperatorFunction) {
384        match f {
385            OperatorFunction::Arg0(func) => {
386                engine.register_fn(key, func);
387            }
388            OperatorFunction::Arg1(func) => {
389                engine.register_fn(key, func);
390            }
391            OperatorFunction::Arg2(func) => {
392                engine.register_fn(key, func);
393            }
394            OperatorFunction::Arg3(func) => {
395                engine.register_fn(key, func);
396            }
397            OperatorFunction::Arg4(func) => {
398                engine.register_fn(key, func);
399            }
400            OperatorFunction::Arg5(func) => {
401                engine.register_fn(key, func);
402            }
403            OperatorFunction::Arg6(func) => {
404                engine.register_fn(key, func);
405            }
406        }
407    }
408
409    pub(crate) fn register_g_functions(&mut self) -> Result<()> {
410        if let Some(ast_map) = self.model.get_model().get("g") {
411            for (fname, ast) in ast_map {
412                register_g_function!(self, fname, ast);
413            }
414        }
415
416        Ok(())
417    }
418}
419
420#[async_trait]
421impl CoreApi for Enforcer {
422    #[allow(clippy::box_default)]
423    async fn new_raw<M: TryIntoModel, A: TryIntoAdapter>(
424        m: M,
425        a: A,
426    ) -> Result<Self> {
427        let model = m.try_into_model().await?;
428        let adapter = a.try_into_adapter().await?;
429        let fm = FunctionMap::default();
430        let eft = Box::new(DefaultEffector);
431        let rm = Arc::new(RwLock::new(DefaultRoleManager::new(10)));
432
433        let mut engine = Engine::new_raw();
434
435        engine.register_global_module(CASBIN_PACKAGE.as_shared_module());
436
437        for (key, &func) in fm.get_functions() {
438            Self::register_function(&mut engine, key, func);
439        }
440
441        let mut e = Self {
442            model,
443            adapter,
444            fm,
445            eft,
446            rm,
447            enabled: true,
448            auto_save: true,
449            auto_build_role_links: true,
450            #[cfg(feature = "watcher")]
451            auto_notify_watcher: true,
452            #[cfg(feature = "watcher")]
453            watcher: None,
454            events: HashMap::new(),
455            engine,
456            #[cfg(feature = "logging")]
457            logger: Box::new(DefaultLogger::default()),
458        };
459
460        #[cfg(any(feature = "logging", feature = "watcher"))]
461        e.on(Event::PolicyChange, notify_logger_and_watcher);
462
463        e.register_g_functions()?;
464
465        // If using DefaultModel, compile matcher expressions
466        if let Some(default_model) =
467            e.model.as_any_mut().downcast_mut::<DefaultModel>()
468        {
469            default_model.compile_matchers(&e.engine)?;
470        }
471
472        Ok(e)
473    }
474
475    #[inline]
476    async fn new<M: TryIntoModel, A: TryIntoAdapter>(
477        m: M,
478        a: A,
479    ) -> Result<Self> {
480        let mut e = Self::new_raw(m, a).await?;
481
482        // Do not initialize the full policy when using a filtered adapter
483        if !e.adapter.is_filtered() {
484            e.load_policy().await?;
485        }
486        Ok(e)
487    }
488
489    #[inline]
490    fn add_function(&mut self, fname: &str, f: OperatorFunction) {
491        self.fm.add_function(fname, f);
492        Self::register_function(&mut self.engine, fname, f);
493    }
494
495    #[inline]
496    fn get_model(&self) -> &dyn Model {
497        &*self.model
498    }
499
500    #[inline]
501    fn get_mut_model(&mut self) -> &mut dyn Model {
502        &mut *self.model
503    }
504
505    #[inline]
506    fn get_adapter(&self) -> &dyn Adapter {
507        &*self.adapter
508    }
509
510    #[inline]
511    fn get_mut_adapter(&mut self) -> &mut dyn Adapter {
512        &mut *self.adapter
513    }
514
515    #[cfg(feature = "watcher")]
516    #[inline]
517    fn set_watcher(&mut self, w: Box<dyn Watcher>) {
518        self.watcher = Some(w);
519    }
520
521    #[cfg(feature = "logging")]
522    #[inline]
523    fn get_logger(&self) -> &dyn Logger {
524        &*self.logger
525    }
526
527    #[cfg(feature = "logging")]
528    #[inline]
529    fn set_logger(&mut self, l: Box<dyn Logger>) {
530        self.logger = l;
531    }
532
533    #[cfg(feature = "watcher")]
534    #[inline]
535    fn get_watcher(&self) -> Option<&dyn Watcher> {
536        if let Some(ref watcher) = self.watcher {
537            Some(&**watcher)
538        } else {
539            None
540        }
541    }
542
543    #[cfg(feature = "watcher")]
544    #[inline]
545    fn get_mut_watcher(&mut self) -> Option<&mut dyn Watcher> {
546        if let Some(ref mut watcher) = self.watcher {
547            Some(&mut **watcher)
548        } else {
549            None
550        }
551    }
552
553    #[inline]
554    fn get_role_manager(&self) -> Arc<RwLock<dyn RoleManager>> {
555        Arc::clone(&self.rm)
556    }
557
558    #[inline]
559    fn set_role_manager(
560        &mut self,
561        rm: Arc<RwLock<dyn RoleManager>>,
562    ) -> Result<()> {
563        self.rm = rm;
564        if self.auto_build_role_links {
565            self.build_role_links()?;
566        }
567
568        self.register_g_functions()
569    }
570
571    async fn set_model<M: TryIntoModel>(&mut self, m: M) -> Result<()> {
572        self.model = m.try_into_model().await?;
573
574        // If using DefaultModel, recompile matcher expressions
575        if let Some(default_model) =
576            self.model.as_any_mut().downcast_mut::<DefaultModel>()
577        {
578            default_model.compile_matchers(&self.engine)?;
579        }
580
581        self.load_policy().await?;
582        Ok(())
583    }
584
585    async fn set_adapter<A: TryIntoAdapter>(&mut self, a: A) -> Result<()> {
586        self.adapter = a.try_into_adapter().await?;
587        self.load_policy().await?;
588        Ok(())
589    }
590
591    #[inline]
592    fn set_effector(&mut self, e: Box<dyn Effector>) {
593        self.eft = e;
594    }
595
596    /// Enforce decides whether a "subject" can access a "object" with the operation "action",
597    /// input parameters are usually: (sub, obj, act).
598    ///
599    /// # Examples
600    /// ```
601    /// use casbin::prelude::*;
602    /// #[cfg(feature = "runtime-async-std")]
603    /// #[async_std::main]
604    /// async fn main() -> Result<()> {
605    ///     let mut e = Enforcer::new("examples/basic_model.conf", "examples/basic_policy.csv").await?;
606    ///     assert_eq!(true, e.enforce(("alice", "data1", "read"))?);
607    ///     Ok(())
608    /// }
609    ///
610    /// #[cfg(feature = "runtime-tokio")]
611    /// #[tokio::main]
612    /// async fn main() -> Result<()> {
613    ///     let mut e = Enforcer::new("examples/basic_model.conf", "examples/basic_policy.csv").await?;
614    ///     assert_eq!(true, e.enforce(("alice", "data1", "read"))?);
615    ///
616    ///     Ok(())
617    /// }
618    /// #[cfg(all(not(feature = "runtime-async-std"), not(feature = "runtime-tokio")))]
619    /// fn main() {}
620    /// ```
621    fn enforce<ARGS: EnforceArgs>(&self, rvals: ARGS) -> Result<bool> {
622        let rvals = rvals.try_into_vec()?;
623        #[allow(unused_variables)]
624        let (authorized, indices) = self.private_enforce(&rvals)?;
625
626        #[cfg(feature = "logging")]
627        {
628            self.logger.print_enforce_log(
629                rvals.iter().map(|x| x.to_string()).collect(),
630                authorized,
631                false,
632            );
633
634            #[cfg(feature = "explain")]
635            if let Some(indices) = indices {
636                let all_rules = get_or_err!(self, "p", ModelError::P, "policy")
637                    .get_policy();
638
639                let rules: Vec<String> = indices
640                    .into_iter()
641                    .filter_map(|y| {
642                        all_rules.iter().nth(y).map(|x| x.join(", "))
643                    })
644                    .collect();
645
646                self.logger.print_explain_log(rules);
647            }
648        }
649
650        Ok(authorized)
651    }
652    /// Enforce decides whether a "subject" can access a "object" with the operation "action",
653    /// input parameters are usually: (sub, obj, act).
654    /// this function will add suffix to each model eg. r2, p2, e2, m2, g2,
655    ///
656    /// # Examples
657    /// ```
658    /// use casbin::prelude::*;
659    /// use casbin::EnforceContext;
660    ///
661    /// #[cfg(feature = "runtime-async-std")]
662    /// #[async_std::main]
663    /// async fn main() -> Result<()> {
664    ///     let mut e = Enforcer::new("examples/multi_section_model.conf", "examples/multi_section_policy.csv").await?;
665    ///     assert_eq!(true, e.enforce(("alice", "read", "project1"))?);
666    ///     let ctx = EnforceContext::new("2");
667    ///     assert_eq!(true, e.enforce_with_context(ctx, ("james", "execute"))?);
668    ///     Ok(())
669    /// }
670    ///
671    /// #[cfg(feature = "runtime-tokio")]
672    /// #[tokio::main]
673    /// async fn main() -> Result<()> {
674    ///     let mut e = Enforcer::new("examples/multi_section_model.conf", "examples/multi_section_policy.csv").await?;
675    ///     assert_eq!(true, e.enforce(("alice", "read", "project1"))?);
676    ///     let ctx = EnforceContext::new("2");
677    ///     assert_eq!(true, e.enforce_with_context(ctx, ("james", "execute"))?);
678    ///
679    ///     Ok(())
680    /// }
681    /// #[cfg(all(not(feature = "runtime-async-std"), not(feature = "runtime-tokio")))]
682    /// fn main() {}
683    /// ```
684    fn enforce_with_context<ARGS: EnforceArgs>(
685        &self,
686        ctx: EnforceContext,
687        rvals: ARGS,
688    ) -> Result<bool> {
689        let rvals = rvals.try_into_vec()?;
690        #[allow(unused_variables)]
691        let (authorized, indices) =
692            self.private_enforce_with_context(ctx, &rvals)?;
693
694        #[cfg(feature = "logging")]
695        {
696            self.logger.print_enforce_log(
697                rvals.iter().map(|x| x.to_string()).collect(),
698                authorized,
699                false,
700            );
701
702            #[cfg(feature = "explain")]
703            if let Some(indices) = indices {
704                let all_rules = get_or_err!(self, "p", ModelError::P, "policy")
705                    .get_policy();
706
707                let rules: Vec<String> = indices
708                    .into_iter()
709                    .filter_map(|y| {
710                        all_rules.iter().nth(y).map(|x| x.join(", "))
711                    })
712                    .collect();
713
714                self.logger.print_explain_log(rules);
715            }
716        }
717
718        Ok(authorized)
719    }
720
721    fn enforce_mut<ARGS: EnforceArgs>(&mut self, rvals: ARGS) -> Result<bool> {
722        self.enforce(rvals)
723    }
724
725    #[cfg(feature = "explain")]
726    fn enforce_ex<ARGS: EnforceArgs>(
727        &self,
728        rvals: ARGS,
729    ) -> Result<(bool, Vec<Vec<String>>)> {
730        let rvals = rvals.try_into_vec()?;
731        #[allow(unused_variables)]
732        let (authorized, indices) = self.private_enforce(&rvals)?;
733
734        let rules = match indices {
735            Some(indices) => {
736                let all_rules = get_or_err!(self, "p", ModelError::P, "policy")
737                    .get_policy();
738
739                indices
740                    .into_iter()
741                    .filter_map(|y| all_rules.iter().nth(y).cloned())
742                    .collect::<Vec<_>>()
743            }
744            None => vec![],
745        };
746
747        Ok((authorized, rules))
748    }
749
750    fn build_role_links(&mut self) -> Result<()> {
751        self.rm.write().clear();
752        self.model.build_role_links(Arc::clone(&self.rm))?;
753
754        Ok(())
755    }
756
757    #[cfg(feature = "incremental")]
758    fn build_incremental_role_links(&mut self, d: EventData) -> Result<()> {
759        self.model
760            .build_incremental_role_links(Arc::clone(&self.rm), d)?;
761
762        Ok(())
763    }
764
765    async fn load_policy(&mut self) -> Result<()> {
766        self.model.clear_policy();
767        self.adapter.load_policy(&mut *self.model).await?;
768
769        if self.auto_build_role_links {
770            self.build_role_links()?;
771        }
772
773        Ok(())
774    }
775
776    async fn load_filtered_policy<'a>(&mut self, f: Filter<'a>) -> Result<()> {
777        self.model.clear_policy();
778        self.adapter
779            .load_filtered_policy(&mut *self.model, f)
780            .await?;
781
782        if self.auto_build_role_links {
783            self.build_role_links()?;
784        }
785
786        Ok(())
787    }
788
789    #[inline]
790    fn is_filtered(&self) -> bool {
791        self.adapter.is_filtered()
792    }
793
794    #[inline]
795    fn is_enabled(&self) -> bool {
796        self.enabled
797    }
798
799    async fn save_policy(&mut self) -> Result<()> {
800        assert!(!self.is_filtered(), "cannot save filtered policy");
801
802        self.adapter.save_policy(&mut *self.model).await?;
803
804        let mut policies = self.get_all_policy();
805        let gpolicies = self.get_all_grouping_policy();
806
807        policies.extend(gpolicies);
808
809        #[cfg(any(feature = "logging", feature = "watcher"))]
810        self.emit(Event::PolicyChange, EventData::SavePolicy(policies));
811
812        Ok(())
813    }
814
815    #[inline]
816    async fn clear_policy(&mut self) -> Result<()> {
817        if self.auto_save {
818            self.adapter.clear_policy().await?;
819        }
820        self.model.clear_policy();
821
822        #[cfg(any(feature = "logging", feature = "watcher"))]
823        self.emit(Event::PolicyChange, EventData::ClearPolicy);
824
825        Ok(())
826    }
827
828    #[inline]
829    fn enable_enforce(&mut self, enabled: bool) {
830        self.enabled = enabled;
831
832        #[cfg(feature = "logging")]
833        self.logger.print_status_log(enabled);
834    }
835
836    #[cfg(feature = "logging")]
837    #[inline]
838    fn enable_log(&mut self, enabled: bool) {
839        self.logger.enable_log(enabled);
840    }
841
842    #[inline]
843    fn enable_auto_save(&mut self, auto_save: bool) {
844        self.auto_save = auto_save;
845    }
846
847    #[inline]
848    fn enable_auto_build_role_links(&mut self, auto_build_role_links: bool) {
849        self.auto_build_role_links = auto_build_role_links;
850    }
851
852    #[cfg(feature = "watcher")]
853    #[inline]
854    fn enable_auto_notify_watcher(&mut self, auto_notify_watcher: bool) {
855        if !auto_notify_watcher {
856            self.off(Event::PolicyChange);
857        } else {
858            self.on(Event::PolicyChange, notify_logger_and_watcher);
859        }
860
861        self.auto_notify_watcher = auto_notify_watcher;
862    }
863
864    #[inline]
865    fn has_auto_save_enabled(&self) -> bool {
866        self.auto_save
867    }
868
869    #[cfg(feature = "watcher")]
870    #[inline]
871    fn has_auto_notify_watcher_enabled(&self) -> bool {
872        self.auto_notify_watcher
873    }
874
875    #[inline]
876    fn has_auto_build_role_links_enabled(&self) -> bool {
877        self.auto_build_role_links
878    }
879}
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884    use crate::prelude::*;
885
886    fn is_send<T: Send>() -> bool {
887        true
888    }
889
890    fn is_sync<T: Sync>() -> bool {
891        true
892    }
893
894    #[test]
895    fn test_send_sync() {
896        assert!(is_send::<Enforcer>());
897        assert!(is_sync::<Enforcer>());
898    }
899
900    #[cfg(not(target_arch = "wasm32"))]
901    #[cfg_attr(
902        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
903        async_std::test
904    )]
905    #[cfg_attr(
906        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
907        tokio::test
908    )]
909    async fn test_enforcer_swap_adapter_type() {
910        let mut m = DefaultModel::default();
911        m.add_def("r", "r", "sub, obj, act");
912        m.add_def("p", "p", "sub, obj, act");
913        m.add_def("e", "e", "some(where (p.eft == allow))");
914        m.add_def(
915            "m",
916            "m",
917            "r.sub == p.sub && keyMatch(r.obj, p.obj) && regexMatch(r.act, p.act)",
918        );
919
920        let file = FileAdapter::new("examples/basic_policy.csv");
921        let mem = MemoryAdapter::default();
922        let mut e = Enforcer::new(m, file).await.unwrap();
923        // this should fail since FileAdapter has basically no add_policy
924        assert!(e
925            .adapter
926            .add_policy(
927                "p",
928                "p",
929                vec!["alice".into(), "data".into(), "read".into()]
930            )
931            .await
932            .unwrap());
933        e.set_adapter(mem).await.unwrap();
934        // this passes since our MemoryAdapter has a working add_policy method
935        assert!(e
936            .adapter
937            .add_policy(
938                "p",
939                "p",
940                vec!["alice".into(), "data".into(), "read".into()]
941            )
942            .await
943            .unwrap())
944    }
945
946    #[cfg(not(target_arch = "wasm32"))]
947    #[cfg_attr(
948        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
949        async_std::test
950    )]
951    #[cfg_attr(
952        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
953        tokio::test
954    )]
955    async fn test_key_match_model_in_memory() {
956        let mut m = DefaultModel::default();
957        m.add_def("r", "r", "sub, obj, act");
958        m.add_def("p", "p", "sub, obj, act");
959        m.add_def("e", "e", "some(where (p.eft == allow))");
960        m.add_def(
961            "m",
962            "m",
963            "r.sub == p.sub && keyMatch(r.obj, p.obj) && regexMatch(r.act, p.act)",
964        );
965
966        let adapter = FileAdapter::new("examples/keymatch_policy.csv");
967        let e = Enforcer::new(m, adapter).await.unwrap();
968        assert_eq!(
969            true,
970            e.enforce(("alice", "/alice_data/resource1", "GET"))
971                .unwrap()
972        );
973        assert_eq!(
974            true,
975            e.enforce(("alice", "/alice_data/resource1", "POST"))
976                .unwrap()
977        );
978        assert_eq!(
979            true,
980            e.enforce(("alice", "/alice_data/resource2", "GET"))
981                .unwrap()
982        );
983        assert_eq!(
984            false,
985            e.enforce(("alice", "/alice_data/resource2", "POST"))
986                .unwrap()
987        );
988        assert_eq!(
989            false,
990            e.enforce(("alice", "/bob_data/resource1", "GET")).unwrap()
991        );
992        assert_eq!(
993            false,
994            e.enforce(("alice", "/bob_data/resource1", "POST")).unwrap()
995        );
996        assert_eq!(
997            false,
998            e.enforce(("alice", "/bob_data/resource2", "GET")).unwrap()
999        );
1000        assert_eq!(
1001            false,
1002            e.enforce(("alice", "/bob_data/resource2", "POST")).unwrap()
1003        );
1004
1005        assert_eq!(
1006            false,
1007            e.enforce(("bob", "/alice_data/resource1", "GET")).unwrap()
1008        );
1009        assert_eq!(
1010            false,
1011            e.enforce(("bob", "/alice_data/resource1", "POST")).unwrap()
1012        );
1013        assert_eq!(
1014            true,
1015            e.enforce(("bob", "/alice_data/resource2", "GET")).unwrap()
1016        );
1017        assert_eq!(
1018            false,
1019            e.enforce(("bob", "/alice_data/resource2", "POST")).unwrap()
1020        );
1021        assert_eq!(
1022            false,
1023            e.enforce(("bob", "/bob_data/resource1", "GET")).unwrap()
1024        );
1025        assert_eq!(
1026            true,
1027            e.enforce(("bob", "/bob_data/resource1", "POST")).unwrap()
1028        );
1029        assert_eq!(
1030            false,
1031            e.enforce(("bob", "/bob_data/resource2", "GET")).unwrap()
1032        );
1033        assert_eq!(
1034            true,
1035            e.enforce(("bob", "/bob_data/resource2", "POST")).unwrap()
1036        );
1037
1038        assert_eq!(true, e.enforce(("cathy", "/cathy_data", "GET")).unwrap());
1039        assert_eq!(true, e.enforce(("cathy", "/cathy_data", "POST")).unwrap());
1040        assert_eq!(
1041            false,
1042            e.enforce(("cathy", "/cathy_data", "DELETE")).unwrap()
1043        );
1044    }
1045
1046    #[cfg(not(target_arch = "wasm32"))]
1047    #[cfg_attr(
1048        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1049        async_std::test
1050    )]
1051    #[cfg_attr(
1052        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1053        tokio::test
1054    )]
1055    async fn test_key_match_model_in_memory_deny() {
1056        let mut m = DefaultModel::default();
1057        m.add_def("r", "r", "sub, obj, act");
1058        m.add_def("p", "p", "sub, obj, act");
1059        m.add_def("e", "e", "!some(where (p.eft == deny))");
1060        m.add_def(
1061            "m",
1062            "m",
1063            "r.sub == p.sub && keyMatch(r.obj, p.obj) && regexMatch(r.act, p.act)",
1064        );
1065
1066        let adapter = FileAdapter::new("examples/keymatch_policy.csv");
1067        let e = Enforcer::new(m, adapter).await.unwrap();
1068        assert_eq!(
1069            true,
1070            e.enforce(("alice", "/alice_data/resource2", "POST"))
1071                .unwrap()
1072        );
1073    }
1074
1075    use crate::RbacApi;
1076    #[cfg(not(target_arch = "wasm32"))]
1077    #[cfg_attr(
1078        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1079        async_std::test
1080    )]
1081    #[cfg_attr(
1082        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1083        tokio::test
1084    )]
1085    async fn test_rbac_model_in_memory_indeterminate() {
1086        let mut m = DefaultModel::default();
1087        m.add_def("r", "r", "sub, obj, act");
1088        m.add_def("p", "p", "sub, obj, act");
1089        m.add_def("g", "g", "_, _");
1090        m.add_def("e", "e", "some(where (p.eft == allow))");
1091        m.add_def(
1092            "m",
1093            "m",
1094            "g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act",
1095        );
1096
1097        let adapter = MemoryAdapter::default();
1098        let mut e = Enforcer::new(m, adapter).await.unwrap();
1099        e.add_permission_for_user(
1100            "alice",
1101            vec!["data1", "invalid"]
1102                .iter()
1103                .map(|s| s.to_string())
1104                .collect(),
1105        )
1106        .await
1107        .unwrap();
1108        assert_eq!(false, e.enforce(("alice", "data1", "read")).unwrap());
1109    }
1110
1111    #[cfg(not(target_arch = "wasm32"))]
1112    #[cfg_attr(
1113        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1114        async_std::test
1115    )]
1116    #[cfg_attr(
1117        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1118        tokio::test
1119    )]
1120    async fn test_rbac_model_in_memory() {
1121        let mut m = DefaultModel::default();
1122        m.add_def("r", "r", "sub, obj, act");
1123        m.add_def("p", "p", "sub, obj, act");
1124        m.add_def("g", "g", "_, _");
1125        m.add_def("e", "e", "some(where (p.eft == allow))");
1126        m.add_def(
1127            "m",
1128            "m",
1129            "g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act",
1130        );
1131
1132        let adapter = MemoryAdapter::default();
1133        let mut e = Enforcer::new(m, adapter).await.unwrap();
1134        e.add_permission_for_user(
1135            "alice",
1136            vec!["data1", "read"]
1137                .iter()
1138                .map(|s| s.to_string())
1139                .collect(),
1140        )
1141        .await
1142        .unwrap();
1143        e.add_permission_for_user(
1144            "bob",
1145            vec!["data2", "write"]
1146                .iter()
1147                .map(|s| s.to_string())
1148                .collect(),
1149        )
1150        .await
1151        .unwrap();
1152        e.add_permission_for_user(
1153            "data2_admin",
1154            vec!["data2", "read"]
1155                .iter()
1156                .map(|s| s.to_string())
1157                .collect(),
1158        )
1159        .await
1160        .unwrap();
1161        e.add_permission_for_user(
1162            "data2_admin",
1163            vec!["data2", "write"]
1164                .iter()
1165                .map(|s| s.to_string())
1166                .collect(),
1167        )
1168        .await
1169        .unwrap();
1170        e.add_role_for_user("alice", "data2_admin", None)
1171            .await
1172            .unwrap();
1173
1174        assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap());
1175        assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());
1176        assert_eq!(true, e.enforce(("alice", "data2", "read")).unwrap());
1177        assert_eq!(true, e.enforce(("alice", "data2", "write")).unwrap());
1178        assert_eq!(false, e.enforce(("bob", "data1", "read")).unwrap());
1179        assert_eq!(false, e.enforce(("bob", "data1", "write")).unwrap());
1180        assert_eq!(false, e.enforce(("bob", "data2", "read")).unwrap());
1181        assert_eq!(true, e.enforce(("bob", "data2", "write")).unwrap());
1182    }
1183
1184    #[cfg(not(target_arch = "wasm32"))]
1185    #[cfg_attr(
1186        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1187        async_std::test
1188    )]
1189    #[cfg_attr(
1190        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1191        tokio::test
1192    )]
1193    async fn test_not_used_rbac_model_in_memory() {
1194        let mut m = DefaultModel::default();
1195        m.add_def("r", "r", "sub, obj, act");
1196        m.add_def("p", "p", "sub, obj, act");
1197        m.add_def("g", "g", "_, _");
1198        m.add_def("e", "e", "some(where (p.eft == allow))");
1199        m.add_def(
1200            "m",
1201            "m",
1202            "g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act",
1203        );
1204
1205        let adapter = MemoryAdapter::default();
1206        let mut e = Enforcer::new(m, adapter).await.unwrap();
1207        e.add_permission_for_user(
1208            "alice",
1209            vec!["data1", "read"]
1210                .iter()
1211                .map(|s| s.to_string())
1212                .collect(),
1213        )
1214        .await
1215        .unwrap();
1216        e.add_permission_for_user(
1217            "bob",
1218            vec!["data2", "write"]
1219                .iter()
1220                .map(|s| s.to_string())
1221                .collect(),
1222        )
1223        .await
1224        .unwrap();
1225
1226        assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap());
1227        assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());
1228        assert_eq!(false, e.enforce(("alice", "data2", "read")).unwrap());
1229        assert_eq!(false, e.enforce(("alice", "data2", "write")).unwrap());
1230        assert_eq!(false, e.enforce(("bob", "data1", "read")).unwrap());
1231        assert_eq!(false, e.enforce(("bob", "data1", "write")).unwrap());
1232        assert_eq!(false, e.enforce(("bob", "data2", "read")).unwrap());
1233        assert_eq!(true, e.enforce(("bob", "data2", "write")).unwrap());
1234    }
1235
1236    #[cfg(feature = "ip")]
1237    #[cfg(not(target_arch = "wasm32"))]
1238    #[cfg_attr(
1239        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1240        async_std::test
1241    )]
1242    #[cfg_attr(
1243        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1244        tokio::test
1245    )]
1246    async fn test_ip_match_model() {
1247        let m = DefaultModel::from_file("examples/ipmatch_model.conf")
1248            .await
1249            .unwrap();
1250
1251        let adapter = FileAdapter::new("examples/ipmatch_policy.csv");
1252        let e = Enforcer::new(m, adapter).await.unwrap();
1253
1254        assert!(e.enforce(("192.168.2.123", "data1", "read")).unwrap());
1255
1256        assert!(e.enforce(("10.0.0.5", "data2", "write")).unwrap());
1257
1258        assert!(!e.enforce(("192.168.2.123", "data1", "write")).unwrap());
1259        assert!(!e.enforce(("192.168.2.123", "data2", "read")).unwrap());
1260        assert!(!e.enforce(("192.168.2.123", "data2", "write")).unwrap());
1261
1262        assert!(!e.enforce(("192.168.0.123", "data1", "read")).unwrap());
1263        assert!(!e.enforce(("192.168.0.123", "data1", "write")).unwrap());
1264        assert!(!e.enforce(("192.168.0.123", "data2", "read")).unwrap());
1265        assert!(!e.enforce(("192.168.0.123", "data2", "write")).unwrap());
1266
1267        assert!(!e.enforce(("10.0.0.5", "data1", "read")).unwrap());
1268        assert!(!e.enforce(("10.0.0.5", "data1", "write")).unwrap());
1269        assert!(!e.enforce(("10.0.0.5", "data2", "read")).unwrap());
1270
1271        assert!(!e.enforce(("192.168.0.1", "data1", "read")).unwrap());
1272        assert!(!e.enforce(("192.168.0.1", "data1", "write")).unwrap());
1273        assert!(!e.enforce(("192.168.0.1", "data2", "read")).unwrap());
1274        assert!(!e.enforce(("192.168.0.1", "data2", "write")).unwrap());
1275    }
1276
1277    #[cfg(not(target_arch = "wasm32"))]
1278    #[cfg_attr(
1279        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1280        async_std::test
1281    )]
1282    #[cfg_attr(
1283        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1284        tokio::test
1285    )]
1286    async fn test_enable_auto_save() {
1287        let m = DefaultModel::from_file("examples/basic_model.conf")
1288            .await
1289            .unwrap();
1290
1291        let adapter = FileAdapter::new("examples/basic_policy.csv");
1292        let mut e = Enforcer::new(m, adapter).await.unwrap();
1293        e.enable_auto_save(false);
1294        e.remove_policy(
1295            vec!["alice", "data1", "read"]
1296                .iter()
1297                .map(|s| s.to_string())
1298                .collect(),
1299        )
1300        .await
1301        .unwrap();
1302        e.load_policy().await.unwrap();
1303
1304        assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap());
1305        assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());
1306        assert_eq!(false, e.enforce(("alice", "data2", "read")).unwrap());
1307        assert_eq!(false, e.enforce(("alice", "data2", "write")).unwrap());
1308        assert_eq!(false, e.enforce(("bob", "data1", "read")).unwrap());
1309        assert_eq!(false, e.enforce(("bob", "data1", "write")).unwrap());
1310        assert_eq!(false, e.enforce(("bob", "data2", "read")).unwrap());
1311        assert_eq!(true, e.enforce(("bob", "data2", "write")).unwrap());
1312
1313        e.enable_auto_save(true);
1314        e.remove_policy(
1315            vec!["alice", "data1", "read"]
1316                .iter()
1317                .map(|s| s.to_string())
1318                .collect(),
1319        )
1320        .await
1321        .unwrap();
1322        e.load_policy().await.unwrap();
1323        assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap());
1324        assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());
1325        assert_eq!(false, e.enforce(("alice", "data2", "read")).unwrap());
1326        assert_eq!(false, e.enforce(("alice", "data2", "write")).unwrap());
1327        assert_eq!(false, e.enforce(("bob", "data1", "read")).unwrap());
1328        assert_eq!(false, e.enforce(("bob", "data1", "write")).unwrap());
1329        assert_eq!(false, e.enforce(("bob", "data2", "read")).unwrap());
1330        assert_eq!(true, e.enforce(("bob", "data2", "write")).unwrap());
1331    }
1332
1333    #[cfg(not(target_arch = "wasm32"))]
1334    #[cfg_attr(
1335        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1336        async_std::test
1337    )]
1338    #[cfg_attr(
1339        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1340        tokio::test
1341    )]
1342    async fn test_role_links() {
1343        let m = DefaultModel::from_file("examples/rbac_model.conf")
1344            .await
1345            .unwrap();
1346
1347        let adapter = MemoryAdapter::default();
1348        let mut e = Enforcer::new(m, adapter).await.unwrap();
1349        e.enable_auto_build_role_links(false);
1350        e.build_role_links().unwrap();
1351        assert_eq!(false, e.enforce(("user501", "data9", "read")).unwrap());
1352    }
1353
1354    #[cfg(not(target_arch = "wasm32"))]
1355    #[cfg_attr(
1356        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1357        async_std::test
1358    )]
1359    #[cfg_attr(
1360        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1361        tokio::test
1362    )]
1363    async fn test_get_and_set_model() {
1364        let m1 = DefaultModel::from_file("examples/basic_model.conf")
1365            .await
1366            .unwrap();
1367        let adapter1 = FileAdapter::new("examples/basic_policy.csv");
1368        let mut e = Enforcer::new(m1, adapter1).await.unwrap();
1369
1370        assert_eq!(false, e.enforce(("root", "data1", "read")).unwrap());
1371
1372        let m2 = DefaultModel::from_file("examples/basic_with_root_model.conf")
1373            .await
1374            .unwrap();
1375        let adapter2 = FileAdapter::new("examples/basic_policy.csv");
1376        let e2 = Enforcer::new(m2, adapter2).await.unwrap();
1377
1378        e.model = e2.model;
1379        assert_eq!(true, e.enforce(("root", "data1", "read")).unwrap());
1380    }
1381
1382    #[cfg(not(target_arch = "wasm32"))]
1383    #[cfg_attr(
1384        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1385        async_std::test
1386    )]
1387    #[cfg_attr(
1388        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1389        tokio::test
1390    )]
1391    async fn test_get_and_set_adapter_in_mem() {
1392        let m1 = DefaultModel::from_file("examples/basic_model.conf")
1393            .await
1394            .unwrap();
1395        let adapter1 = FileAdapter::new("examples/basic_policy.csv");
1396        let mut e = Enforcer::new(m1, adapter1).await.unwrap();
1397
1398        assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap());
1399        assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());
1400
1401        let m2 = DefaultModel::from_file("examples/basic_model.conf")
1402            .await
1403            .unwrap();
1404        let adapter2 = FileAdapter::new("examples/basic_inverse_policy.csv");
1405        let e2 = Enforcer::new(m2, adapter2).await.unwrap();
1406
1407        e.adapter = e2.adapter;
1408        e.load_policy().await.unwrap();
1409        assert_eq!(false, e.enforce(("alice", "data1", "read")).unwrap());
1410        assert_eq!(true, e.enforce(("alice", "data1", "write")).unwrap());
1411    }
1412
1413    #[cfg(not(target_arch = "wasm32"))]
1414    #[cfg_attr(
1415        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1416        async_std::test
1417    )]
1418    #[cfg_attr(
1419        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1420        tokio::test
1421    )]
1422    async fn test_keymatch_custom_model() {
1423        use crate::model::key_match;
1424
1425        let m1 = DefaultModel::from_file("examples/keymatch_custom_model.conf")
1426            .await
1427            .unwrap();
1428        let adapter1 = FileAdapter::new("examples/keymatch_policy.csv");
1429        let mut e = Enforcer::new(m1, adapter1).await.unwrap();
1430
1431        e.add_function(
1432            "keyMatchCustom",
1433            OperatorFunction::Arg2(|s1: Dynamic, s2: Dynamic| {
1434                let s1_str = s1.to_string();
1435                let s2_str = s2.to_string();
1436                key_match(&s1_str, &s2_str).into()
1437            }),
1438        );
1439
1440        assert_eq!(
1441            true,
1442            e.enforce(("alice", "/alice_data/123", "GET")).unwrap()
1443        );
1444        assert_eq!(
1445            true,
1446            e.enforce(("alice", "/alice_data/resource1", "POST"))
1447                .unwrap()
1448        );
1449
1450        assert_eq!(
1451            true,
1452            e.enforce(("bob", "/alice_data/resource2", "GET")).unwrap()
1453        );
1454
1455        assert_eq!(
1456            true,
1457            e.enforce(("bob", "/bob_data/resource1", "POST")).unwrap()
1458        );
1459
1460        assert_eq!(true, e.enforce(("cathy", "/cathy_data", "GET")).unwrap());
1461        assert_eq!(true, e.enforce(("cathy", "/cathy_data", "POST")).unwrap());
1462    }
1463
1464    #[cfg(not(target_arch = "wasm32"))]
1465    #[cfg_attr(
1466        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1467        async_std::test
1468    )]
1469    #[cfg_attr(
1470        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1471        tokio::test
1472    )]
1473    async fn test_filtered_file_adapter() {
1474        let adapter = FileAdapter::new_filtered_adapter(
1475            "examples/rbac_with_domains_policy.csv",
1476        );
1477        let mut e =
1478            Enforcer::new("examples/rbac_with_domains_model.conf", adapter)
1479                .await
1480                .unwrap();
1481
1482        let filter = Filter {
1483            p: vec!["", "domain1"],
1484            g: vec!["", "", "domain1"],
1485        };
1486
1487        e.load_filtered_policy(filter).await.unwrap();
1488        assert_eq!(
1489            e.enforce(("alice", "domain1", "data1", "read")).unwrap(),
1490            true
1491        );
1492        assert!(e.enforce(("alice", "domain1", "data1", "write")).unwrap());
1493        assert!(!e.enforce(("alice", "domain1", "data2", "read")).unwrap());
1494        assert!(!e.enforce(("alice", "domain1", "data2", "write")).unwrap());
1495        assert!(!e.enforce(("bob", "domain2", "data2", "read")).unwrap());
1496        assert!(!e.enforce(("bob", "domain2", "data2", "write")).unwrap());
1497    }
1498
1499    #[cfg(not(target_arch = "wasm32"))]
1500    #[cfg_attr(
1501        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1502        async_std::test
1503    )]
1504    #[cfg_attr(
1505        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1506        tokio::test
1507    )]
1508    async fn test_set_role_manager() {
1509        let mut e = Enforcer::new(
1510            "examples/rbac_with_domains_model.conf",
1511            "examples/rbac_with_domains_policy.csv",
1512        )
1513        .await
1514        .unwrap();
1515
1516        let new_rm = Arc::new(RwLock::new(DefaultRoleManager::new(10)));
1517
1518        e.set_role_manager(new_rm).unwrap();
1519
1520        assert!(e.enforce(("alice", "domain1", "data1", "read")).unwrap(),);
1521        assert!(e.enforce(("alice", "domain1", "data1", "write")).unwrap());
1522        assert!(e.enforce(("bob", "domain2", "data2", "read")).unwrap());
1523        assert!(e.enforce(("bob", "domain2", "data2", "write")).unwrap());
1524    }
1525
1526    #[cfg(not(target_arch = "wasm32"))]
1527    #[cfg_attr(
1528        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1529        async_std::test
1530    )]
1531    #[cfg_attr(
1532        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1533        tokio::test
1534    )]
1535    async fn test_policy_abac1() {
1536        use serde::Serialize;
1537
1538        let mut m = DefaultModel::default();
1539        m.add_def("r", "r", "sub, obj, act");
1540        m.add_def("p", "p", "sub_rule, obj, act");
1541        m.add_def("e", "e", "some(where (p.eft == allow))");
1542        m.add_def(
1543            "m",
1544            "m",
1545            "eval(p.sub_rule) && r.obj == p.obj && r.act == p.act",
1546        );
1547
1548        let a = MemoryAdapter::default();
1549
1550        let mut e = Enforcer::new(m, a).await.unwrap();
1551
1552        e.add_policy(
1553            vec!["r.sub.age > 18", "/data1", "read"]
1554                .into_iter()
1555                .map(|x| x.to_string())
1556                .collect(),
1557        )
1558        .await
1559        .unwrap();
1560
1561        #[derive(Serialize, Hash)]
1562        pub struct Person<'a> {
1563            name: &'a str,
1564            age: u8,
1565        }
1566
1567        assert_eq!(
1568            e.enforce((
1569                Person {
1570                    name: "alice",
1571                    age: 16
1572                },
1573                "/data1",
1574                "read"
1575            ))
1576            .unwrap(),
1577            false
1578        );
1579        assert_eq!(
1580            e.enforce((
1581                Person {
1582                    name: "bob",
1583                    age: 19
1584                },
1585                "/data1",
1586                "read"
1587            ))
1588            .unwrap(),
1589            true
1590        );
1591    }
1592
1593    #[cfg(not(target_arch = "wasm32"))]
1594    #[cfg_attr(
1595        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1596        async_std::test
1597    )]
1598    #[cfg_attr(
1599        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1600        tokio::test
1601    )]
1602    async fn test_policy_abac2() {
1603        use serde::Serialize;
1604
1605        let mut m = DefaultModel::default();
1606        m.add_def("r", "r", "sub, obj, act");
1607        m.add_def("p", "p", "sub, obj, act");
1608        m.add_def("e", "e", "some(where (p.eft == allow))");
1609        m.add_def("g", "g", "_, _");
1610        m.add_def(
1611            "m",
1612            "m",
1613            "(g(r.sub, p.sub) || eval(p.sub) == true) && r.act == p.act",
1614        );
1615
1616        let a = MemoryAdapter::default();
1617
1618        let mut e = Enforcer::new(m, a).await.unwrap();
1619
1620        e.add_policy(
1621            vec![r#""admin""#, "post", "write"]
1622                .into_iter()
1623                .map(|x| x.to_string())
1624                .collect(),
1625        )
1626        .await
1627        .unwrap();
1628
1629        e.add_policy(
1630            vec!["r.sub == r.obj.author", "post", "write"]
1631                .into_iter()
1632                .map(|x| x.to_string())
1633                .collect(),
1634        )
1635        .await
1636        .unwrap();
1637
1638        e.add_grouping_policy(
1639            vec!["alice", r#""admin""#]
1640                .into_iter()
1641                .map(|x| x.to_string())
1642                .collect(),
1643        )
1644        .await
1645        .unwrap();
1646
1647        #[derive(Serialize, Hash)]
1648        pub struct Post<'a> {
1649            author: &'a str,
1650        }
1651
1652        assert_eq!(
1653            e.enforce(("alice", Post { author: "bob" }, "write"))
1654                .unwrap(),
1655            true
1656        );
1657
1658        assert_eq!(
1659            e.enforce(("bob", Post { author: "bob" }, "write")).unwrap(),
1660            true
1661        );
1662    }
1663
1664    #[cfg(feature = "explain")]
1665    #[cfg(not(target_arch = "wasm32"))]
1666    #[cfg_attr(
1667        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1668        async_std::test
1669    )]
1670    #[cfg_attr(
1671        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1672        tokio::test
1673    )]
1674    async fn test_enforce_ex() {
1675        use crate::adapter;
1676
1677        let model = DefaultModel::from_file("examples/basic_model.conf")
1678            .await
1679            .unwrap();
1680
1681        let adapter = adapter::FileAdapter::new("examples/basic_policy.csv");
1682
1683        let e = Enforcer::new(model, adapter).await.unwrap();
1684
1685        assert_eq!(
1686            e.enforce_ex(("alice", "data1", "read")).unwrap(),
1687            (
1688                true,
1689                vec![vec![
1690                    "alice".to_string(),
1691                    "data1".to_string(),
1692                    "read".to_string()
1693                ]]
1694            )
1695        );
1696        assert_eq!(
1697            e.enforce_ex(("alice", "data1", "write")).unwrap(),
1698            (false, vec![])
1699        );
1700        assert_eq!(
1701            e.enforce_ex(("alice", "data2", "read")).unwrap(),
1702            (false, vec![])
1703        );
1704        assert_eq!(
1705            e.enforce_ex(("alice", "data2", "write")).unwrap(),
1706            (false, vec![])
1707        );
1708        assert_eq!(
1709            e.enforce_ex(("bob", "data1", "read")).unwrap(),
1710            (false, vec![])
1711        );
1712        assert_eq!(
1713            e.enforce_ex(("bob", "data1", "write")).unwrap(),
1714            (false, vec![])
1715        );
1716        assert_eq!(
1717            e.enforce_ex(("bob", "data2", "read")).unwrap(),
1718            (false, vec![])
1719        );
1720        assert_eq!(
1721            e.enforce_ex(("bob", "data2", "write")).unwrap(),
1722            (
1723                true,
1724                vec![vec![
1725                    "bob".to_string(),
1726                    "data2".to_string(),
1727                    "write".to_string()
1728                ]]
1729            )
1730        );
1731
1732        let e = Enforcer::new(
1733            "examples/rbac_model.conf",
1734            "examples/rbac_policy.csv",
1735        )
1736        .await
1737        .unwrap();
1738
1739        assert_eq!(
1740            e.enforce_ex(("alice", "data1", "read")).unwrap(),
1741            (
1742                true,
1743                vec![vec![
1744                    "alice".to_string(),
1745                    "data1".to_string(),
1746                    "read".to_string()
1747                ]]
1748            )
1749        );
1750        assert_eq!(
1751            e.enforce_ex(("alice", "data1", "write")).unwrap(),
1752            (false, vec![])
1753        );
1754        assert_eq!(
1755            e.enforce_ex(("alice", "data2", "read")).unwrap(),
1756            (
1757                true,
1758                vec![vec![
1759                    "data2_admin".to_string(),
1760                    "data2".to_string(),
1761                    "read".to_string()
1762                ]]
1763            )
1764        );
1765        assert_eq!(
1766            e.enforce_ex(("alice", "data2", "write")).unwrap(),
1767            (
1768                true,
1769                vec![vec![
1770                    "data2_admin".to_string(),
1771                    "data2".to_string(),
1772                    "write".to_string()
1773                ]]
1774            )
1775        );
1776        assert_eq!(
1777            e.enforce_ex(("bob", "data1", "read")).unwrap(),
1778            (false, vec![])
1779        );
1780        assert_eq!(
1781            e.enforce_ex(("bob", "data1", "write")).unwrap(),
1782            (false, vec![])
1783        );
1784        assert_eq!(
1785            e.enforce_ex(("bob", "data2", "read")).unwrap(),
1786            (false, vec![])
1787        );
1788        assert_eq!(
1789            e.enforce_ex(("bob", "data2", "write")).unwrap(),
1790            (
1791                true,
1792                vec![vec![
1793                    "bob".to_string(),
1794                    "data2".to_string(),
1795                    "write".to_string()
1796                ]]
1797            )
1798        );
1799
1800        let e = Enforcer::new(
1801            "examples/priority_model.conf",
1802            "examples/priority_policy.csv",
1803        )
1804        .await
1805        .unwrap();
1806
1807        assert_eq!(
1808            e.enforce_ex(("alice", "data1", "read")).unwrap(),
1809            (
1810                true,
1811                vec![vec![
1812                    "alice".to_string(),
1813                    "data1".to_string(),
1814                    "read".to_string(),
1815                    "allow".to_string()
1816                ]]
1817            )
1818        );
1819        assert_eq!(
1820            e.enforce_ex(("alice", "data1", "write")).unwrap(),
1821            (
1822                false,
1823                vec![vec![
1824                    "data1_deny_group".to_string(),
1825                    "data1".to_string(),
1826                    "write".to_string(),
1827                    "deny".to_string()
1828                ]]
1829            )
1830        );
1831        assert_eq!(
1832            e.enforce_ex(("alice", "data2", "read")).unwrap(),
1833            (false, vec![])
1834        );
1835        assert_eq!(
1836            e.enforce_ex(("alice", "data2", "write")).unwrap(),
1837            (false, vec![])
1838        );
1839        assert_eq!(
1840            e.enforce_ex(("bob", "data1", "write")).unwrap(),
1841            (false, vec![])
1842        );
1843        assert_eq!(
1844            e.enforce_ex(("bob", "data2", "read")).unwrap(),
1845            (
1846                true,
1847                vec![vec![
1848                    "data2_allow_group".to_string(),
1849                    "data2".to_string(),
1850                    "read".to_string(),
1851                    "allow".to_string()
1852                ]]
1853            )
1854        );
1855        assert_eq!(
1856            e.enforce_ex(("bob", "data2", "write")).unwrap(),
1857            (
1858                false,
1859                vec![vec![
1860                    "bob".to_string(),
1861                    "data2".to_string(),
1862                    "write".to_string(),
1863                    "deny".to_string()
1864                ]]
1865            )
1866        );
1867    }
1868
1869    #[cfg(not(target_arch = "wasm32"))]
1870    #[cfg_attr(
1871        all(feature = "runtime-async-std", not(target_arch = "wasm32")),
1872        async_std::test
1873    )]
1874    #[cfg_attr(
1875        all(feature = "runtime-tokio", not(target_arch = "wasm32")),
1876        tokio::test
1877    )]
1878    async fn test_custom_function_with_dynamic_types() {
1879        use crate::prelude::*;
1880
1881        let m = DefaultModel::from_str(
1882            r#"
1883[request_definition]
1884r = sub, obj, act
1885
1886[policy_definition]
1887p = sub, obj, act
1888
1889[policy_effect]
1890e = some(where (p.eft == allow))
1891
1892[matchers]
1893m = r.sub == p.sub && r.obj == p.obj && r.act == p.act
1894"#,
1895        )
1896        .await
1897        .unwrap();
1898
1899        let adapter = MemoryAdapter::default();
1900        let mut e = Enforcer::new(m, adapter).await.unwrap();
1901
1902        // Test 1: Custom function that takes integer arguments
1903        e.add_function(
1904            "greaterThan",
1905            OperatorFunction::Arg2(|a: Dynamic, b: Dynamic| {
1906                // Dynamic can hold integers - extract and compare
1907                let a_int = a.as_int().unwrap_or(0);
1908                let b_int = b.as_int().unwrap_or(0);
1909                (a_int > b_int).into()
1910            }),
1911        );
1912
1913        // Test 2: Custom function that works with booleans
1914        e.add_function(
1915            "customAnd",
1916            OperatorFunction::Arg2(|a: Dynamic, b: Dynamic| {
1917                // Dynamic can hold booleans - extract and perform logic
1918                let a_bool = a.as_bool().unwrap_or(false);
1919                let b_bool = b.as_bool().unwrap_or(false);
1920                (a_bool && b_bool).into()
1921            }),
1922        );
1923
1924        // Test 3: Custom function that works with strings
1925        e.add_function(
1926            "stringContains",
1927            OperatorFunction::Arg2(|haystack: Dynamic, needle: Dynamic| {
1928                // Dynamic can hold strings - convert and check
1929                let haystack_str = haystack.to_string();
1930                let needle_str = needle.to_string();
1931                haystack_str.contains(&needle_str).into()
1932            }),
1933        );
1934
1935        // Test 4: Custom function with 3 arguments
1936        e.add_function(
1937            "between",
1938            OperatorFunction::Arg3(
1939                |val: Dynamic, min: Dynamic, max: Dynamic| {
1940                    // Check if val is between min and max (inclusive)
1941                    let val_int = val.as_int().unwrap_or(0);
1942                    let min_int = min.as_int().unwrap_or(0);
1943                    let max_int = max.as_int().unwrap_or(0);
1944                    (val_int >= min_int && val_int <= max_int).into()
1945                },
1946            ),
1947        );
1948
1949        // Verify that custom functions are registered without errors
1950        // In real usage, these would be called from policy matchers
1951
1952        // Test basic enforcement still works with Dynamic-based functions
1953        e.add_policy(vec![
1954            "alice".to_owned(),
1955            "data1".to_owned(),
1956            "read".to_owned(),
1957        ])
1958        .await
1959        .unwrap();
1960
1961        assert_eq!(true, e.enforce(("alice", "data1", "read")).unwrap());
1962        assert_eq!(false, e.enforce(("alice", "data1", "write")).unwrap());
1963        assert_eq!(false, e.enforce(("bob", "data1", "read")).unwrap());
1964    }
1965}