Skip to content

Commit

Permalink
run prediction in bg thread (shared across pages to interleave CPU/GPU)
Browse files Browse the repository at this point in the history
  • Loading branch information
bertsky committed Sep 18, 2024
1 parent 9611e2c commit fb2a680
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ocrd_calamari/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional
import itertools
from glob import glob
from concurrent.futures import ThreadPoolExecutor

import numpy as np
from ocrd import Processor, OcrdPage, OcrdPageResult
Expand Down Expand Up @@ -46,8 +47,6 @@ def batched(iterable, n):
itertools.batched = batched

class CalamariRecognize(Processor):
# max_workers = 1

@property
def executable(self):
return 'ocrd-calamari-recognize'
Expand Down Expand Up @@ -83,6 +82,9 @@ def setup(self):
voter_params.type = VoterParams.Type.Value(self.parameter["voter"].upper())
self.voter = voter_from_proto(voter_params)

# run in a background thread so GPU parts can be interleaved with CPU pre-/post-processing across pages
self.executor = ThreadPoolExecutor(max_workers=1)

def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional[str] = None) -> OcrdPageResult:
"""
Perform text recognition with Calamari.
Expand Down Expand Up @@ -158,9 +160,9 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional

lines, coords, images = zip(*lines)
# not exposed in MultiPredictor yet, cf. calamari#361:
# results = self.predictor.predict_raw(images, progress_bar=False, batch_size=BATCH_SIZE)
# results = self.executor.submit(self.predictor.predict_raw, images, progress_bar=False, batch_size=BATCH_SIZE).result()
# avoid too large a batch size (causing OOM on CPU or GPU)
fun = lambda x: self.predictor.predict_raw(x, progress_bar=False)
fun = lambda x: self.executor.submit(self.predictor.predict_raw, x, progress_bar=False).result()
results = itertools.chain.from_iterable(
map(fun, itertools.batched(images, BATCH_SIZE)))
for line, line_coords, raw_results in zip(lines, coords, results):
Expand Down

0 comments on commit fb2a680

Please sign in to comment.