Skip to content

Commit

Permalink
output format option
Browse files Browse the repository at this point in the history
  • Loading branch information
ODAncona committed Feb 24, 2025
1 parent b2b2cec commit e0b44fe
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 59 deletions.
File renamed without changes.
16 changes: 16 additions & 0 deletions src/default_template_xml.hbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

<directory>{{ absolute_code_path }}</directory>

<source-tree>
{{ source_tree }}
</source-tree>

<files>
{{#each files}}
{{#if code}}
<file path="{{ path }}">
{{ code }}
</file>
{{/if}}
{{/each}}
</files>
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub use filter::should_include_file;
pub use git::{get_git_diff, get_git_diff_between_branches, get_git_log};
pub use path::{label, traverse_directory};
pub use sort::{sort_files, sort_tree, FileSortMethod};
pub use template::{handle_undefined_variables, handlebars_setup, render_template, write_to_file};
pub use template::{
handle_undefined_variables, handlebars_setup, render_template, write_to_file, OutputFormat,
};
pub use tokenizer::{count_tokens, TokenFormat, TokenizerType};
pub use util::strip_utf8_bom;
114 changes: 60 additions & 54 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use clap::Parser;
use code2prompt::{
count_tokens, get_git_diff, get_git_diff_between_branches, get_git_log,
handle_undefined_variables, handlebars_setup, label, render_template, traverse_directory,
write_to_file, FileSortMethod, TokenFormat, TokenizerType,
write_to_file, FileSortMethod, OutputFormat, TokenFormat, TokenizerType,
};
use colored::*;
use indicatif::{ProgressBar, ProgressStyle};
Expand All @@ -16,11 +16,7 @@ use num_format::{SystemLocale, ToFormattedString};
use serde_json::json;
use std::path::PathBuf;

// Constants
const DEFAULT_TEMPLATE_NAME: &str = "default";
const CUSTOM_TEMPLATE_NAME: &str = "custom";

// CLI Arguments
// ~~~ CLI Arguments ~~~
#[derive(Parser)]
#[clap(
name = env!("CARGO_PKG_NAME"),
Expand All @@ -45,24 +41,32 @@ struct Cli {
#[clap(long)]
include_priority: bool,

/// Optional output file path
#[clap(short = 'O', long = "output-file")]
output_file: Option<String>,

/// Output format: markdown, json, or xml
#[clap(short = 'F', long = "output-format", default_value = "markdown")]
output_format: OutputFormat,

/// Optional Path to a custom Handlebars template
#[clap(short, long)]
template: Option<PathBuf>,

/// Exclude files/folders from the source tree based on exclude patterns
#[clap(long)]
exclude_from_tree: bool,

/// Display the token count of the generated prompt.
/// Accepts a format: "raw" (machine parsable) or "format" (human readable).
#[clap(long, value_name = "FORMAT", default_value = "format")]
tokens: TokenFormat,

/// Optional tokenizer to use for token count
///
/// Supported tokenizers: cl100k (default), p50k, p50k_edit, r50k, gpt2
#[clap(short = 'c', long)]
encoding: Option<String>,

/// Optional output file path
#[clap(short, long)]
output: Option<String>,
/// Display the token count of the generated prompt.
/// Accepts a format: "raw" (machine parsable) or "format" (human readable).
#[clap(long, value_name = "FORMAT", default_value = "format")]
tokens: TokenFormat,

/// Include git diff
#[clap(short, long)]
Expand All @@ -80,34 +84,26 @@ struct Cli {
#[clap(short, long)]
line_number: bool,

/// Disable wrapping code inside markdown code blocks
#[clap(long)]
no_codeblock: bool,

/// Use relative paths instead of absolute paths, including the parent directory
#[clap(long)]
relative_paths: bool,

/// Optional Disable copying to clipboard
#[clap(long)]
no_clipboard: bool,

/// Optional Path to a custom Handlebars template
#[clap(short, long)]
template: Option<PathBuf>,

/// Print output as JSON
#[clap(long)]
json: bool,

/// Follow symlinks
#[clap(short = 'f', long)]
#[clap(short = 'L', long)]
follow_symlinks: bool,

/// Include hidden directories and files
#[clap(long)]
hidden: bool,

/// Disable wrapping code inside markdown code blocks
#[clap(long)]
no_codeblock: bool,

/// Optional Disable copying to clipboard
#[clap(long)]
no_clipboard: bool,

/// Skip .gitignore rules
#[clap(long)]
no_ignore: bool,
Expand Down Expand Up @@ -138,9 +134,6 @@ fn main() -> Result<()> {
}

// ~~~ Initialization ~~~
// Handlebars Template Setup
let (template_content, template_name) = get_template(&args)?;
let handlebars = handlebars_setup(&template_content, template_name)?;

// Progress Bar Setup
let spinner = setup_spinner("Traversing directory and building tree...");
Expand Down Expand Up @@ -193,6 +186,7 @@ fn main() -> Result<()> {
}
};

// ~~~ Git Related ~~~
// Git Diff
let git_diff = if args.diff {
spinner.set_message("Generating git diff...");
Expand All @@ -201,7 +195,7 @@ fn main() -> Result<()> {
String::new()
};

// git diff two get_git_diff_between_branches
// git diff between two branches
let mut git_diff_branch: String = String::new();
if let Some(branches) = &args.git_diff_branch {
spinner.set_message("Generating git diff between two branches...");
Expand All @@ -214,7 +208,7 @@ fn main() -> Result<()> {
.unwrap_or_default()
}

// git diff two get_git_diff_between_branches
// git log between two branches
let mut git_log_branch: String = String::new();
if let Some(branches) = &args.git_log_branch {
spinner.set_message("Generating git log between two branches...");
Expand All @@ -228,7 +222,8 @@ fn main() -> Result<()> {

spinner.finish_with_message("Done!".green().to_string());

// Prepare JSON Data
// ~~~ Template ~~~
// Template Data
let mut data = json!({
"absolute_code_path": label(&args.path),
"source_tree": tree,
Expand All @@ -243,11 +238,15 @@ fn main() -> Result<()> {
serde_json::to_string_pretty(&data).unwrap()
);

// Handle undefined variables
// Template Setup
let (template_content, template_name) = get_template(&args)?;
let handlebars = handlebars_setup(&template_content, &template_name)?;

// Handle User Defined Variables
handle_undefined_variables(&mut data, &template_content)?;

// Render the template
let rendered = render_template(&handlebars, template_name, &data)?;
// Template Rendering
let rendered = render_template(&handlebars, &template_name, &data)?;

// ~~~ Token Count ~~~
let tokenizer_type = args
Expand All @@ -258,13 +257,13 @@ fn main() -> Result<()> {
.unwrap_or(TokenizerType::Cl100kBase);

let token_count = count_tokens(&rendered, &tokenizer_type);
let model_info = tokenizer_type.description();

let formatted_token_count: String = match args.tokens {
TokenFormat::Raw => token_count.to_string(),
TokenFormat::Format => token_count.to_formatted_string(&SystemLocale::default().unwrap()),
};

let model_info = tokenizer_type.description();

println!(
"{}{}{} Token count: {}, Model info: {}",
"[".bold().white(),
Expand All @@ -274,7 +273,7 @@ fn main() -> Result<()> {
model_info
);

// ~~~ Output ~~~
// ~~~ Informations ~~~
let paths: Vec<String> = files
.iter()
.filter_map(|file| {
Expand All @@ -284,15 +283,15 @@ fn main() -> Result<()> {
})
.collect();

if args.json {
if args.output_format == OutputFormat::Json {
let json_output = json!({
"prompt": rendered,
"directory_name": label(&args.path),
"directory_name": &label(&args.path),
"token_count": token_count,
"model_info": model_info,
"files": paths,
"files": &paths,
});
println!("{}", serde_json::to_string_pretty(&json_output)?);
println!("{}", serde_json::to_string_pretty(&json_output).unwrap());
return Ok(());
}

Expand Down Expand Up @@ -331,7 +330,7 @@ fn main() -> Result<()> {
}

// ~~~ Output File ~~~
if let Some(output_path) = &args.output {
if let Some(output_path) = &args.output_file {
write_to_file(output_path, &rendered)?;
}

Expand Down Expand Up @@ -387,15 +386,22 @@ fn parse_patterns(patterns: &Option<String>) -> Vec<String> {
/// # Returns
///
/// * `Result<(String, &str)>` - A tuple containing the template content and name
fn get_template(args: &Cli) -> Result<(String, &str)> {
fn get_template(args: &Cli) -> Result<(String, String)> {
let format = &args.output_format;
if let Some(template_path) = &args.template {
let content = std::fs::read_to_string(template_path)
.context("Failed to read custom template file")?;
Ok((content, CUSTOM_TEMPLATE_NAME))
Ok((content, "custom".to_string()))
} else {
Ok((
include_str!("default_template.hbs").to_string(),
DEFAULT_TEMPLATE_NAME,
))
match format {
OutputFormat::Markdown | OutputFormat::Json => Ok((
include_str!("default_template_md.hbs").to_string(),
"markdown".to_string(),
)),
OutputFormat::Xml => Ok((
include_str!("default_template_xml.hbs").to_string(),
"xml".to_string(),
)),
}
}
}
1 change: 0 additions & 1 deletion src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use serde_json::json;
use std::fs;
use std::path::Path;
use termtree::Tree;

/// Traverses the directory and returns the string representation of the tree and the vector of JSON file representations.
///
/// # Arguments
Expand Down
2 changes: 1 addition & 1 deletion src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl CodePrompt {

// Setup template
let template_content =
template.unwrap_or_else(|| include_str!("default_template.hbs").to_string());
template.unwrap_or_else(|| include_str!("default_template_md.hbs").to_string());
let handlebars = handlebars_setup(&template_content, "template")
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;

Expand Down
27 changes: 26 additions & 1 deletion src/template.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
//! This module contains the functions to set up the Handlebars template engine and render the template with the provided data.
//! It also includes functions for handling user-defined variables, copying the rendered output to the clipboard, and writing it to a file.
use anyhow::Result;
use anyhow::{anyhow, Result};
use colored::*;
use handlebars::{no_escape, Handlebars};
use inquire::Text;
use regex::Regex;
use std::io::Write;
use std::str::FromStr;

/// Set up the Handlebars template engine with a template string and a template name.
///
Expand Down Expand Up @@ -127,3 +128,27 @@ pub fn write_to_file(output_path: &str, rendered: &str) -> Result<()> {
);
Ok(())
}

/// Enum to represent the output format.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OutputFormat {
Markdown,
Json,
Xml,
}

impl FromStr for OutputFormat {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"markdown" | "md" => Ok(OutputFormat::Markdown),
"json" => Ok(OutputFormat::Json),
"xml" => Ok(OutputFormat::Xml),
_ => Err(anyhow!(
"Invalid output format: {}. Allowed values: markdown, json, xml",
s
)),
}
}
}
2 changes: 1 addition & 1 deletion tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ mod tests {
let mut cmd =
Command::cargo_bin("code2prompt").expect("Failed to find code2prompt binary");
cmd.arg(&self.dir.path().to_str().unwrap())
.arg("--output")
.arg("--output-file")
.arg(&self.output_file)
.arg("--no-clipboard");
cmd
Expand Down

0 comments on commit e0b44fe

Please sign in to comment.