diff --git a/src/fastserve/__main__.py b/src/fastserve/__main__.py index c1e2887..1ea4d09 100644 --- a/src/fastserve/__main__.py +++ b/src/fastserve/__main__.py @@ -1,6 +1,9 @@ -from fastserve import BaseFastServe +from fastserve.core import BaseFastServe from fastserve.handler import DummyHandler +from fastserve.utils import BaseRequest handler = DummyHandler() -serve = BaseFastServe(handler=handler) +serve = BaseFastServe( + handle=handler.handle, batch_size=1, timeout=0, input_schema=BaseRequest +) serve.run_server() diff --git a/src/fastserve/core.py b/src/fastserve/core.py index 10dbd49..bcbb461 100644 --- a/src/fastserve/core.py +++ b/src/fastserve/core.py @@ -28,7 +28,7 @@ async def lifespan(app: FastAPI): yield self.batch_processing.cancel() - self._app = FastAPI(lifespan=lifespan) + self._app = FastAPI(lifespan=lifespan, title="FastServe") def _serve( self, diff --git a/src/fastserve/models/llama_cpp.py b/src/fastserve/models/llama_cpp.py index 89bac84..7c185b8 100644 --- a/src/fastserve/models/llama_cpp.py +++ b/src/fastserve/models/llama_cpp.py @@ -39,39 +39,24 @@ def __init__( *args, **kwargs, ): - super().__init__(batch_size, timeout, input_schema=PromptRequest) - if not os.path.exists(model_path): raise FileNotFoundError(f"{model_path} not found.") - if lazy: - self.llm = None - else: - self.llm = Llama( - model_path=model_path, - main_gpu=main_gpu, - n_ctx=n_ctx, - verbose=False, - *args, - **kwargs, - ) + self.llm = Llama( + model_path=model_path, + main_gpu=main_gpu, + n_ctx=n_ctx, + verbose=False, + *args, + **kwargs, + ) self.n_ctx = n_ctx self.model_path = model_path self.main_gpu = main_gpu self.args = args self.kwargs = kwargs + super().__init__(batch_size, timeout, input_schema=PromptRequest) def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: - if self.llm is None: - logger.info("Initializing model") - self.llm = Llama( - model_path=self.model_path, - main_gpu=self.main_gpu, - n_ctx=self.n_ctx, - verbose=False, - *self.args, - **self.kwargs, - ) - result = self.llm(prompt=prompt, *args, **kwargs) logger.info(result) return result diff --git a/src/fastserve/models/sdxl_turbo.py b/src/fastserve/models/sdxl_turbo.py index b78eda5..fcc0104 100644 --- a/src/fastserve/models/sdxl_turbo.py +++ b/src/fastserve/models/sdxl_turbo.py @@ -20,7 +20,6 @@ class ServeSDXLTurbo(FastServe): def __init__( self, batch_size=1, timeout=0.0, device="cuda", num_inference_steps: int = 1 ) -> None: - super().__init__(self, batch_size, timeout) if num_inference_steps > 1: logging.warning( "It is recommended to use inference_steps=1 for SDXL Turbo model." @@ -31,6 +30,7 @@ def __init__( "stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16" ) self.pipe.to(device) + super().__init__(batch_size, timeout) def handle(self, batch: List[PromptRequest]) -> List[StreamingResponse]: prompts = [b.prompt for b in batch] diff --git a/src/fastserve/models/ssd.py b/src/fastserve/models/ssd.py index 7816b28..7475fba 100644 --- a/src/fastserve/models/ssd.py +++ b/src/fastserve/models/ssd.py @@ -18,7 +18,6 @@ class ServeSSD1B(FastServe): def __init__( self, batch_size=1, timeout=0.0, device="cuda", num_inference_steps: int = 50 ) -> None: - super().__init__(self, batch_size, timeout) self.num_inference_steps = num_inference_steps self.input_schema = PromptRequest self.pipe = StableDiffusionXLPipeline.from_pretrained( @@ -28,6 +27,7 @@ def __init__( variant="fp16", ) self.pipe.to(device) + super().__init__(batch_size, timeout) def handle(self, batch: List[PromptRequest]) -> List[StreamingResponse]: prompts = [b.prompt for b in batch]