Skip to content

Commit

Permalink
Merge pull request #344 from Kuadrant/inmem_single_lookup
Browse files Browse the repository at this point in the history
in mem storage: got rid of the namespace indirection
  • Loading branch information
alexsnaps committed Sep 18, 2024
2 parents 748b1d1 + 0c08ce4 commit cdb9850
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 102 deletions.
32 changes: 7 additions & 25 deletions limitador/src/limit.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::limit::conditions::{ErrorType, Literal, SyntaxError, Token, TokenType};
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BTreeSet, HashMap, HashSet};
use std::error::Error;
Expand Down Expand Up @@ -28,7 +28,7 @@ mod deprecated {
#[cfg(feature = "lenient_conditions")]
pub use deprecated::check_deprecated_syntax_usages_and_reset;

#[derive(Debug, Hash, Eq, PartialEq, Clone, Serialize, Deserialize)]
#[derive(Debug, Hash, Eq, PartialEq, Clone, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Namespace(String);

impl From<&str> for Namespace {
Expand All @@ -49,7 +49,7 @@ impl From<String> for Namespace {
}
}

#[derive(Eq, Debug, Clone, Serialize, Deserialize)]
#[derive(Eq, Debug, Clone, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Limit {
#[serde(skip_serializing, default)]
id: Option<String>,
Expand All @@ -62,13 +62,11 @@ pub struct Limit {

// Need to sort to generate the same object when using the JSON as a key or
// value in Redis.
#[serde(serialize_with = "ordered_condition_set")]
conditions: HashSet<Condition>,
#[serde(serialize_with = "ordered_set")]
variables: HashSet<String>,
conditions: BTreeSet<Condition>,
variables: BTreeSet<String>,
}

#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash)]
#[derive(Deserialize, Serialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
#[serde(try_from = "String", into = "String")]
pub struct Condition {
var_name: String,
Expand Down Expand Up @@ -267,7 +265,7 @@ impl From<Condition> for String {
}
}

#[derive(PartialEq, Eq, Debug, Clone, Hash)]
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Hash)]
pub enum Predicate {
Equal,
NotEqual,
Expand All @@ -291,22 +289,6 @@ impl From<Predicate> for String {
}
}

fn ordered_condition_set<S>(value: &HashSet<Condition>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let ordered: BTreeSet<String> = value.iter().map(|c| c.clone().into()).collect();
ordered.serialize(serializer)
}

fn ordered_set<S>(value: &HashSet<String>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let ordered: BTreeSet<_> = value.iter().collect();
ordered.serialize(serializer)
}

impl Limit {
pub fn new<N: Into<Namespace>, T: TryInto<Condition>>(
namespace: N,
Expand Down
118 changes: 41 additions & 77 deletions limitador/src/storage/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,48 @@ use crate::limit::{Limit, Namespace};
use crate::storage::atomic_expiring_value::AtomicExpiringValue;
use crate::storage::{Authorization, CounterStorage, StorageErr};
use moka::sync::Cache;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::ops::Deref;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};

type NamespacedLimitCounters<T> = HashMap<Namespace, HashMap<Limit, T>>;

pub struct InMemoryStorage {
limits_for_namespace: RwLock<NamespacedLimitCounters<AtomicExpiringValue>>,
simple_limits: RwLock<BTreeMap<Limit, AtomicExpiringValue>>,
qualified_counters: Cache<Counter, Arc<AtomicExpiringValue>>,
}

impl CounterStorage for InMemoryStorage {
#[tracing::instrument(skip_all)]
fn is_within_limits(&self, counter: &Counter, delta: u64) -> Result<bool, StorageErr> {
let limits_by_namespace = self.limits_for_namespace.read().unwrap();

let mut value = 0;

if counter.is_qualified() {
if let Some(counter) = self.qualified_counters.get(counter) {
value = counter.value();
}
} else if let Some(limits) = limits_by_namespace.get(counter.limit().namespace()) {
if let Some(counter) = limits.get(counter.limit()) {
value = counter.value();
}
}
let value = if counter.is_qualified() {
self.qualified_counters
.get(counter)
.map(|c| c.value())
.unwrap_or_default()
} else {
let limits_by_namespace = self.simple_limits.read().unwrap();
limits_by_namespace
.get(counter.limit())
.map(|c| c.value())
.unwrap_or_default()
};

Ok(counter.max_value() >= value + delta)
}

#[tracing::instrument(skip_all)]
fn add_counter(&self, limit: &Limit) -> Result<(), StorageErr> {
if limit.variables().is_empty() {
let mut limits_by_namespace = self.limits_for_namespace.write().unwrap();
limits_by_namespace
.entry(limit.namespace().clone())
.or_default()
.entry(limit.clone())
.or_default();
let mut limits_by_namespace = self.simple_limits.write().unwrap();
limits_by_namespace.entry(limit.clone()).or_default();
}
Ok(())
}

#[tracing::instrument(skip_all)]
fn update_counter(&self, counter: &Counter, delta: u64) -> Result<(), StorageErr> {
let mut limits_by_namespace = self.limits_for_namespace.write().unwrap();
let mut counters = self.simple_limits.write().unwrap();
let now = SystemTime::now();
if counter.is_qualified() {
let value = match self.qualified_counters.get(counter) {
Expand All @@ -62,23 +55,13 @@ impl CounterStorage for InMemoryStorage {
};
value.update(delta, counter.window(), now);
} else {
match limits_by_namespace.entry(counter.limit().namespace().clone()) {
match counters.entry(counter.limit().clone()) {
Entry::Vacant(v) => {
let mut limits = HashMap::new();
limits.insert(
counter.limit().clone(),
AtomicExpiringValue::new(delta, now + counter.window()),
);
v.insert(limits);
v.insert(AtomicExpiringValue::new(delta, now + counter.window()));
}
Entry::Occupied(o) => {
o.get().update(delta, counter.window(), now);
}
Entry::Occupied(mut o) => match o.get_mut().entry(counter.limit().clone()) {
Entry::Vacant(v) => {
v.insert(AtomicExpiringValue::new(delta, now + counter.window()));
}
Entry::Occupied(o) => {
o.get().update(delta, counter.window(), now);
}
},
}
}
Ok(())
Expand All @@ -91,7 +74,7 @@ impl CounterStorage for InMemoryStorage {
delta: u64,
load_counters: bool,
) -> Result<Authorization, StorageErr> {
let limits_by_namespace = self.limits_for_namespace.read().unwrap();
let limits_by_namespace = self.simple_limits.read().unwrap();
let mut first_limited = None;
let mut counter_values_to_update: Vec<(&AtomicExpiringValue, Duration)> = Vec::new();
let mut qualified_counter_values_to_updated: Vec<(Arc<AtomicExpiringValue>, Duration)> =
Expand Down Expand Up @@ -119,10 +102,8 @@ impl CounterStorage for InMemoryStorage {

// Process simple counters
for counter in counters.iter_mut().filter(|c| !c.is_qualified()) {
let atomic_expiring_value: &AtomicExpiringValue = limits_by_namespace
.get(counter.limit().namespace())
.and_then(|limits| limits.get(counter.limit()))
.unwrap();
let atomic_expiring_value: &AtomicExpiringValue =
limits_by_namespace.get(counter.limit()).unwrap();

if let Some(limited) = process_counter(counter, atomic_expiring_value.value(), delta) {
if !load_counters {
Expand All @@ -135,7 +116,7 @@ impl CounterStorage for InMemoryStorage {
// Process qualified counters
for counter in counters.iter_mut().filter(|c| c.is_qualified()) {
let value = match self.qualified_counters.get(counter) {
None => self.qualified_counters.get_with(counter.clone(), || {
None => self.qualified_counters.get_with_by_ref(counter, || {
Arc::new(AtomicExpiringValue::new(0, now + counter.window()))
}),
Some(counter) => counter,
Expand Down Expand Up @@ -171,24 +152,14 @@ impl CounterStorage for InMemoryStorage {
fn get_counters(&self, limits: &HashSet<Arc<Limit>>) -> Result<HashSet<Counter>, StorageErr> {
let mut res = HashSet::new();

let namespaces: HashSet<&Namespace> = limits.iter().map(|l| l.namespace()).collect();
let limits_by_namespace = self.limits_for_namespace.read().unwrap();

for namespace in namespaces {
if let Some(limits) = limits_by_namespace.get(namespace) {
for limit in limits.keys() {
if limits.contains_key(limit) {
for (counter, expiring_value) in self.counters_in_namespace(namespace) {
let mut counter_with_val = counter.clone();
counter_with_val.set_remaining(
counter_with_val.max_value() - expiring_value.value(),
);
counter_with_val.set_expires_in(expiring_value.ttl());
if counter_with_val.expires_in().unwrap() > Duration::ZERO {
res.insert(counter_with_val);
}
}
}
for limit in limits {
for (counter, expiring_value) in self.counters_in_namespace(limit.namespace()) {
let mut counter_with_val = counter.clone();
counter_with_val
.set_remaining(counter_with_val.max_value() - expiring_value.value());
counter_with_val.set_expires_in(expiring_value.ttl());
if counter_with_val.expires_in().unwrap() > Duration::ZERO {
res.insert(counter_with_val);
}
}
}
Expand Down Expand Up @@ -218,15 +189,15 @@ impl CounterStorage for InMemoryStorage {

#[tracing::instrument(skip_all)]
fn clear(&self) -> Result<(), StorageErr> {
self.limits_for_namespace.write().unwrap().clear();
self.simple_limits.write().unwrap().clear();
Ok(())
}
}

impl InMemoryStorage {
pub fn new(cache_size: u64) -> Self {
Self {
limits_for_namespace: RwLock::new(HashMap::new()),
simple_limits: RwLock::new(BTreeMap::new()),
qualified_counters: Cache::new(cache_size),
}
}
Expand All @@ -237,11 +208,11 @@ impl InMemoryStorage {
) -> HashMap<Counter, AtomicExpiringValue> {
let mut res: HashMap<Counter, AtomicExpiringValue> = HashMap::new();

if let Some(counters_by_limit) = self.limits_for_namespace.read().unwrap().get(namespace) {
for (limit, value) in counters_by_limit {
for (limit, counter) in self.simple_limits.read().unwrap().iter() {
if limit.namespace() == namespace {
res.insert(
Counter::new(limit.clone(), HashMap::default()),
value.clone(),
counter.clone(),
);
}
}
Expand All @@ -256,14 +227,7 @@ impl InMemoryStorage {
}

fn delete_counters_of_limit(&self, limit: &Limit) {
if let Some(counters_by_limit) = self
.limits_for_namespace
.write()
.unwrap()
.get_mut(limit.namespace())
{
counters_by_limit.remove(limit);
}
self.simple_limits.write().unwrap().remove(limit);
}

fn counter_is_within_limits(counter: &Counter, current_val: Option<&u64>, delta: u64) -> bool {
Expand Down

0 comments on commit cdb9850

Please sign in to comment.