Skip to content

Commit

Permalink
Incorporate high level llama api
Browse files Browse the repository at this point in the history
  • Loading branch information
dtiarks authored and rlouf committed Jan 22, 2024
1 parent ae8374b commit a8d24fe
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 379 deletions.
36 changes: 8 additions & 28 deletions examples/llamacpp_processor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from enum import Enum

import numpy as np
from llama_cpp import Llama, LogitsProcessorList, StoppingCriteria, StoppingCriteriaList
from numpy.typing import NDArray
from llama_cpp import Llama, LogitsProcessorList
from pydantic import BaseModel, constr

from outlines.generate.processors.llamacpp import JSONLogitsProcessor
from outlines.models.llamacpp import JSONLogitsProcessor


class Weapon(str, Enum):
Expand All @@ -31,38 +29,20 @@ class Character(BaseModel):
strength: int


# TODO: why do we need this?
class EosCriteria(StoppingCriteria):
def __init__(self, eos_token_id):
self.eos_token_id = eos_token_id

def __call__(self, input_ids: NDArray[np.intc], logits: NDArray[np.single]):
if self.eos_token_id in input_ids[1:]:
return True


if __name__ == "__main__":
llama = Llama("./phi-2.Q4_K_M.gguf")

prompt = b"Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"
prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:"

logits_processor = JSONLogitsProcessor(Character, llama)
stopping_criteria_list = StoppingCriteriaList([EosCriteria(llama.token_eos())])

json_str = ""
tokens = llama.tokenize(prompt)
for token in llama.generate(
tokens,
json_str = llama.create_completion(
prompt,
top_k=40,
top_p=0.95,
temp=0.7,
temperature=0.7,
max_tokens=120,
logits_processor=LogitsProcessorList([logits_processor]),
stopping_criteria=stopping_criteria_list,
):
d = llama.detokenize([token])
try:
json_str += d.decode("utf-8")
except UnicodeDecodeError:
continue
)["choices"][0]["text"]

print(json_str)
Empty file.
133 changes: 0 additions & 133 deletions outlines/generate/processors/llamacpp.py

This file was deleted.

Loading

0 comments on commit a8d24fe

Please sign in to comment.