Skip to content

Commit

Permalink
Fix keyspace isolation with ScriptEvaluateAsync (#1377)
Browse files Browse the repository at this point in the history
Co-authored-by: Gunnar Liljas <gunnar.liljas@revide.se>
  • Loading branch information
gliljas and Gunnar Liljas authored Mar 18, 2020
1 parent 659d514 commit 4f58848
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/StackExchange.Redis/KeyspaceIsolation/WrapperBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public Task<bool> KeyExpireAsync(RedisKey key, TimeSpan? expiry, CommandFlags fl
return Inner.KeyExpireAsync(ToInner(key), expiry, flags);
}

public Task<TimeSpan?> KeyIdleTimeAsync(RedisKey key,CommandFlags flags = CommandFlags.None)
public Task<TimeSpan?> KeyIdleTimeAsync(RedisKey key, CommandFlags flags = CommandFlags.None)
{
return Inner.KeyIdleTimeAsync(ToInner(key), flags);
}
Expand Down Expand Up @@ -386,12 +386,14 @@ public Task<RedisResult> ScriptEvaluateAsync(string script, RedisKey[] keys = nu

public Task<RedisResult> ScriptEvaluateAsync(LuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None)
{
return Inner.ScriptEvaluateAsync(script, parameters, flags);
// TODO: The return value could contain prefixed keys. It might make sense to 'unprefix' those?
return script.EvaluateAsync(Inner, parameters, Prefix, flags);
}

public Task<RedisResult> ScriptEvaluateAsync(LoadedLuaScript script, object parameters = null, CommandFlags flags = CommandFlags.None)
{
return Inner.ScriptEvaluateAsync(script, parameters, flags);
// TODO: The return value could contain prefixed keys. It might make sense to 'unprefix' those?
return script.EvaluateAsync(Inner, parameters, Prefix, flags);
}

public Task<long> SetAddAsync(RedisKey key, RedisValue[] values, CommandFlags flags = CommandFlags.None)
Expand Down Expand Up @@ -908,12 +910,12 @@ protected ICollection<object> ToInner(ICollection<object> args)
{
var withPrefix = new object[args.Count];
int i = 0;
foreach(var oldArg in args)
foreach (var oldArg in args)
{
object newArg;
if (oldArg is RedisKey key)
{
newArg = ToInner(key);
newArg = ToInner(key);
}
else if (oldArg is RedisChannel channel)
{
Expand Down
53 changes: 53 additions & 0 deletions tests/StackExchange.Redis.Tests/Scripting.cs
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,32 @@ public void LuaScriptWithWrappedDatabase()
}
}

[Fact]
public async Task AsyncLuaScriptWithWrappedDatabase()
{
const string Script = "redis.call('set', @key, @value)";

using (var conn = Create(allowAdmin: true))
{
Skip.IfMissingFeature(conn, nameof(RedisFeatures.Scripting), f => f.Scripting);
var db = conn.GetDatabase();
var wrappedDb = KeyspaceIsolation.DatabaseExtensions.WithKeyPrefix(db, "prefix-");
var key = Me();
await db.KeyDeleteAsync(key, CommandFlags.FireAndForget);

var prepared = LuaScript.Prepare(Script);
await wrappedDb.ScriptEvaluateAsync(prepared, new { key = (RedisKey)key, value = 123 });
var val1 = await wrappedDb.StringGetAsync(key);
Assert.Equal(123, (int)val1);

var val2 = await db.StringGetAsync("prefix-" + key);
Assert.Equal(123, (int)val2);

var val3 = await db.StringGetAsync(key);
Assert.True(val3.IsNull);
}
}

[Fact]
public void LoadedLuaScriptWithWrappedDatabase()
{
Expand Down Expand Up @@ -883,6 +909,33 @@ public void LoadedLuaScriptWithWrappedDatabase()
}
}

[Fact]
public async Task AsyncLoadedLuaScriptWithWrappedDatabase()
{
const string Script = "redis.call('set', @key, @value)";

using (var conn = Create(allowAdmin: true))
{
Skip.IfMissingFeature(conn, nameof(RedisFeatures.Scripting), f => f.Scripting);
var db = conn.GetDatabase();
var wrappedDb = KeyspaceIsolation.DatabaseExtensions.WithKeyPrefix(db, "prefix2-");
var key = Me();
await db.KeyDeleteAsync(key, CommandFlags.FireAndForget);

var server = conn.GetServer(conn.GetEndPoints()[0]);
var prepared = await LuaScript.Prepare(Script).LoadAsync(server);
await wrappedDb.ScriptEvaluateAsync(prepared, new { key = (RedisKey)key, value = 123 }, flags: CommandFlags.FireAndForget);
var val1 = await wrappedDb.StringGetAsync(key);
Assert.Equal(123, (int)val1);

var val2 = await db.StringGetAsync("prefix2-" + key);
Assert.Equal(123, (int)val2);

var val3 = await db.StringGetAsync(key);
Assert.True(val3.IsNull);
}
}

[Fact]
public void ScriptWithKeyPrefixViaTokens()
{
Expand Down

0 comments on commit 4f58848

Please sign in to comment.