Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Dec 20, 2023
1 parent 046f902 commit c553b12
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 29 deletions.
7 changes: 5 additions & 2 deletions src/fastserve/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from fastserve import BaseFastServe
from fastserve.core import BaseFastServe
from fastserve.handler import DummyHandler
from fastserve.utils import BaseRequest

handler = DummyHandler()
serve = BaseFastServe(handler=handler)
serve = BaseFastServe(
handle=handler.handle, batch_size=1, timeout=0, input_schema=BaseRequest
)
serve.run_server()
2 changes: 1 addition & 1 deletion src/fastserve/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def lifespan(app: FastAPI):
yield
self.batch_processing.cancel()

self._app = FastAPI(lifespan=lifespan)
self._app = FastAPI(lifespan=lifespan, title="FastServe")

def _serve(
self,
Expand Down
33 changes: 9 additions & 24 deletions src/fastserve/models/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,39 +39,24 @@ def __init__(
*args,
**kwargs,
):
super().__init__(batch_size, timeout, input_schema=PromptRequest)

if not os.path.exists(model_path):
raise FileNotFoundError(f"{model_path} not found.")
if lazy:
self.llm = None
else:
self.llm = Llama(
model_path=model_path,
main_gpu=main_gpu,
n_ctx=n_ctx,
verbose=False,
*args,
**kwargs,
)
self.llm = Llama(
model_path=model_path,
main_gpu=main_gpu,
n_ctx=n_ctx,
verbose=False,
*args,
**kwargs,
)
self.n_ctx = n_ctx
self.model_path = model_path
self.main_gpu = main_gpu
self.args = args
self.kwargs = kwargs
super().__init__(batch_size, timeout, input_schema=PromptRequest)

def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
if self.llm is None:
logger.info("Initializing model")
self.llm = Llama(
model_path=self.model_path,
main_gpu=self.main_gpu,
n_ctx=self.n_ctx,
verbose=False,
*self.args,
**self.kwargs,
)

result = self.llm(prompt=prompt, *args, **kwargs)
logger.info(result)
return result
Expand Down
2 changes: 1 addition & 1 deletion src/fastserve/models/sdxl_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class ServeSDXLTurbo(FastServe):
def __init__(
self, batch_size=1, timeout=0.0, device="cuda", num_inference_steps: int = 1
) -> None:
super().__init__(self, batch_size, timeout)
if num_inference_steps > 1:
logging.warning(
"It is recommended to use inference_steps=1 for SDXL Turbo model."
Expand All @@ -31,6 +30,7 @@ def __init__(
"stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16"
)
self.pipe.to(device)
super().__init__(batch_size, timeout)

def handle(self, batch: List[PromptRequest]) -> List[StreamingResponse]:
prompts = [b.prompt for b in batch]
Expand Down
2 changes: 1 addition & 1 deletion src/fastserve/models/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ServeSSD1B(FastServe):
def __init__(
self, batch_size=1, timeout=0.0, device="cuda", num_inference_steps: int = 50
) -> None:
super().__init__(self, batch_size, timeout)
self.num_inference_steps = num_inference_steps
self.input_schema = PromptRequest
self.pipe = StableDiffusionXLPipeline.from_pretrained(
Expand All @@ -28,6 +27,7 @@ def __init__(
variant="fp16",
)
self.pipe.to(device)
super().__init__(batch_size, timeout)

def handle(self, batch: List[PromptRequest]) -> List[StreamingResponse]:
prompts = [b.prompt for b in batch]
Expand Down

0 comments on commit c553b12

Please sign in to comment.