Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

in mem storage: got rid of the namespace indirection #344

Merged
merged 2 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 7 additions & 25 deletions limitador/src/limit.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::limit::conditions::{ErrorType, Literal, SyntaxError, Token, TokenType};
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::error::Error;
Expand Down Expand Up @@ -28,7 +28,7 @@ mod deprecated {
#[cfg(feature = "lenient_conditions")]
pub use deprecated::check_deprecated_syntax_usages_and_reset;

#[derive(Debug, Hash, Eq, PartialEq, Clone, Serialize, Deserialize)]
#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Namespace(String);

impl From<&str> for Namespace {
Expand All @@ -49,7 +49,7 @@ impl From<String> for Namespace {
}
}

#[derive(Eq, Debug, Clone, Serialize, Deserialize)]
#[derive(Eq, Debug, Clone, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Limit {
#[serde(skip_serializing, default)]
id: Option<String>,
Expand All @@ -62,13 +62,11 @@ pub struct Limit {

// Need to sort to generate the same object when using the JSON as a key or
// value in Redis.
#[serde(serialize_with = "ordered_condition_set")]
conditions: HashSet<Condition>,
#[serde(serialize_with = "ordered_set")]
variables: HashSet<String>,
conditions: BTreeSet<Condition>,
variables: BTreeSet<String>,
}

#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash)]
#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
#[serde(try_from = "String", into = "String")]
pub struct Condition {
var_name: String,
Expand Down Expand Up @@ -267,7 +265,7 @@ impl From<Condition> for String {
}
}

#[derive(PartialEq, Eq, Debug, Clone, Hash)]
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Hash)]
pub enum Predicate {
Equal,
NotEqual,
Expand All @@ -291,22 +289,6 @@ impl From<Predicate> for String {
}
}

fn ordered_condition_set<S>(value: &HashSet<Condition>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let ordered: BTreeSet<String> = value.iter().map(|c| c.clone().into()).collect();
ordered.serialize(serializer)
}

fn ordered_set<S>(value: &HashSet<String>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let ordered: BTreeSet<_> = value.iter().collect();
ordered.serialize(serializer)
}

impl Limit {
pub fn new<N: Into<Namespace>, T: TryInto<Condition>>(
namespace: N,
Expand Down
118 changes: 41 additions & 77 deletions limitador/src/storage/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,48 @@ use crate::limit::{Limit, Namespace};
use crate::storage::atomic_expiring_value::AtomicExpiringValue;
use crate::storage::{Authorization, CounterStorage, StorageErr};
use moka::sync::Cache;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::ops::Deref;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};

type NamespacedLimitCounters<T> = HashMap<Namespace, HashMap<Limit, T>>;

pub struct InMemoryStorage {
limits_for_namespace: RwLock<NamespacedLimitCounters<AtomicExpiringValue>>,
simple_limits: RwLock<BTreeMap<Limit, AtomicExpiringValue>>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼

qualified_counters: Cache<Counter, Arc<AtomicExpiringValue>>,
}

impl CounterStorage for InMemoryStorage {
#[tracing::instrument(skip_all)]
fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result<bool, StorageErr> {
let limits_by_namespace = self.limits_for_namespace.read().unwrap();

let mut value = 0;

if counter.is_qualified() {
if let Some(counter) = self.qualified_counters.get(counter) {
value = counter.value();
}
} else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) {
if let Some(counter) = limits.get(counter.limit()) {
value = counter.value();
}
}
let value = if counter.is_qualified() {
self.qualified_counters
.get(counter)
.map(|c| c.value())
.unwrap_or_default()
} else {
let limits_by_namespace = self.simple_limits.read().unwrap();
limits_by_namespace
.get(counter.limit())
.map(|c| c.value())
.unwrap_or_default()
};

Ok(counter.max_value() >= value + delta)
}

#[tracing::instrument(skip_all)]
fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> {
if limit.variables().is_empty() {
let mut limits_by_namespace = self.limits_for_namespace.write().unwrap();
limits_by_namespace
.entry(limit.namespace().clone())
.or_default()
.entry(limit.clone())
.or_default();
let mut limits_by_namespace = self.simple_limits.write().unwrap();
limits_by_namespace.entry(limit.clone()).or_default();
}
Ok(())
}

#[tracing::instrument(skip_all)]
fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> {
let mut limits_by_namespace = self.limits_for_namespace.write().unwrap();
let mut counters = self.simple_limits.write().unwrap();
let now = SystemTime::now();
if counter.is_qualified() {
let value = match self.qualified_counters.get(counter) {
Expand All @@ -62,23 +55,13 @@ impl CounterStorage for InMemoryStorage {
};
value.update(delta, counter.window(), now);
} else {
match limits_by_namespace.entry(counter.limit().namespace().clone()) {
match counters.entry(counter.limit().clone()) {
Entry::Vacant(v) => {
let mut limits = HashMap::new();
limits.insert(
counter.limit().clone(),
AtomicExpiringValue::new(delta, now + counter.window()),
);
v.insert(limits);
v.insert(AtomicExpiringValue::new(delta, now + counter.window()));
}
Entry::Occupied(o) => {
o.get().update(delta, counter.window(), now);
}
Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) {
Entry::Vacant(v) => {
v.insert(AtomicExpiringValue::new(delta, now + counter.window()));
}
Entry::Occupied(o) => {
o.get().update(delta, counter.window(), now);
}
},
}
}
Ok(())
Expand All @@ -91,7 +74,7 @@ impl CounterStorage for InMemoryStorage {
delta: u64,
load_counters: bool,
) -> Result<Authorization, StorageErr> {
let limits_by_namespace = self.limits_for_namespace.read().unwrap();
let limits_by_namespace = self.simple_limits.read().unwrap();
let mut first_limited = None;
let mut counter_values_to_update: Vec<(&AtomicExpiringValue, Duration)> = Vec::new();
let mut qualified_counter_values_to_updated: Vec<(Arc<AtomicExpiringValue>, Duration)> =
Expand Down Expand Up @@ -119,10 +102,8 @@ impl CounterStorage for InMemoryStorage {

// Process simple counters
for counter in counters.iter_mut().filter(|c| !c.is_qualified()) {
let atomic_expiring_value: &AtomicExpiringValue = limits_by_namespace
.get(counter.limit().namespace())
.and_then(|limits| limits.get(counter.limit()))
.unwrap();
let atomic_expiring_value: &AtomicExpiringValue =
limits_by_namespace.get(counter.limit()).unwrap();

if let Some(limited) = process_counter(counter, atomic_expiring_value.value(), delta) {
if !load_counters {
Expand All @@ -135,7 +116,7 @@ impl CounterStorage for InMemoryStorage {
// Process qualified counters
for counter in counters.iter_mut().filter(|c| c.is_qualified()) {
let value = match self.qualified_counters.get(counter) {
None => self.qualified_counters.get_with(counter.clone(), || {
None => self.qualified_counters.get_with_by_ref(counter, || {
Arc::new(AtomicExpiringValue::new(0, now + counter.window()))
}),
Some(counter) => counter,
Expand Down Expand Up @@ -171,24 +152,14 @@ impl CounterStorage for InMemoryStorage {
fn get_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<HashSet<Counter>, StorageErr> {
let mut res = HashSet::new();

let namespaces: HashSet<&Namespace> = limits.iter().map(|l| l.namespace()).collect();
let limits_by_namespace = self.limits_for_namespace.read().unwrap();

for namespace in namespaces {
if let Some(limits) = limits_by_namespace.get(namespace) {
for limit in limits.keys() {
if limits.contains_key(limit) {
for (counter, expiring_value) in self.counters_in_namespace(namespace) {
let mut counter_with_val = counter.clone();
counter_with_val.set_remaining(
counter_with_val.max_value() - expiring_value.value(),
);
counter_with_val.set_expires_in(expiring_value.ttl());
if counter_with_val.expires_in().unwrap() > Duration::ZERO {
res.insert(counter_with_val);
}
}
}
for limit in limits {
for (counter, expiring_value) in self.counters_in_namespace(limit.namespace()) {
let mut counter_with_val = counter.clone();
counter_with_val
.set_remaining(counter_with_val.max_value() - expiring_value.value());
counter_with_val.set_expires_in(expiring_value.ttl());
if counter_with_val.expires_in().unwrap() > Duration::ZERO {
res.insert(counter_with_val);
}
}
}
Expand Down Expand Up @@ -218,15 +189,15 @@ impl CounterStorage for InMemoryStorage {

#[tracing::instrument(skip_all)]
fn clear(&self) -> Result<(), StorageErr> {
self.limits_for_namespace.write().unwrap().clear();
self.simple_limits.write().unwrap().clear();
Ok(())
}
}

impl InMemoryStorage {
pub fn new(cache_size: u64) -> Self {
Self {
limits_for_namespace: RwLock::new(HashMap::new()),
simple_limits: RwLock::new(BTreeMap::new()),
qualified_counters: Cache::new(cache_size),
}
}
Expand All @@ -237,11 +208,11 @@ impl InMemoryStorage {
) -> HashMap<Counter, AtomicExpiringValue> {
let mut res: HashMap<Counter, AtomicExpiringValue> = HashMap::new();

if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) {
for (limit, value) in counters_by_limit {
for (limit, counter) in self.simple_limits.read().unwrap().iter() {
if limit.namespace() == namespace {
res.insert(
Counter::new(limit.clone(), HashMap::default()),
value.clone(),
counter.clone(),
);
}
}
Expand All @@ -256,14 +227,7 @@ impl InMemoryStorage {
}

fn delete_counters_of_limit(&self, limit: &Limit) {
if let Some(counters_by_limit) = self
.limits_for_namespace
.write()
.unwrap()
.get_mut(limit.namespace())
{
counters_by_limit.remove(limit);
}
self.simple_limits.write().unwrap().remove(limit);
}

fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool {
Expand Down
Loading