diff --git a/storey/sources.py b/storey/sources.py index 95aba149..c215192b 100644 --- a/storey/sources.py +++ b/storey/sources.py @@ -1,4 +1,5 @@ import asyncio +import copy import csv import math import queue @@ -510,6 +511,7 @@ def __init__(self, paths: Union[List[str], str], header: bool = False, build_dic def _init(self): self._event_buffer = queue.Queue(1024) self._types = [] + self._none_columns = set() def _infer_type(self, value): lowercase = value.lower() @@ -534,6 +536,9 @@ def _infer_type(self, value): except ValueError: pass + if value == '': + return 'n' + return 's' def _parse_field(self, field, index): @@ -557,6 +562,8 @@ def _parse_field(self, field, index): if field == '': return None return self._datetime_from_timestamp(field) + if typ == 'n': + return None raise TypeError(f'Unknown type: {typ}') def _datetime_from_timestamp(self, timestamp): @@ -583,8 +590,17 @@ def _blocking_io_loop(self): parsed_line = next(csv.reader([line])) if self._type_inference: if not self._types: - for field in parsed_line: - self._types.append(self._infer_type(field)) + for index, field in enumerate(parsed_line): + type_field = self._infer_type(field) + self._types.append(type_field) + if type_field == 'n': + self._none_columns.add(index) + else: + for index in copy.copy(self._none_columns): + type_field = self._infer_type(parsed_line[index]) + if type_field != 'n': + self._types[index] = type_field + self._none_columns.remove(index) for i in range(len(parsed_line)): parsed_line[i] = self._parse_field(parsed_line[i], i) element = parsed_line diff --git a/tests/test_flow.py b/tests/test_flow.py index 19f84f35..f867bb45 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -2554,3 +2554,25 @@ def test_custom_string_time_input(): result = awaitable_result.await_result() controller.terminate() assert result.time == datetime(2021, 5, 9, 14, 5, 27, tzinfo=pytz.utc) + + +def test_csv_none_value_first_row(tmpdir): + out_file_par = f'{tmpdir}/test_csv_none_value_first_row_{uuid.uuid4().hex}.parquet' + out_file_csv = f'{tmpdir}/test_csv_none_value_first_row_{uuid.uuid4().hex}.csv' + + columns = ['first_name', "bid", "bool", "time"] + data = pd.DataFrame([['katya', None, None, None], ['dina', 45.7, True, datetime(2021, 4, 21, 15, 56, 53, 385444)]], + columns=columns) + data.to_csv(out_file_csv) + + controller = build_flow([ + CSVSource(out_file_csv, header=True, key_field='first_name', build_dict=True), + ParquetTarget(out_file_par) + ]).run() + + controller.await_termination() + read_back_df = pd.read_parquet(out_file_par) + + for c in columns: + assert read_back_df.dtypes.to_dict()[c] == data.dtypes.to_dict()[c] +