diff --git a/czv-python/tests/test_count.py b/czv-python/tests/test_count.py index c796835..e5b2f9a 100644 --- a/czv-python/tests/test_count.py +++ b/czv-python/tests/test_count.py @@ -2,12 +2,12 @@ import czv import pytest from .test_data import test_data -class TestCountFunc: +class TestRowCount: @pytest.mark.parametrize( "file_name,expected", [("fruits.csv", 3), ("constituents_altnames.csv", 33971)], ) - def test_count(self, file_name, expected): + def test_row_count(self, file_name, expected): """Count the total number of non-header rows.""" result = czv.row_count(file_path=test_data[file_name]) @@ -17,8 +17,19 @@ class TestCountFunc: "file_name,expected", [("fruits.csv", 4), ("constituents_altnames.csv", 33972)], ) - def test_include_header_row(self, file_name, expected): + def test_row_count_include_header_row(self, file_name, expected): """Count the total number of rows including the header row.""" result = czv.row_count(file_path=test_data[file_name], include_header_row=True) assert result == expected + +class TestColumnCount: + @pytest.mark.parametrize( + "file_name,expected", + [("fruits.csv", 2), ("constituents_altnames.csv", 6)], + ) + def test_column_count(self, file_name, expected): + """Count the total number of columns.""" + + result = czv.column_count(file_path=test_data[file_name]) + assert result == expected