casbin/adapter/
memory_adapter.rs

1use crate::{
2    adapter::{Adapter, Filter},
3    model::Model,
4    Result,
5};
6
7use async_trait::async_trait;
8use hashlink::LinkedHashSet;
9
10#[derive(Default)]
11pub struct MemoryAdapter {
12    policy: LinkedHashSet<Vec<String>>,
13    is_filtered: bool,
14}
15
16#[async_trait]
17impl Adapter for MemoryAdapter {
18    async fn load_policy(&mut self, m: &mut dyn Model) -> Result<()> {
19        self.is_filtered = false;
20        for line in self.policy.iter() {
21            let sec = &line[0];
22            let ptype = &line[1];
23            let rule = line[2..].to_vec().clone();
24
25            if let Some(t1) = m.get_mut_model().get_mut(sec) {
26                if let Some(t2) = t1.get_mut(ptype) {
27                    t2.get_mut_policy().insert(rule);
28                }
29            }
30        }
31
32        Ok(())
33    }
34
35    async fn load_filtered_policy<'a>(
36        &mut self,
37        m: &mut dyn Model,
38        f: Filter<'a>,
39    ) -> Result<()> {
40        for line in self.policy.iter() {
41            let sec = &line[0];
42            let ptype = &line[1];
43            let rule = line[1..].to_vec().clone();
44            let mut is_filtered = false;
45
46            if sec == "p" {
47                for (i, r) in f.p.iter().enumerate() {
48                    if !r.is_empty() && r != &rule[i + 1] {
49                        is_filtered = true;
50                    }
51                }
52            }
53            if sec == "g" {
54                for (i, r) in f.g.iter().enumerate() {
55                    if !r.is_empty() && r != &rule[i + 1] {
56                        is_filtered = true;
57                    }
58                }
59            }
60
61            if !is_filtered {
62                if let Some(ast_map) = m.get_mut_model().get_mut(sec) {
63                    if let Some(ast) = ast_map.get_mut(ptype) {
64                        ast.get_mut_policy().insert(rule);
65                    }
66                }
67            } else {
68                self.is_filtered = true;
69            }
70        }
71        Ok(())
72    }
73
74    async fn save_policy(&mut self, m: &mut dyn Model) -> Result<()> {
75        self.policy.clear();
76
77        if let Some(ast_map) = m.get_model().get("p") {
78            for (ptype, ast) in ast_map {
79                if let Some(sec) = ptype.chars().next().map(|x| x.to_string()) {
80                    for policy in ast.get_policy() {
81                        let mut rule = policy.clone();
82                        rule.insert(0, ptype.clone());
83                        rule.insert(0, sec.clone());
84                        self.policy.insert(rule);
85                    }
86                }
87            }
88        }
89
90        if let Some(ast_map) = m.get_model().get("g") {
91            for (ptype, ast) in ast_map {
92                if let Some(sec) = ptype.chars().next().map(|x| x.to_string()) {
93                    for policy in ast.get_policy() {
94                        let mut rule = policy.clone();
95                        rule.insert(0, ptype.clone());
96                        rule.insert(0, sec.clone());
97                        self.policy.insert(rule);
98                    }
99                }
100            }
101        }
102
103        Ok(())
104    }
105
106    async fn clear_policy(&mut self) -> Result<()> {
107        self.policy.clear();
108        self.is_filtered = false;
109        Ok(())
110    }
111
112    async fn add_policy(
113        &mut self,
114        sec: &str,
115        ptype: &str,
116        mut rule: Vec<String>,
117    ) -> Result<bool> {
118        rule.insert(0, ptype.to_owned());
119        rule.insert(0, sec.to_owned());
120
121        Ok(self.policy.insert(rule))
122    }
123
124    async fn add_policies(
125        &mut self,
126        sec: &str,
127        ptype: &str,
128        rules: Vec<Vec<String>>,
129    ) -> Result<bool> {
130        let mut all_added = true;
131        let rules: Vec<Vec<String>> = rules
132            .into_iter()
133            .map(|mut rule| {
134                rule.insert(0, ptype.to_owned());
135                rule.insert(0, sec.to_owned());
136                rule
137            })
138            .collect();
139
140        for rule in &rules {
141            if self.policy.contains(rule) {
142                all_added = false;
143                return Ok(all_added);
144            }
145        }
146        self.policy.extend(rules);
147
148        Ok(all_added)
149    }
150
151    async fn remove_policies(
152        &mut self,
153        sec: &str,
154        ptype: &str,
155        rules: Vec<Vec<String>>,
156    ) -> Result<bool> {
157        let mut all_removed = true;
158        let rules: Vec<Vec<String>> = rules
159            .into_iter()
160            .map(|mut rule| {
161                rule.insert(0, ptype.to_owned());
162                rule.insert(0, sec.to_owned());
163                rule
164            })
165            .collect();
166
167        for rule in &rules {
168            if !self.policy.contains(rule) {
169                all_removed = false;
170                return Ok(all_removed);
171            }
172        }
173        for rule in &rules {
174            self.policy.remove(rule);
175        }
176
177        Ok(all_removed)
178    }
179
180    async fn remove_policy(
181        &mut self,
182        sec: &str,
183        ptype: &str,
184        mut rule: Vec<String>,
185    ) -> Result<bool> {
186        rule.insert(0, ptype.to_owned());
187        rule.insert(0, sec.to_owned());
188
189        Ok(self.policy.remove(&rule))
190    }
191
192    async fn remove_filtered_policy(
193        &mut self,
194        sec: &str,
195        ptype: &str,
196        field_index: usize,
197        field_values: Vec<String>,
198    ) -> Result<bool> {
199        if field_values.is_empty() {
200            return Ok(false);
201        }
202
203        let mut tmp = LinkedHashSet::new();
204        let mut res = false;
205        for rule in &self.policy {
206            if sec == rule[0] && ptype == rule[1] {
207                let mut matched = true;
208                for (i, field_value) in field_values.iter().enumerate() {
209                    if !field_value.is_empty()
210                        && &rule[field_index + i + 2] != field_value
211                    {
212                        matched = false;
213                        break;
214                    }
215                }
216
217                if matched {
218                    res = true;
219                } else {
220                    tmp.insert(rule.clone());
221                }
222            } else {
223                tmp.insert(rule.clone());
224            }
225        }
226        self.policy = tmp;
227
228        Ok(res)
229    }
230
231    fn is_filtered(&self) -> bool {
232        self.is_filtered
233    }
234}