Skip to content

Commit

Permalink
gpt_batch 3.0 with api-base, embedding_get, new doc
Browse files Browse the repository at this point in the history
  • Loading branch information
fengsxy committed May 4, 2024
1 parent 73d16c8 commit a1eb754
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 8 deletions.
64 changes: 59 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,73 @@
Certainly! Here's a clean and comprehensive README for your `GPTBatcher` tool, formatted in Markdown:

```markdown
# GPT Batcher

A simple tool to batch process messages using OpenAI's GPT models.
A simple tool to batch process messages using OpenAI's GPT models. `GPTBatcher` allows for efficient handling of multiple requests simultaneously, ensuring quick responses and robust error management.

## Installation

Clone this repository and run:
To get started with `GPTBatcher`, clone this repository to your local machine. Navigate to the repository directory and install the required dependencies (if any) by running:

```bash
pip install -r requirements.txt
```

## Quick Start

## Usage
To use `GPTBatcher`, you need to instantiate it with your OpenAI API key and the model name you wish to use. Here's a quick guide:

Here's how to use the `GPTBatcher`:
### Handling Message Lists

This example demonstrates how to send a list of questions and receive answers:

```python
from gpt_batch.batcher import GPTBatcher

# Initialize the batcher
batcher = GPTBatcher(api_key='your_key_here', model_name='gpt-3.5-turbo-1106')
result = batcher.handle_message_list(['your', 'list', 'of', 'messages'])

# Send a list of messages and receive answers
result = batcher.handle_message_list(['question_1', 'question_2', 'question_3', 'question_4'])
print(result)
# Expected output: ["answer_1", "answer_2", "answer_3", "answer_4"]
```

### Handling Embedding Lists

This example shows how to get embeddings for a list of strings:

```python
from gpt_batch.batcher import GPTBatcher

# Reinitialize the batcher for embeddings
batcher = GPTBatcher(api_key='your_key_here', model_name='text-embedding-3-small')

# Send a list of strings and get their embeddings
result = batcher.handle_embedding_list(['question_1', 'question_2', 'question_3', 'question_4'])
print(result)
# Expected output: ["embedding_1", "embedding_2", "embedding_3", "embedding_4"]
```

## Configuration

The `GPTBatcher` class can be customized with several parameters to adjust its performance and behavior:

- **api_key** (str): Your OpenAI API key.
- **model_name** (str): Identifier for the GPT model version you want to use, default is 'gpt-3.5-turbo-1106'.
- **system_prompt** (str): Initial text or question to seed the model, default is empty.
- **temperature** (float): Adjusts the creativity of the responses, default is 1.
- **num_workers** (int): Number of parallel workers for request handling, default is 64.
- **timeout_duration** (int): Timeout for API responses in seconds, default is 60.
- **retry_attempts** (int): How many times to retry a failed request, default is 2.
- **miss_index** (list): Tracks indices of requests that failed to process correctly.

For more detailed documentation on the parameters and methods, refer to the class docstring.

## License

Specify your licensing information here.

```
This README provides clear instructions on how to install and use the `GPTBatcher`, along with detailed explanations of its configuration parameters. Adjust the "License" section as necessary based on your project's licensing terms.
1 change: 1 addition & 0 deletions gpt_batch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .batcher import GPTBatcher


__all__ = ['GPTBatcher']
67 changes: 65 additions & 2 deletions gpt_batch/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,31 @@
from tqdm import tqdm

class GPTBatcher:
def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2):
"""
A class to handle batching and sending requests to the OpenAI GPT model efficiently.
Attributes:
client (OpenAI): The client instance to communicate with the OpenAI API using the provided API key.
model_name (str): The name of the GPT model to be used. Default is 'gpt-3.5-turbo-0125'.
system_prompt (str): Initial prompt or context to be used with the model. Default is an empty string.
temperature (float): Controls the randomness of the model's responses. Higher values lead to more diverse outputs. Default is 1.
num_workers (int): Number of worker threads used for handling concurrent requests. Default is 64.
timeout_duration (int): Maximum time (in seconds) to wait for a response from the API before timing out. Default is 60 seconds.
retry_attempts (int): Number of retries if a request fails. Default is 2.
miss_index (list): Tracks the indices of requests that failed to process correctly.
Parameters:
api_key (str): API key for authenticating requests to the OpenAI API.
model_name (str, optional): Specifies the GPT model version. Default is 'gpt-3.5-turbo-0125'.
system_prompt (str, optional): Initial text or question to seed the model with. Default is empty.
temperature (float, optional): Sets the creativity of the responses. Default is 1.
num_workers (int, optional): Number of parallel workers for request handling. Default is 64.
timeout_duration (int, optional): Timeout for API responses in seconds. Default is 60.
retry_attempts (int, optional): How many times to retry a failed request. Default is 2.
"""

def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",temperature=1,num_workers=64,timeout_duration=60,retry_attempts=2,api_base_url=None):

self.client = OpenAI(api_key=api_key)
self.model_name = model_name
self.system_prompt = system_prompt
Expand All @@ -13,6 +37,8 @@ def __init__(self, api_key, model_name="gpt-3.5-turbo-0125", system_prompt="",te
self.timeout_duration = timeout_duration
self.retry_attempts = retry_attempts
self.miss_index =[]
if api_base_url:
self.client.base_url = api_base_url

def get_attitude(self, ask_text):
index, ask_text = ask_text
Expand Down Expand Up @@ -44,7 +70,7 @@ def process_attitude(self, message_list):
new_list.extend(future.result() for future in done if future.done())
if len(not_done) == 0:
break
future_to_message = {executor.submit(self.get_attitude, (future_to_message[future], msg), temperature): future_to_message[future] for future, msg in not_done}
future_to_message = {executor.submit(self.get_attitude, future_to_message[future]): future_to_message[future] for future, msg in not_done}
executor.shutdown(wait=False)
return new_list

Expand Down Expand Up @@ -80,6 +106,43 @@ def handle_message_list(self,message_list):
attitude_list = self.complete_attitude_list(attitude_list, max_length)
attitude_list = [x[1] for x in attitude_list]
return attitude_list

def process_embedding(self,message_list):
new_list = []
executor = ThreadPoolExecutor(max_workers=self.num_workers)
# Split message_list into chunks
message_chunks = list(self.chunk_list(message_list, self.num_workers))
fixed_get_embedding = partial(self.get_embedding)
for chunk in tqdm(message_chunks, desc="Processing messages"):
future_to_message = {executor.submit(fixed_get_embedding, message): message for message in chunk}
for i in range(self.retry_attempts):
done, not_done = wait(future_to_message.keys(), timeout=self.timeout_duration)
for future in not_done:
future.cancel()
new_list.extend(future.result() for future in done if future.done())
if len(not_done) == 0:
break
future_to_message = {executor.submit(fixed_get_embedding, future_to_message[future]): future_to_message[future] for future in not_done}
executor.shutdown(wait=False)
return new_list
def get_embedding(self,text):
index,text = text
response = self.client.embeddings.create(
input=text,
model=self.model_name)
return (index,response.data[0].embedding)

def handle_embedding_list(self,message_list):
indexed_list = [(index, data) for index, data in enumerate(message_list)]
max_length = len(indexed_list)
attitude_list = self.process_embedding(indexed_list)
attitude_list.sort(key=lambda x: x[0])
attitude_list = self.complete_attitude_list(attitude_list, max_length)
attitude_list = [x[1] for x in attitude_list]
return attitude_list

def get_miss_index(self):
return self.miss_index

# Add other necessary methods similar to the above, refactored to fit within this class structure.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='gpt_batch',
version='0.1.1',
version='0.1.2',
packages=find_packages(),
install_requires=[
'openai', 'tqdm'
Expand Down
32 changes: 32 additions & 0 deletions tests/test_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,38 @@ def test_handle_message_list():
assert len(results) == 2, "There should be two results, one for each message"
assert all(len(result) >= 2 for result in results), "Each result should be at least two elements"

def test_handle_embedding_list():
# Initialize the GPTBatcher with hypothetical valid credentials
#api_key = #get from system environment
api_key = os.getenv('TEST_KEY')
if not api_key:
raise ValueError("API key must be set in the environment variables")
batcher = GPTBatcher(api_key=api_key, model_name='text-embedding-3-small')
embedding_list = [ "I think privacy is important", "I don't think privacy is important"]
results = batcher.handle_embedding_list(embedding_list)
assert len(results) == 2, "There should be two results, one for each message"
assert all(len(result) >= 2 for result in results), "Each result should be at least two elements"

def test_base_url():
# Initialize the GPTBatcher with hypothetical valid credentials
#api_key = #get from system environment
api_key = os.getenv('TEST_KEY')
if not api_key:
raise ValueError("API key must be set in the environment variables")
batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', api_base_url="https://api.openai.com/v2/")
assert batcher.client.base_url == "https://api.openai.com/v2/", "The base URL should be set to the provided value"

def test_get_miss_index():
# Initialize the GPTBatcher with hypothetical valid credentials
#api_key = #get from system environment
api_key = os.getenv('TEST_KEY')
if not api_key:
raise ValueError("API key must be set in the environment variables")
batcher = GPTBatcher(api_key=api_key, model_name='gpt-3.5-turbo-1106', system_prompt="Your task is to discuss privacy questions and provide persuasive answers with supporting reasons.")
message_list = ["I think privacy is important", "I don't think privacy is important"]
results = batcher.handle_message_list(message_list)
miss_index = batcher.get_miss_index()
assert miss_index == [], "The miss index should be empty"
# Optionally, you can add a test configuration if you have specific needs
if __name__ == "__main__":
pytest.main()

0 comments on commit a1eb754

Please sign in to comment.