Skip to content

Commit

Permalink
Make use_training_labels positional required (#11529)
Browse files Browse the repository at this point in the history
* Make use_training_labels positional required

* Update sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_training_client.py

* lint

* comments
  • Loading branch information
rakshith91 authored May 21, 2020
1 parent 4c7b8de commit bc01929
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 45 deletions.
3 changes: 1 addition & 2 deletions sdk/formrecognizer/azure-ai-formrecognizer/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
- `FormField` does not have a page_number.
- `begin_recognize_receipts` APIs now return `RecognizedReceipt` instead of `USReceipt`
- `USReceiptType` is renamed to `ReceiptType`
- `use_training_labels` is now a required positional param in the `begin_training` APIs.
- `stream` and `url` parameters found on methods for `FormRecognizerClient` have been renamed to `form` and `form_url`, respectively.
For recognize receipt methods, parameters have been renamed to `receipt` and `receipt_url`.



**New features**

- Support to copy a custom model from one Form Recognizer resource to another
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import json
from typing import (
Optional,
Any,
Iterable,
Dict,
Expand Down Expand Up @@ -88,8 +87,8 @@ def __init__(self, endpoint, credential, **kwargs):
)

@distributed_trace
def begin_train_model(self, training_files_url, use_training_labels=False, **kwargs):
# type: (str, Optional[bool], Any) -> LROPoller
def begin_train_model(self, training_files_url, use_training_labels, **kwargs):
# type: (str, bool, Any) -> LROPoller
"""Create and train a custom model. The request must include a `training_files_url` parameter that is an
externally accessible Azure storage blob container Uri (preferably a Shared Access Signature Uri).
Models are trained using documents that are of the following content type - 'application/pdf',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import json
from typing import (
Optional,
Any,
AsyncIterable,
Dict,
Expand Down Expand Up @@ -96,7 +95,7 @@ def __init__(
async def train_model(
self,
training_files_url: str,
use_training_labels: Optional[bool] = False,
use_training_labels: bool,
**kwargs: Any
) -> CustomFormModel:
"""Create and train a custom model. The request must include a `training_files_url` parameter that is an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def train_model_without_labels(self):
) as form_training_client:

# Default for train_model is `use_training_labels=False`
model = await form_training_client.train_model(self.container_sas_url)
model = await form_training_client.train_model(self.container_sas_url, use_training_labels=False)

# Custom model information
print("Model ID: {}".format(model.model_id))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def train_model_without_labels(self):
form_training_client = FormTrainingClient(self.endpoint, AzureKeyCredential(self.key))

# Default for begin_train_model is `use_training_labels=False`
poller = form_training_client.begin_train_model(self.container_sas_url)
poller = form_training_client.begin_train_model(self.container_sas_url, use_training_labels=False)
model = poller.result()

# Custom model information
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_auto_detect_unsupported_stream_content(self, resource_group, location,
def test_custom_form_damaged_file(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

with self.assertRaises(HttpResponseError):
Expand All @@ -73,7 +73,7 @@ def test_custom_form_damaged_file(self, client, container_sas_url):
def test_custom_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

with open(self.form_jpg, "rb") as stream:
Expand All @@ -98,7 +98,7 @@ def test_custom_form_unlabeled(self, client, container_sas_url):
def test_custom_form_multipage_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

with open(self.multipage_invoice_pdf, "rb") as stream:
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_custom_form_multipage_labeled(self, client, container_sas_url):
def test_custom_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down Expand Up @@ -216,7 +216,7 @@ def callback(raw_response, _, headers):
def test_custom_form_multipage_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_auto_detect_unsupported_stream_content(self, resource_group, loca
async def test_custom_form_damaged_file(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

with self.assertRaises(HttpResponseError):
form = await fr_client.recognize_custom_forms(
Expand All @@ -73,7 +73,7 @@ async def test_custom_form_damaged_file(self, client, container_sas_url):
async def test_custom_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

with open(self.form_jpg, "rb") as fd:
myfile = fd.read()
Expand All @@ -94,7 +94,7 @@ async def test_custom_form_unlabeled(self, client, container_sas_url):
async def test_custom_form_multipage_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

with open(self.multipage_invoice_pdf, "rb") as fd:
myfile = fd.read()
Expand Down Expand Up @@ -168,7 +168,7 @@ async def test_custom_form_multipage_labeled(self, client, container_sas_url):
async def test_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down Expand Up @@ -204,7 +204,7 @@ def callback(raw_response, _, headers):
async def test_custom_forms_multipage_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_custom_form_bad_url(self, client, container_sas_url):
def test_custom_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

poller = fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg)
Expand All @@ -89,7 +89,7 @@ def test_custom_form_unlabeled(self, client, container_sas_url):
def test_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

poller = fr_client.begin_recognize_custom_forms_from_url(
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_url):
def test_custom_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down Expand Up @@ -193,7 +193,7 @@ def callback(raw_response, _, headers):
def test_custom_form_multipage_unlabeled_transform(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

responses = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def test_form_bad_url(self, client, container_sas_url):
async def test_form_unlabeled(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

form = await fr_client.recognize_custom_forms_from_url(model.model_id, self.form_url_jpg)

Expand All @@ -85,7 +85,7 @@ async def test_form_unlabeled(self, client, container_sas_url):
async def test_custom_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

forms = await fr_client.recognize_custom_forms_from_url(
model.model_id,
Expand Down Expand Up @@ -148,7 +148,7 @@ async def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_
async def test_form_unlabeled_transform(self, client, container_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down Expand Up @@ -181,7 +181,7 @@ def callback(raw_response, _, headers):
async def test_multipage_unlabeled_transform(self, client, container_sas_url, blob_sas_url):
fr_client = client.get_form_recognizer_client()

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

responses = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_mgmt_model_labeled(self, client, container_sas_url):
@GlobalTrainingAccountPreparer()
def test_mgmt_model_unlabeled(self, client, container_sas_url):

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
unlabeled_model_from_train = poller.result()

unlabeled_model_from_get = client.get_custom_model(unlabeled_model_from_train.model_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def test_mgmt_model_labeled(self, client, container_sas_url):
@GlobalFormRecognizerAccountPreparer()
@GlobalTrainingAccountPreparer()
async def test_mgmt_model_unlabeled(self, client, container_sas_url):
unlabeled_model_from_train = await client.train_model(container_sas_url)
unlabeled_model_from_train = await client.train_model(container_sas_url, use_training_labels=False)

unlabeled_model_from_get = await client.get_custom_model(unlabeled_model_from_train.model_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class TestTraining(FormRecognizerTest):
def test_training_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key):
client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx"))
with self.assertRaises(ClientAuthenticationError):
poller = client.begin_train_model("xx")
poller = client.begin_train_model("xx", use_training_labels=False)

@GlobalFormRecognizerAccountPreparer()
@GlobalTrainingAccountPreparer()
def test_training(self, client, container_sas_url):

poller = client.begin_train_model(training_files_url=container_sas_url)
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False)
model = poller.result()

self.assertIsNotNone(model.model_id)
Expand All @@ -52,7 +52,7 @@ def test_training(self, client, container_sas_url):
@GlobalTrainingAccountPreparer(multipage=True)
def test_training_multipage(self, client, container_sas_url):

poller = client.begin_train_model(container_sas_url)
poller = client.begin_train_model(container_sas_url, use_training_labels=False)
model = poller.result()

self.assertIsNotNone(model.model_id)
Expand Down Expand Up @@ -83,7 +83,7 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

poller = client.begin_train_model(training_files_url=container_sas_url, cls=callback)
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, cls=callback)
model = poller.result()

raw_model = raw_response[0]
Expand All @@ -102,7 +102,7 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

poller = client.begin_train_model(container_sas_url, cls=callback)
poller = client.begin_train_model(container_sas_url, use_training_labels=False, cls=callback)
model = poller.result()

raw_model = raw_response[0]
Expand Down Expand Up @@ -199,16 +199,16 @@ def callback(response):
@GlobalTrainingAccountPreparer()
def test_training_with_files_filter(self, client, container_sas_url):

poller = client.begin_train_model(training_files_url=container_sas_url, include_sub_folders=True)
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True)
model = poller.result()
self.assertEqual(len(model.training_documents), 6)
self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders

poller = client.begin_train_model(container_sas_url, prefix="subfolder", include_sub_folders=True)
poller = client.begin_train_model(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True)
model = poller.result()
self.assertEqual(len(model.training_documents), 1)
self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders

with self.assertRaises(HttpResponseError):
poller = client.begin_train_model(training_files_url=container_sas_url, prefix="xxx")
poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx")
model = poller.result()
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ class TestTrainingAsync(AsyncFormRecognizerTest):
async def test_training_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key):
client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx"))
with self.assertRaises(ClientAuthenticationError):
result = await client.train_model("xx")
result = await client.train_model("xx", use_training_labels=False)

@GlobalFormRecognizerAccountPreparer()
@GlobalTrainingAccountPreparer()
async def test_training(self, client, container_sas_url):

model = await client.train_model(training_files_url=container_sas_url)
model = await client.train_model(
training_files_url=container_sas_url,
use_training_labels=False)

self.assertIsNotNone(model.model_id)
self.assertIsNotNone(model.created_on)
Expand All @@ -51,7 +53,7 @@ async def test_training(self, client, container_sas_url):
@GlobalTrainingAccountPreparer(multipage=True)
async def test_training_multipage(self, client, container_sas_url):

model = await client.train_model(container_sas_url)
model = await client.train_model(container_sas_url, use_training_labels=False)

self.assertIsNotNone(model.model_id)
self.assertIsNotNone(model.created_on)
Expand Down Expand Up @@ -81,7 +83,10 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

model = await client.train_model(training_files_url=container_sas_url, cls=callback)
model = await client.train_model(
training_files_url=container_sas_url,
use_training_labels=False,
cls=callback)

raw_model = raw_response[0]
custom_model = raw_response[1]
Expand All @@ -99,7 +104,7 @@ def callback(response):
raw_response.append(raw_model)
raw_response.append(custom_model)

model = await client.train_model(container_sas_url, cls=callback)
model = await client.train_model(container_sas_url, use_training_labels=False, cls=callback)

raw_model = raw_response[0]
custom_model = raw_response[1]
Expand Down Expand Up @@ -189,13 +194,13 @@ def callback(response):
@GlobalTrainingAccountPreparer()
async def test_training_with_files_filter(self, client, container_sas_url):

model = await client.train_model(training_files_url=container_sas_url, include_sub_folders=True)
model = await client.train_model(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True)
self.assertEqual(len(model.training_documents), 6)
self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders

model = await client.train_model(container_sas_url, prefix="subfolder", include_sub_folders=True)
model = await client.train_model(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True)
self.assertEqual(len(model.training_documents), 1)
self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders

with self.assertRaises(HttpResponseError):
model = await client.train_model(training_files_url=container_sas_url, prefix="xxx")
model = await client.train_model(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx")

0 comments on commit bc01929

Please sign in to comment.