Skip to content

Commit

Permalink
ML-625: csv - discovering type if None is in first row (#226)
Browse files Browse the repository at this point in the history
* ML-625: csv - discovering type if None is in first row

* pr comments

* pr comments2
  • Loading branch information
katyakats authored Jun 1, 2021
1 parent 80a1246 commit 83de86e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
20 changes: 18 additions & 2 deletions storey/sources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import csv
import math
import queue
Expand Down Expand Up @@ -510,6 +511,7 @@ def __init__(self, paths: Union[List[str], str], header: bool = False, build_dic
def _init(self):
self._event_buffer = queue.Queue(1024)
self._types = []
self._none_columns = set()

def _infer_type(self, value):
lowercase = value.lower()
Expand All @@ -534,6 +536,9 @@ def _infer_type(self, value):
except ValueError:
pass

if value == '':
return 'n'

return 's'

def _parse_field(self, field, index):
Expand All @@ -557,6 +562,8 @@ def _parse_field(self, field, index):
if field == '':
return None
return self._datetime_from_timestamp(field)
if typ == 'n':
return None
raise TypeError(f'Unknown type: {typ}')

def _datetime_from_timestamp(self, timestamp):
Expand All @@ -583,8 +590,17 @@ def _blocking_io_loop(self):
parsed_line = next(csv.reader([line]))
if self._type_inference:
if not self._types:
for field in parsed_line:
self._types.append(self._infer_type(field))
for index, field in enumerate(parsed_line):
type_field = self._infer_type(field)
self._types.append(type_field)
if type_field == 'n':
self._none_columns.add(index)
else:
for index in copy.copy(self._none_columns):
type_field = self._infer_type(parsed_line[index])
if type_field != 'n':
self._types[index] = type_field
self._none_columns.remove(index)
for i in range(len(parsed_line)):
parsed_line[i] = self._parse_field(parsed_line[i], i)
element = parsed_line
Expand Down
22 changes: 22 additions & 0 deletions tests/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2554,3 +2554,25 @@ def test_custom_string_time_input():
result = awaitable_result.await_result()
controller.terminate()
assert result.time == datetime(2021, 5, 9, 14, 5, 27, tzinfo=pytz.utc)


def test_csv_none_value_first_row(tmpdir):
out_file_par = f'{tmpdir}/test_csv_none_value_first_row_{uuid.uuid4().hex}.parquet'
out_file_csv = f'{tmpdir}/test_csv_none_value_first_row_{uuid.uuid4().hex}.csv'

columns = ['first_name', "bid", "bool", "time"]
data = pd.DataFrame([['katya', None, None, None], ['dina', 45.7, True, datetime(2021, 4, 21, 15, 56, 53, 385444)]],
columns=columns)
data.to_csv(out_file_csv)

controller = build_flow([
CSVSource(out_file_csv, header=True, key_field='first_name', build_dict=True),
ParquetTarget(out_file_par)
]).run()

controller.await_termination()
read_back_df = pd.read_parquet(out_file_par)

for c in columns:
assert read_back_df.dtypes.to_dict()[c] == data.dtypes.to_dict()[c]

0 comments on commit 83de86e

Please sign in to comment.