Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use outlines.processors and SequenceGeneratorAdapter for outlines.models.vllm #1053

Merged
merged 1 commit into from
Jul 20, 2024

Conversation

lapp0
Copy link
Collaborator

@lapp0 lapp0 commented Jul 20, 2024

Fixes: models.vllm being the only model (other than models.exllamav2) using SequenceGenerator

Changes

  • Update outlines.generate handlers, default to SequenceGeneratorAdapter for all models except ExLlamaV2Model
  • Update OutlinesLogitsProcessors to allow vLLM input_ids which are of type tuple
  • Fix FSMLogitsProcessor bug: unable to handle batch sequences where prompt ids are part of input_ids. Wasn't caught previously because model_llamacpp cannot perform batch generation.

Tests

  • in tests/generate/test_generate.py
    • test model_vllm
    • enable model_vllm and model_transformers_vision only if cuda is available

Benchmarks

Regex logits processing performance has changed in an acceptable manner

  • Structured generation with torch: 268±7μs -> 225±1μs
  • Noop with numpy: 160±0.9μs -> 185±3μs
Benchmarks that have improved:
Change Before [a7e3381] After [b2c28a7] Ratio Benchmark (Parameter)
- 268±7μs 225±1μs 0.84 bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('torch', 'Z*')

Benchmarks that have stayed the same:

Change Before [a7e3381] After [b2c28a7] Ratio Benchmark (Parameter)
5.21±0.01s 5.18±0.02s 0.99 bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm('complex_schema')
3.67±0.02s 3.62±0.03s 0.99 bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_fsm('simple_schema')
90.4±1μs 87.3±0.3μs 0.97 bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_regex('complex_schema')
49.4±0.04μs 50.0±0.9μs 1.01 bench_json_schema.JsonSchemaBenchmark.time_json_schema_to_regex('simple_schema')
5.69±0.02s 5.71±0.02s 1 bench_numba_compile.NumbaCompileBenchmark.time_compile_numba
180±10μs 173±1μs 0.96 bench_processors.LogitsProcessorPassthroughBenchmark.time_passthrough('torch')
256±3μs 243±2μs 0.95 bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('numpy', 'Z*')
1.03±0.01ms 1.02±0.01ms 0.99 bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('numpy', '[^Z]*')
1.03±0.01ms 1.02±0.02ms 0.99 bench_processors.LogitsProcessorStructuredBenchmark.time_structured_generation('torch', '[^Z]*')
592M 592M 1 bench_regex_guide.MemoryRegexGuideBenchmark.peakmem_regex_to_guide('complex_span_constrained_relation_extraction')
493M 494M 1 bench_regex_guide.MemoryRegexGuideBenchmark.peakmem_regex_to_guide('simple_phone')
2.72±0.05s 2.72±0.02s 1 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('complex_phone')
6.36±0.01s 6.35±0.02s 1 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('complex_span_constrained_relation_extraction')
2.58±0.01s 2.54±0.02s 0.99 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('date')
2.53±0.02s 2.58±0.02s 1.02 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('email')
2.52±0.06s 2.48±0.02s 0.98 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('ip')
2.47±0.03s 2.43±0.02s 0.98 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('simple_phone')
2.43±0.02s 2.41±0.02s 0.99 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('ssn')
2.46±0.06s 2.38±0.01s 0.97 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('time')
2.64±0.04s 2.62±0.03s 0.99 bench_regex_guide.RegexGuideBenchmark.time_regex_to_guide('url')
Benchmarks that have got worse:
Change Before [a7e3381] After [b2c28a7] Ratio Benchmark (Parameter)
---------- ---------------------- --------------------- --------- --------------------------------------------------------------------------------
+ 160±0.9μs 185±3μs 1.16 bench_processors.LogitsProcessorPassthroughBenchmark.time_passthrough('numpy')

Performance degradation detected!

@lapp0 lapp0 marked this pull request as draft July 20, 2024 08:16
@lapp0 lapp0 marked this pull request as ready for review July 20, 2024 08:49
@lapp0 lapp0 added structured generation Linked to structured generation vLLM Things involving vLLM support and removed run-benchmarks labels Jul 20, 2024
@rlouf rlouf merged commit 47dfa4b into dottxt-ai:main Jul 20, 2024
6 of 7 checks passed
Comment on lines +29 to +30
if hasattr(self.model, "get_tokenizer"):
tokenizer = self.model.get_tokenizer()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up, for the AyncLLMEngine (as shown in the outlines vLLM example), this will return a coroutine: https://github.com/vllm-project/vllm/blob/main/vllm/engine/async_llm_engine.py#L506 .

I'm trying to figure out the best path forward because I'd love to use this with my vLLM-based service, but it seems like this work is part of something bigger so I don't want to dive in and start propagating async through this code without checking in with you first. Happy to contribute, but could use a little guidance on the strategy 😁

Copy link
Collaborator Author

@lapp0 lapp0 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's uncertain whether we will move towards async for outlines.generate, but it has been proposed #655 Currently outlines.generate with outlines.models.vllm uses a vllm.LLM

Bare in mind that outlines.serve already has a vllm server integration and vice versa, vllm has an outlines.processors integration in progress

Does outlines.serve or vLLM's outlines integration satisfy your needs, or were you thinking of something different?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, thank you for the sanity check! After re-reviewing the outlines.serve code, I realized I didn't go deep enough and needed to pass my engine.engine (engines all the way down 🐢) to get all the way to the vllm.LLM. Thanks again for the pointers!

Copy link
Collaborator Author

@lapp0 lapp0 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! Bare in mind that after our next major release (because of this PR), the tokenizer, not the engine will be passed to the processor. serve.py has a PR to reflect this behavior https://github.com/outlines-dev/outlines/pull/1061/files#diff-535a1da5f8addb89d07782185c32b54f85189b25786d1c9b7cbd002b55939e16R74

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted! Will keep an eye out for that. Thanks again for everything; super excited for the awesome capabilities you all have enabled with outlines!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
structured generation Linked to structured generation vLLM Things involving vLLM support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants