test_chunk_localizer.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import pytest
  2. from openhands.utils.chunk_localizer import (
  3. Chunk,
  4. create_chunks,
  5. get_top_k_chunk_matches,
  6. normalized_lcs,
  7. )
  8. def test_chunk_creation():
  9. chunk = Chunk(text='test chunk', line_range=(1, 1))
  10. assert chunk.text == 'test chunk'
  11. assert chunk.line_range == (1, 1)
  12. assert chunk.normalized_lcs is None
  13. def test_chunk_visualization(capsys):
  14. chunk = Chunk(text='line1\nline2', line_range=(1, 2))
  15. assert chunk.visualize() == '1|line1\n2|line2\n'
  16. def test_create_chunks_raw_string():
  17. text = 'line1\nline2\nline3\nline4\nline5'
  18. chunks = create_chunks(text, size=2)
  19. assert len(chunks) == 3
  20. assert chunks[0].text == 'line1\nline2'
  21. assert chunks[0].line_range == (1, 2)
  22. assert chunks[1].text == 'line3\nline4'
  23. assert chunks[1].line_range == (3, 4)
  24. assert chunks[2].text == 'line5'
  25. assert chunks[2].line_range == (5, 5)
  26. def test_normalized_lcs():
  27. chunk = 'abcdef'
  28. edit_draft = 'abcxyz'
  29. assert normalized_lcs(chunk, edit_draft) == 0.5
  30. def test_get_top_k_chunk_matches():
  31. text = 'chunk1\nchunk2\nchunk3\nchunk4'
  32. query = 'chunk2'
  33. matches = get_top_k_chunk_matches(text, query, k=2, max_chunk_size=1)
  34. assert len(matches) == 2
  35. assert matches[0].text == 'chunk2'
  36. assert matches[0].line_range == (2, 2)
  37. assert matches[0].normalized_lcs == 1.0
  38. assert matches[1].text == 'chunk1'
  39. assert matches[1].line_range == (1, 1)
  40. assert matches[1].normalized_lcs == 5 / 6
  41. assert matches[0].normalized_lcs > matches[1].normalized_lcs
  42. def test_create_chunks_with_empty_lines():
  43. text = 'line1\n\nline3\n\n\nline6'
  44. chunks = create_chunks(text, size=2)
  45. assert len(chunks) == 3
  46. assert chunks[0].text == 'line1\n'
  47. assert chunks[0].line_range == (1, 2)
  48. assert chunks[1].text == 'line3\n'
  49. assert chunks[1].line_range == (3, 4)
  50. assert chunks[2].text == '\nline6'
  51. assert chunks[2].line_range == (5, 6)
  52. def test_create_chunks_with_large_size():
  53. text = 'line1\nline2\nline3'
  54. chunks = create_chunks(text, size=10)
  55. assert len(chunks) == 1
  56. assert chunks[0].text == text
  57. assert chunks[0].line_range == (1, 3)
  58. def test_create_chunks_with_last_chunk_smaller():
  59. text = 'line1\nline2\nline3'
  60. chunks = create_chunks(text, size=2)
  61. assert len(chunks) == 2
  62. assert chunks[0].text == 'line1\nline2'
  63. assert chunks[0].line_range == (1, 2)
  64. assert chunks[1].text == 'line3'
  65. assert chunks[1].line_range == (3, 3)
  66. def test_normalized_lcs_edge_cases():
  67. assert normalized_lcs('', '') == 0.0
  68. assert normalized_lcs('a', '') == 0.0
  69. assert normalized_lcs('', 'a') == 0.0
  70. assert normalized_lcs('abcde', 'ace') == 0.6
  71. def test_get_top_k_chunk_matches_with_ties():
  72. text = 'chunk1\nchunk2\nchunk3\nchunk1'
  73. query = 'chunk'
  74. matches = get_top_k_chunk_matches(text, query, k=3, max_chunk_size=1)
  75. assert len(matches) == 3
  76. assert all(match.normalized_lcs == 5 / 6 for match in matches)
  77. assert {match.text for match in matches} == {'chunk1', 'chunk2', 'chunk3'}
  78. def test_get_top_k_chunk_matches_with_large_k():
  79. text = 'chunk1\nchunk2\nchunk3'
  80. query = 'chunk'
  81. matches = get_top_k_chunk_matches(text, query, k=10, max_chunk_size=1)
  82. assert len(matches) == 3 # Should return all chunks even if k is larger
  83. @pytest.mark.parametrize('chunk_size', [1, 2, 3, 4])
  84. def test_create_chunks_different_sizes(chunk_size):
  85. text = 'line1\nline2\nline3\nline4'
  86. chunks = create_chunks(text, size=chunk_size)
  87. assert len(chunks) == (4 + chunk_size - 1) // chunk_size
  88. assert sum(len(chunk.text.split('\n')) for chunk in chunks) == 4
  89. def test_chunk_visualization_with_special_characters():
  90. chunk = Chunk(text='line1\nline2\t\nline3\r', line_range=(1, 3))
  91. assert chunk.visualize() == '1|line1\n2|line2\t\n3|line3\r\n'
  92. def test_normalized_lcs_with_unicode():
  93. chunk = 'Hello, 世界!'
  94. edit_draft = 'Hello, world!'
  95. assert 0 < normalized_lcs(chunk, edit_draft) < 1
  96. def test_get_top_k_chunk_matches_with_overlapping_chunks():
  97. text = 'chunk1\nchunk2\nchunk3\nchunk4'
  98. query = 'chunk2\nchunk3'
  99. matches = get_top_k_chunk_matches(text, query, k=2, max_chunk_size=2)
  100. assert len(matches) == 2
  101. assert matches[0].text == 'chunk1\nchunk2'
  102. assert matches[0].line_range == (1, 2)
  103. assert matches[1].text == 'chunk3\nchunk4'
  104. assert matches[1].line_range == (3, 4)
  105. assert matches[0].normalized_lcs == matches[1].normalized_lcs