diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 2e01445cb..542b6fc26 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -104,10 +104,6 @@ def __init__( parameters that cannot be set by calling this class' methods. """ - if model_name not in ["gpt-4", "gpt-3.5-turbo"]: - raise ValueError( - "Invalid model_name. It must be either 'gpt-4' or 'gpt-3.5-turbo'." - ) try: import openai @@ -125,6 +121,13 @@ def __init__( raise ValueError( "You must specify an API key to use the OpenAI API integration." ) + try: + client = openai.OpenAI() + client.models.retrieve(model_name) + except openai.NotFoundError: + raise ValueError( + "Invalid model_name. Check openai models list at https://platform.openai.com/docs/models" + ) if config is not None: self.config = replace(config, model=model_name) # type: ignore diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 2801673fe..c2e885eb1 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -1,7 +1,6 @@ import pytest from outlines.models.openai import ( - OpenAI, build_optimistic_mask, find_longest_intersection, find_response_choices_intersection, @@ -49,11 +48,3 @@ def test_find_longest_common_prefix(response, choice, expected_prefix): def test_build_optimistic_mask(transposed, mask_size, expected_mask): mask = build_optimistic_mask(transposed, mask_size) assert mask == expected_mask - - -def test_model_name_validation(): - with pytest.raises(ValueError): - OpenAI(model_name="invalid_model_name") - - with pytest.raises(ValueError): - OpenAI(model_name="gpt-4-1106-preview")