Skip to content

Commit

Permalink
fixed race condition when generating
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Aug 20, 2024
1 parent 7ee359a commit c1ae350
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
5 changes: 3 additions & 2 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,9 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
llama_reset_timings(llama_ctx_v4);
}

generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens

concat_output_mtx.lock();
concat_output = "";
concat_output_reader_copy_poll = "";
Expand Down Expand Up @@ -2140,8 +2143,6 @@ generation_outputs gpttype_generate(const generation_inputs inputs)

bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1;

generation_finished = false; // Set current generation status
generated_tokens.clear(); // New Generation, new tokens

std::string grammarstr = inputs.grammar;
bool grammar_retain_state = inputs.grammar_retain_state;
Expand Down
10 changes: 3 additions & 7 deletions koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
modelbusy = threading.Lock()
requestsinqueue = 0
defaultport = 5001
KcppVersion = "1.73"
KcppVersion = "1.73.1"
showdebug = True
guimode = False
showsamplerwarning = True
Expand Down Expand Up @@ -1412,11 +1412,7 @@ def run_blocking(): # api format 1=basic,2=kai,3=oai,4=oai-chat
global last_non_horde_req_time
last_non_horde_req_time = time.time()

return generate(
genparams=genparams,
is_quiet=is_quiet,
stream_flag=stream_flag
)
return generate(genparams=genparams,is_quiet=is_quiet,stream_flag=stream_flag)

genout = {"text": "", "status": -1, "stopreason": -1}
if stream_flag:
Expand Down Expand Up @@ -1486,7 +1482,7 @@ async def handle_sse_stream(self, genparams, api_format):
current_token = 0
incomplete_token_buffer = bytearray()
async_sleep_short = 0.02
await asyncio.sleep(0.3) #anti race condition, prevent check from overtaking generate
await asyncio.sleep(0.5) #anti race condition, prevent check from overtaking generate
try:
tokenReserve = "" #keeps fully formed tokens that we cannot send out yet
while True:
Expand Down

0 comments on commit c1ae350

Please sign in to comment.