Skip to content

Commit

Permalink
amend attempt output manipulation to edit underlying structure
Browse files Browse the repository at this point in the history
  • Loading branch information
leondz committed Jul 18, 2024
1 parent 80debd7 commit e9a2697
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
1 change: 0 additions & 1 deletion garak/attempt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Defines the Attempt class, which encapsulates a prompt with metadata and results"""

from collections.abc import Iterable
from types import GeneratorType
from typing import Any, List
import uuid
Expand Down
5 changes: 4 additions & 1 deletion garak/probes/leakreplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def _attempt_prestore_hook(self, attempt: Attempt, seq: int) -> Attempt:
return attempt

def _postprocess_hook(self, attempt: Attempt) -> Attempt:
attempt.outputs = [re.sub("</?name>", "", o) for o in attempt.outputs]
for idx, thread in enumerate(attempt.messages):
attempt.messages[idx][-1]["content"] = re.sub(
"</?name>", "", thread[-1]["content"]
)
return attempt


Expand Down
29 changes: 29 additions & 0 deletions tests/probes/test_probes_leakreplay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: Portions Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import garak
import garak._config
import garak._plugins
import garak.attempt
import garak.cli


def test_leakreplay_hitlog():

args = "-m test.Blank -p leakreplay -d always.Fail".split()
garak.cli.main(args)


def test_leakreplay_output_count():
generations = 1
garak._config.load_base_config()
garak._config.transient.reportfile = open("/dev/null", "w+")
a = garak.attempt.Attempt(prompt="test")
p = garak._plugins.load_plugin(
"probes.leakreplay.LiteratureCloze80", config_root=garak._config
)
g = garak._plugins.load_plugin("generators.test.Blank", config_root=garak._config)
g.generations = generations
p.generator = g
results = p._execute_all([a])
assert len(a.all_outputs) == generations

0 comments on commit e9a2697

Please sign in to comment.