Skip to content

Commit

Permalink
appbuilder-text2image-update (#481)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update ut

---------

Co-authored-by: yinjiaqi <yinjiaqi@MacBook-Pro.local>
Co-authored-by: MrChengmo <cmchengmo@163.com>
  • Loading branch information
3 people committed Aug 20, 2024
1 parent ddcfbfe commit 6ac8f79
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 273 deletions.
72 changes: 38 additions & 34 deletions appbuilder/core/components/text_to_image/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import math

from typing import Optional
from appbuilder.core.component import Component
from appbuilder.core.message import Message
from appbuilder.core._client import HTTPClient
Expand Down Expand Up @@ -72,38 +73,45 @@ def run(
width: int = 1024,
height: int = 1024,
image_num: int = 1,
timeout: float = None,
retry: int = 0,
request_id: str = None
image: Optional[str] = None,
url: Optional[str] = None,
pdf_file: Optional[str] = None,
pdf_file_num: Optional[str] = None,
change_degree: Optional[int] = None,
text_content: Optional[str] = None,
task_time_out: Optional[int]= None,
text_check: Optional[int] = 1,
request_id: Optional[str] = None
):
"""
输入文本并返回生成的图片url。
参数:
message (obj:`Message`): 输入消息,用于模型的主要输入内容。这是一个必需的参数。举例: Message(content={"prompt": "上海的经典风景"})
width (int,可选): 图片宽度,支持:512x512、640x360、360x640、1024x1024、1280x720、720x1280、2048x2048、2560x1440、1440x2560。
height (int, 可选): 图片高度,支持:512x512、640x360、360x640、1024x1024、1280x720、720x1280、2048x2048、2560x1440、1440x2560。
image_num (int, 可选): 生成图片数量,默认一张,支持生成 1-8 张。
timeout (float, 可选): 请求的超时时间。
retry (int, 可选): 请求的重试次数。
headers = self._http_client.auth_header()
headers["Content-Type"] = "application/json"
api_url = self._http_client.service_url("/v1/bce/aip/ernievilg/v1/txt2imgv2")

req = Text2ImageSubmitRequest(
prompt=message.content["prompt"],
width=width,
height=height,
image_num=image_num,
image=image,
url=url,
pdf_file=pdf_file,
pdf_file_num=pdf_file_num,
change_degree=change_degree,
text_content=text_content,
task_time_out=task_time_out,
text_check=text_check
)
response = self.http_client.session.post(api_url, json=req.model_dump(), headers=headers, timeout=None)
self._http_client.check_response_header(response)
data = response.json()
resp= Text2ImageSubmitResponse(**data)

返回:
obj:`Message`: 输出生成图片的url。举例: Message(content={"img_urls": ["xxx"]})。
"""
inp = Text2ImageInMessage(**message.content)
text2ImageSubmitRequest = Text2ImageSubmitRequest()
text2ImageSubmitRequest.prompt = inp.prompt
text2ImageSubmitRequest.width = width
text2ImageSubmitRequest.height = height
text2ImageSubmitRequest.image_num = image_num
text2ImageSubmitResponse = self.submitText2ImageTask(text2ImageSubmitRequest, request_id=request_id)
taskId = text2ImageSubmitResponse.data.primary_task_id
taskId = resp.data.task_id
if taskId is not None:
task_request_time = 1

while True:
request = Text2ImageQueryRequest()
request.task_id = taskId
request = Text2ImageQueryRequest(task_id=taskId)
text2ImageQueryResponse = self.queryText2ImageData(request, request_id=request_id)
if text2ImageQueryResponse.data.task_progress is not None:
task_progress = float(text2ImageQueryResponse.data.task_progress)
Expand Down Expand Up @@ -143,19 +151,16 @@ def submitText2ImageTask(
obj:`Text2ImageSubmitResponse`: 接口返回的输出消息。
"""
url = self.http_client.service_url("/v1/bce/aip/ernievilg/v1/txt2imgv2")
data = Text2ImageSubmitRequest.to_json(request)
data = request.model_dump()
headers = self.http_client.auth_header(request_id)
headers['content-type'] = 'application/json'
if retry != self.http_client.retry.total:
self.http_client.retry.total = retry
response = self.http_client.session.post(url, data=data, headers=headers, timeout=timeout)
response = self.http_client.session.post(url, json=data, headers=headers, timeout=timeout)
self.http_client.check_response_header(response)
data = response.json()
self.http_client.check_response_json(data)
request_id = self.http_client.response_request_id(response)
self.__class__.check_service_error(request_id, data)
response = Text2ImageSubmitResponse.from_json(payload=json.dumps(data))
response.request_id = request_id
response = Text2ImageSubmitResponse(**data)
return response

def queryText2ImageData(
Expand Down Expand Up @@ -191,8 +196,7 @@ def queryText2ImageData(
self.http_client.check_response_json(data)
request_id = self.http_client.response_request_id(response)
self.__class__.check_service_error(request_id, data)
response = Text2ImageQueryResponse.from_json(payload=json.dumps(data))
response.request_id = request_id
response = Text2ImageQueryResponse(**data)
return response

def extract_img_urls(self, response: Text2ImageQueryResponse):
Expand Down
Loading

0 comments on commit 6ac8f79

Please sign in to comment.