test_doclayout.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import unittest
  2. from unittest.mock import patch, MagicMock
  3. import numpy as np
  4. from pdf2zh.doclayout import (
  5. OnnxModel,
  6. YoloResult,
  7. YoloBox,
  8. )
  9. class TestOnnxModel(unittest.TestCase):
  10. @patch("onnx.load")
  11. @patch("onnxruntime.InferenceSession")
  12. def setUp(self, mock_inference_session, mock_onnx_load):
  13. # Mock ONNX model metadata
  14. mock_model = MagicMock()
  15. mock_model.metadata_props = [
  16. MagicMock(key="stride", value="32"),
  17. MagicMock(key="names", value="['class1', 'class2']"),
  18. ]
  19. mock_onnx_load.return_value = mock_model
  20. # Initialize OnnxModel with a fake path
  21. self.model_path = "fake_model_path.onnx"
  22. self.model = OnnxModel(self.model_path)
  23. def test_stride_property(self):
  24. # Test that stride is correctly set from model metadata
  25. self.assertEqual(self.model.stride, 32)
  26. def test_resize_and_pad_image(self):
  27. # Create a dummy image (100x200)
  28. image = np.ones((100, 200, 3), dtype=np.uint8)
  29. resized_image = self.model.resize_and_pad_image(image, 1024)
  30. # Validate the output shape
  31. self.assertEqual(resized_image.shape[0], 512)
  32. self.assertEqual(resized_image.shape[1], 1024)
  33. # Check that padding has been added
  34. padded_height = resized_image.shape[0] - image.shape[0]
  35. padded_width = resized_image.shape[1] - image.shape[1]
  36. self.assertGreater(padded_height, 0)
  37. self.assertGreater(padded_width, 0)
  38. def test_scale_boxes(self):
  39. img1_shape = (1024, 1024) # Model input shape
  40. img0_shape = (500, 300) # Original image shape
  41. boxes = np.array([[512, 512, 768, 768]]) # Example bounding box
  42. scaled_boxes = self.model.scale_boxes(img1_shape, boxes, img0_shape)
  43. # Verify the output is scaled correctly
  44. self.assertEqual(scaled_boxes.shape, boxes.shape)
  45. self.assertTrue(np.all(scaled_boxes <= max(img0_shape)))
  46. def test_predict(self):
  47. # Mock model inference output
  48. mock_output = np.random.random((1, 300, 6))
  49. self.model.model.run.return_value = [mock_output]
  50. # Create a dummy image
  51. image = np.ones((500, 300, 3), dtype=np.uint8)
  52. results = self.model.predict(image)
  53. # Validate predictions
  54. self.assertEqual(len(results), 1)
  55. self.assertIsInstance(results[0], YoloResult)
  56. self.assertGreater(len(results[0].boxes), 0)
  57. self.assertIsInstance(results[0].boxes[0], YoloBox)
  58. class TestYoloResult(unittest.TestCase):
  59. def test_yolo_result(self):
  60. # Example prediction data
  61. boxes = [
  62. [100, 200, 300, 400, 0.9, 0],
  63. [50, 100, 150, 200, 0.8, 1],
  64. ]
  65. names = ["class1", "class2"]
  66. result = YoloResult(boxes, names)
  67. # Validate the number of boxes and their order by confidence
  68. self.assertEqual(len(result.boxes), 2)
  69. self.assertGreater(result.boxes[0].conf, result.boxes[1].conf)
  70. self.assertEqual(result.names, names)
  71. class TestYoloBox(unittest.TestCase):
  72. def test_yolo_box(self):
  73. # Example box data
  74. box_data = [100, 200, 300, 400, 0.9, 0]
  75. box = YoloBox(box_data)
  76. # Validate box properties
  77. self.assertEqual(box.xyxy, box_data[:4])
  78. self.assertEqual(box.conf, box_data[4])
  79. self.assertEqual(box.cls, box_data[5])
  80. if __name__ == "__main__":
  81. unittest.main()