Skip to content

Commit

Permalink
fix models
Browse files Browse the repository at this point in the history
  • Loading branch information
vinicvaz committed Dec 7, 2023
1 parent 4915b51 commit 6a12755
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 44 deletions.
28 changes: 22 additions & 6 deletions .domino/compiled_metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,32 @@
"description": "Input data for TextSummarizerPiece",
"properties": {
"input_file_path": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": "",
"description": "The path to the text file to summarize.",
"title": "Input File Path",
"type": "string"
"from_upstream": "always",
"title": "Input File Path"
},
"input_text": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": "",
"description": "The text to summarize.",
"title": "Input Text",
"type": "string"
"widget": "textarea"
},
"output_type": {
"allOf": [
Expand Down Expand Up @@ -1046,8 +1062,8 @@
],
"default": null,
"description": "Text to summarize",
"required": false,
"title": "Text"
"title": "Text",
"widget": "textarea"
},
"text_file_path": {
"anyOf": [
Expand All @@ -1060,7 +1076,7 @@
],
"default": null,
"description": "Use it only if not using text field. File path to the text to summarize",
"required": false,
"from_upstream": "always",
"title": "Text File Path"
},
"output_type": {
Expand Down
16 changes: 11 additions & 5 deletions pieces/TextSummarizerLocalPiece/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field, FilePath, validators
from typing import Union
from typing import Union, Optional
from enum import Enum


Expand All @@ -12,13 +12,19 @@ class InputModel(BaseModel):
"""
Input data for TextSummarizerPiece
"""
input_file_path: str = Field(
input_file_path: Optional[str] = Field(
description='The path to the text file to summarize.',
default=""
default="",
json_schema_extra={
"from_upstream": "always"
}
)
input_text: str = Field(
input_text: Optional[str] = Field(
description='The text to summarize.',
default=""
default="",
json_schema_extra={
'widget': "textarea",
}
)
output_type: OutputTypeType = Field(
description='The type of output fot the result text.',
Expand Down
73 changes: 43 additions & 30 deletions pieces/TextSummarizerLocalPiece/piece.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,49 @@



def summarize_long_text(text: str, summarizer, iteration: int=0):
"""
Generate the summary by concatenating the summaries of the individual chunks.
"""
iteration += 1
print(f"Iteration: {iteration}")

# Preprocess text
text = text.lower().replace(".", " ").replace(",", " ").replace("\n", " ")
text = "".join(ch if ch.isalnum() or ch == " " else " " for ch in text)

# Split the input text into chunks
chunk_size = 1000
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
print(f"chunks to process: {len(chunks)}")

# Generate the summary for each chunk
summary_list = [
summarizer(chunk, max_length=60, min_length=30, no_repeat_ngram_size=3)[0]['summary_text']
for chunk in chunks
]
summary = " ".join(summary_list)

if len(summary) > 2000:
return summarize_long_text(summary, summarizer, iteration)
else:
return summary


class TextSummarizerLocalPiece(BasePiece):

def summarize_long_text(self, text: str, summarizer, iteration: int=0):
"""
Generate the summary by concatenating the summaries of the individual chunks.
"""
iteration += 1
print(f"Iteration: {iteration}")

# Preprocess text
text = text.lower().replace(".", " ").replace(",", " ").replace("\n", " ")
text = "".join(ch if ch.isalnum() or ch == " " else " " for ch in text)

# Split the input text into chunks
chunk_size = 1000
chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
print(f"chunks to process: {len(chunks)}")

# Generate the summary for each chunk
summary_list = [
summarizer(chunk, max_length=60, min_length=30, no_repeat_ngram_size=3)[0]['summary_text']
for chunk in chunks
]
summary = " ".join(summary_list)

if len(summary) > 2000:
return self.summarize_long_text(summary, summarizer, iteration)
else:
return summary

def format_display_result(self, final_summary: str):
md_text = f"""
## Summarized text
{final_summary}
"""
file_path = f"{self.results_path}/display_result.md"
with open(file_path, "w") as f:
f.write(md_text)
self.display_result = {
"file_type": "md",
"file_path": file_path
}

def piece_function(self, input_data: InputModel):

# Set device
Expand Down Expand Up @@ -65,7 +77,7 @@ def piece_function(self, input_data: InputModel):

# Run summarizer
self.logger.info("Running summarizer...")
result = summarize_long_text(text=text_str, summarizer=summarizer)
result = self.summarize_long_text(text=text_str, summarizer=summarizer)

# Return result
if input_data.output_type == "xcom":
Expand All @@ -81,6 +93,7 @@ def piece_function(self, input_data: InputModel):
with open(output_file_path, "w") as f:
f.write(result)

self.format_display_result(final_summary=result)
return OutputModel(
message=msg,
summary_result=summary_result,
Expand Down
10 changes: 7 additions & 3 deletions pieces/TextSummarizerPiece/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ class LLMModelType(str, Enum):
class InputModel(BaseModel):
"""
TextSummarizerPiece Input model
"""
"""
text: Optional[str] = Field(
default=None,
description="Text to summarize",
required=False # Setting to false because can use text or text_file_path
json_schema_extra={
'widget': "textarea",
}
)
text_file_path: Optional[str] = Field(
default=None,
description="Use it only if not using text field. File path to the text to summarize",
required=False # Setting to false because can use text or text_file_path
json_schema_extra={
"from_upstream": "always"
}
)
output_type: OutputTypeType = Field(
default=OutputTypeType.string,
Expand Down

0 comments on commit 6a12755

Please sign in to comment.