Skip to content

Commit

Permalink
fix encoding wrapper for batchencoding
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Dec 10, 2024
1 parent 56f2f08 commit 8cdd245
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion lightning_ir/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ def _cat_outputs(
for key, value in output.items():
agg[key].append(value)
types[key] = type(value)
return OutputClass(**{key: _cat_outputs(value, types[key]) for key, value in agg.items()})
kwargs = {key: _cat_outputs(value, types[key]) for key, value in agg.items()}
if OutputClass is BatchEncoding:
return OutputClass(kwargs)
return OutputClass(**kwargs)


class BatchEncodingWrapper(Protocol):
Expand Down

0 comments on commit 8cdd245

Please sign in to comment.