1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
import io import cv2 import numpy as np import PIL.Image as PIL_Image import tensorflow as tf from importlib import import_module from config import * from constants import RunMode from pretreatment import preprocessing from framework import GraphOCR def get_image_batch(img_bytes): def load_image(image_bytes): data_stream = io.BytesIO(image_bytes) pil_image = PIL_Image.open(data_stream) rgb = pil_image.split() size = pil_image.size if len(rgb) > 3 and REPLACE_TRANSPARENT: background = PIL_Image.new('RGB', pil_image.size, (255, 255, 255)) background.paste(pil_image, (0, 0, size[0], size[1]), pil_image) pil_image = background if IMAGE_CHANNEL == 1: pil_image = pil_image.convert('L') im = np.array(pil_image) im = preprocessing(im, BINARYZATION, SMOOTH, BLUR).astype(np.float32) if RESIZE[0] == -1: ratio = RESIZE[1] / size[1] resize_width = int(ratio * size[0]) im = cv2.resize(im, (resize_width, RESIZE[1])) else: im = cv2.resize(im, (RESIZE[0], RESIZE[1])) im = im.swapaxes(0, 1) return (im[:, :, np.newaxis] if IMAGE_CHANNEL == 1 else im[:, :]) / 255. return [load_image(index) for index in [img_bytes]] def decode_maps(charset): return {index: char for index, char in enumerate(charset, 0)} def predict_func(image_batch, _sess, dense_decoded, op_input): dense_decoded_code = _sess.run(dense_decoded, feed_dict={ op_input: image_batch, }) decoded_expression = [] for item in dense_decoded_code: expression = '' for char_index in item: if char_index == -1: expression += '' else: expression += decode_maps(GEN_CHAR_SET)[char_index] decoded_expression.append(expression) return ''.join(decoded_expression) if len(decoded_expression) > 1 else decoded_expression[0] if WARP_CTC: import_module('warpctc_tensorflow') graph = tf.Graph() tf_checkpoint = tf.train.latest_checkpoint(MODEL_PATH) sess = tf.Session( graph=graph, config=tf.ConfigProto( gpu_options=tf.GPUOptions( allocator_type='BFC', per_process_gpu_memory_fraction=0.01 )) ) graph_def = graph.as_graph_def() with graph.as_default(): sess.run(tf.global_variables_initializer()) model = GraphOCR( RunMode.Predict, NETWORK_MAP[NEU_CNN], NETWORK_MAP[NEU_RECURRENT] ) model.build_graph() saver = tf.train.Saver(tf.global_variables()) saver.restore(sess, tf.train.latest_checkpoint(MODEL_PATH)) _ = tf.import_graph_def(graph_def, name="") dense_decoded_op = sess.graph.get_tensor_by_name("dense_decoded:0") x_op = sess.graph.get_tensor_by_name('input:0') sess.graph.finalize() def predict_img(img_bytes): batch = get_image_batch(img_bytes) return predict_func( batch, sess, dense_decoded_op, x_op, )
|