Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Deberta V2 Models #2739

Open
emschwartz opened this issue Jan 25, 2025 · 13 comments
Open

Support Deberta V2 Models #2739

emschwartz opened this issue Jan 25, 2025 · 13 comments

Comments

@emschwartz
Copy link

@BradyBonnette has been maintaining a branch to add support for DebertaV2 models #1177 (comment). It would be amazing to have that merged in 🙏

The one issue I ran into with the current diff is that there's an unsupported operation when running with the metal feature:

RUST_BACKTRACE=1 cargo run  --example debertav2 --features=metal -- --model-id=blaze999/Medical-NER --revision=main --sentence='63 year old woman with history of CAD presented to ER'
    Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.60s
     Running `target/debug/examples/debertav2 --model-id=blaze999/Medical-NER --revision=main '--sentence=63 year old woman with history of CAD presented to ER'`
Loaded model and tokenizers in 2.037694333s
Tokenized and loaded inputs in 1.533875ms
Error: Metal gather I64 F32 not implemented
   0: std::backtrace_rs::backtrace::libunwind::trace
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/../../backtrace/src/backtrace/libunwind.rs:116:5
   1: std::backtrace_rs::backtrace::trace_unsynchronized
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5
   2: std::backtrace::Backtrace::create
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/backtrace.rs:331:13
   3: candle_core::error::Error::bt
             at ./candle-core/src/error.rs:254:25
   4: <candle_core::metal_backend::MetalStorage as candle_core::backend::BackendStorage>::gather
             at ./candle-core/src/error.rs:283:20
   5: candle_core::storage::Storage::gather
             at ./candle-core/src/storage.rs:615:31
   6: candle_core::tensor::Tensor::gather
             at ./candle-core/src/tensor.rs:1555:13
   7: candle_transformers::models::debertav2::DebertaV2DisentangledSelfAttention::disentangled_attention_bias
             at ./candle-transformers/src/models/debertav2.rs:648:23
   8: candle_transformers::models::debertav2::DebertaV2DisentangledSelfAttention::forward
             at ./candle-transformers/src/models/debertav2.rs:461:28
   9: candle_transformers::models::debertav2::DebertaV2Attention::forward
             at ./candle-transformers/src/models/debertav2.rs:742:27
  10: candle_transformers::models::debertav2::DebertaV2Layer::forward
             at ./candle-transformers/src/models/debertav2.rs:892:32
  11: candle_transformers::models::debertav2::DebertaV2Encoder::forward
             at ./candle-transformers/src/models/debertav2.rs:1076:29
  12: candle_transformers::models::debertav2::DebertaV2Model::forward
             at ./candle-transformers/src/models/debertav2.rs:1217:13
  13: candle_transformers::models::debertav2::DebertaV2NERModel::forward
             at ./candle-transformers/src/models/debertav2.rs:1280:22
  14: debertav2::main
             at ./candle-examples/examples/debertav2/main.rs:271:26
  15: core::ops::function::FnOnce::call_once
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/core/src/ops/function.rs:250:5
  16: std::sys::backtrace::__rust_begin_short_backtrace
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/std/src/sys/backtrace.rs:154:18
  17: std::rt::lang_start::{{closure}}
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/std/src/rt.rs:195:18
  18: core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/core/src/ops/function.rs:284:13
  19: std::panicking::try::do_call
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:557:40
  20: std::panicking::try
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:520:19
  21: std::panic::catch_unwind
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panic.rs:358:14
  22: std::rt::lang_start_internal::{{closure}}
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/rt.rs:174:48
  23: std::panicking::try::do_call
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:557:40
  24: std::panicking::try
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:520:19
  25: std::panic::catch_unwind
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panic.rs:358:14
  26: std::rt::lang_start_internal
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/rt.rs:174:20
  27: std::rt::lang_start
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/std/src/rt.rs:194:17
  28: _main


Stack backtrace:
   0: std::backtrace_rs::backtrace::libunwind::trace
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/../../backtrace/src/backtrace/libunwind.rs:116:5
   1: std::backtrace_rs::backtrace::trace_unsynchronized
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5
   2: std::backtrace::Backtrace::create
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/backtrace.rs:331:13
   3: anyhow::error::<impl core::convert::From<E> for anyhow::Error>::from
             at .../.cargo/registry/src/index.crates.io-6f17d22bba15001f/anyhow-1.0.95/src/backtrace.rs:27:14
   4: <core::result::Result<T,F> as core::ops::try_trait::FromResidual<core::result::Result<core::convert::Infallible,E>>>::from_residual
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/core/src/result.rs:2010:27
   5: debertav2::main
             at ./candle-examples/examples/debertav2/main.rs:271:26
   6: core::ops::function::FnOnce::call_once
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/core/src/ops/function.rs:250:5
   7: std::sys::backtrace::__rust_begin_short_backtrace
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/std/src/sys/backtrace.rs:154:18
   8: std::rt::lang_start::{{closure}}
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/std/src/rt.rs:195:18
   9: core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/core/src/ops/function.rs:284:13
  10: std::panicking::try::do_call
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:557:40
  11: std::panicking::try
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:520:19
  12: std::panic::catch_unwind
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panic.rs:358:14
  13: std::rt::lang_start_internal::{{closure}}
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/rt.rs:174:48
  14: std::panicking::try::do_call
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:557:40
  15: std::panicking::try
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panicking.rs:520:19
  16: std::panic::catch_unwind
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/panic.rs:358:14
  17: std::rt::lang_start_internal
             at /rustc/90b35a6239c3d8bdabc530a6a0816f7ff89a0aaf/library/std/src/rt.rs:174:20
  18: std::rt::lang_start
             at .../.rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/std/src/rt.rs:194:17
  19: _main
(llama.cpp) 
@LaurentMazare
Copy link
Collaborator

I've just pushed #2740 which should hopefully add the missing metal kernels.
As for the model itself, we'll need a PR to get this merged at some point.

@BradyBonnette
Copy link

👀

Funny enough I've been working on it a little bit this afternoon, I had something not quite right in the text classification side of things.

Ill sync up my fork here in a little bit, and push up my text classification fixes and see if that works for you @emschwartz

@emschwartz
Copy link
Author

Amazing! I just tried rebasing the fork on @LaurentMazare's change and it does indeed work. Thank you both for your work on this!

@emschwartz
Copy link
Author

Also for what it's worth, I wanted to use nvidia/quality-classifier-deberta and translated their pytorch code sample using @BradyBonnette's branch.

Here's the (AI-generated / modified) version of @BradyBonnette's example:

nvidia/quality-classifier-deberta with candle
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use std::fmt::Display;
use std::path::PathBuf;

use anyhow::ensure;
use anyhow::{Error as E, Result};
use candle::{Device, Module, ModuleT, Tensor, D};
use candle_nn::VarBuilder;
use candle_transformers::models::debertav2::{Config as DebertaV2Config, DebertaV2Model};
use clap::{ArgGroup, Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{
    Encoding, PaddingParams, Tokenizer, TruncationDirection, TruncationParams, TruncationStrategy,
};

struct QualityModel {
    pub device: Device,
    deberta: DebertaV2Model,
    dropout: candle_nn::Dropout,
    classifier: candle_nn::Linear,
}

impl QualityModel {
    pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result<Self> {
        let deberta = DebertaV2Model::load(vb.pp("model"), config)?;
        let dropout = candle_nn::Dropout::new(0.2);
        let classifier = candle_nn::linear(
            config.hidden_size,
            config.id2label.as_ref().unwrap().len(),
            vb.pp("fc"),
        )?;

        Ok(Self {
            device: vb.device().clone(),
            deberta,
            dropout,
            classifier,
        })
    }

    pub fn forward(&self, input_ids: &Tensor, attention_mask: Option<Tensor>) -> Result<Tensor> {
        let outputs = self.deberta.forward(input_ids, None, attention_mask)?;
        let first_token = outputs.narrow(1, 0, 1)?;
        let dropped = self.dropout.forward(&first_token, false)?;
        let logits = self.classifier.forward(&dropped)?;
        let logits = logits.squeeze(1)?;
        // Convert to f32 and i32 for Metal compatibility
        let logits = logits.to_dtype(candle::DType::F32)?;
        Ok(logits)
    }
}

enum TaskType {
    Quality(QualityModel),
}

#[derive(Parser, Debug, Clone, ValueEnum)]
enum ArgsTask {
    Quality,
}

impl Display for ArgsTask {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            ArgsTask::Quality => write!(f, "quality"),
        }
    }
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
#[command(group(ArgGroup::new("model")
    .required(true)
    .args(&["model_id", "model_path"])))]
#[command(group(ArgGroup::new("input")
    .required(true)
    .args(&["sentences", "input_file"])))]
struct Args {
    /// Run on CPU rather than on GPU.
    #[arg(long)]
    cpu: bool,

    /// Enable tracing (generates a trace-timestamp.json file).
    #[arg(long)]
    tracing: bool,

    /// The model id to use from HuggingFace
    #[arg(long, requires_if("model_id", "revision"))]
    model_id: Option<String>,

    /// Revision of the model to use (default: "main")
    #[arg(long, default_value = "main")]
    revision: String,

    /// Specify a sentence to inference. Specify multiple times to inference multiple sentences.
    #[arg(long = "sentence", name="sentences", num_args = 1.., group="input")]
    sentences: Vec<String>,

    /// Path to a file containing sentences to inference (one per line)
    #[arg(long, group = "input")]
    input_file: Option<String>,

    /// Use the pytorch weights rather than the by-default safetensors
    #[arg(long)]
    use_pth: bool,

    /// Perform a very basic benchmark on inferencing, using N number of iterations
    #[arg(long)]
    benchmark_iters: Option<usize>,

    /// Which task to run
    #[arg(long, default_value_t = ArgsTask::Quality)]
    task: ArgsTask,

    /// Use model from a specific directory instead of HuggingFace local cache.
    /// Using this ignores model_id and revision args.
    #[arg(long)]
    model_path: Option<PathBuf>,
}

impl Args {
    fn build_model_and_tokenizer(&self) -> Result<(TaskType, DebertaV2Config, Tokenizer)> {
        let device = candle_examples::device(self.cpu)?;

        // Get files from either the HuggingFace API, or from a specified local directory.
        let (config_filename, tokenizer_filename, weights_filename) = {
            match &self.model_path {
                Some(base_path) => {
                    ensure!(
                        base_path.is_dir(),
                        std::io::Error::new(
                            std::io::ErrorKind::Other,
                            format!("Model path {} is not a directory.", base_path.display()),
                        )
                    );

                    let config = base_path.join("config.json");
                    let tokenizer = base_path.join("tokenizer.json");
                    let weights = if self.use_pth {
                        base_path.join("pytorch_model.bin")
                    } else {
                        base_path.join("model.safetensors")
                    };
                    (config, tokenizer, weights)
                }
                None => {
                    let model_id = self
                        .model_id
                        .as_deref()
                        .unwrap_or("nvidia/quality-classifier-deberta");
                    let repo = Repo::with_revision(
                        model_id.to_string(),
                        RepoType::Model,
                        self.revision.clone(),
                    );
                    let api = Api::new()?;
                    let api = api.repo(repo);
                    let config = api.get("config.json")?;
                    let tokenizer = api.get("tokenizer.json")?;
                    let weights = if self.use_pth {
                        api.get("pytorch_model.bin")?
                    } else {
                        api.get("model.safetensors")?
                    };
                    (config, tokenizer, weights)
                }
            }
        };
        let config = std::fs::read_to_string(config_filename)?;
        // Parse the minimal config first
        let minimal_config: serde_json::Value = serde_json::from_str(&config)?;

        // Get the base model config
        let base_model = minimal_config["base_model"]
            .as_str()
            .unwrap_or("microsoft/deberta-v3-base");
        let api = Api::new()?;
        let repo = Repo::new(base_model.to_string(), RepoType::Model);
        let base_config = api.repo(repo).get("config.json")?;
        let base_config = std::fs::read_to_string(base_config)?;
        let mut config: DebertaV2Config = serde_json::from_str(&base_config)?;

        // Override with values from minimal config
        if let Some(id2label) = minimal_config.get("id2label") {
            config.id2label = Some(serde_json::from_value(id2label.clone())?);
        }
        if let Some(label2id) = minimal_config.get("label2id") {
            config.label2id = Some(serde_json::from_value(label2id.clone())?);
        }

        let mut tokenizer = Tokenizer::from_file(tokenizer_filename)
            .map_err(|e| candle::Error::Msg(format!("Tokenizer error: {e}")))?;
        tokenizer.with_padding(Some(PaddingParams::default()));
        tokenizer
            .with_truncation(Some(TruncationParams::default()))
            .unwrap();

        let vb = if self.use_pth {
            VarBuilder::from_pth(
                &weights_filename,
                candle_transformers::models::debertav2::DTYPE,
                &device,
            )?
        } else {
            unsafe {
                VarBuilder::from_mmaped_safetensors(
                    &[weights_filename],
                    candle_transformers::models::debertav2::DTYPE,
                    &device,
                )?
            }
        };

        match self.task {
            ArgsTask::Quality => Ok((
                TaskType::Quality(QualityModel::load(vb, &config)?),
                config,
                tokenizer,
            )),
        }
    }

    fn get_sentences(&self) -> Result<Vec<String>> {
        if let Some(file_path) = &self.input_file {
            Ok(vec![std::fs::read_to_string(file_path)?])
        } else {
            Ok(self.sentences.clone())
        }
    }
}

fn get_device(model_type: &TaskType) -> &Device {
    match model_type {
        TaskType::Quality(model) => &model.device,
    }
}

enum InputEncoding {
    Single(Encoding),
    Batch(Vec<Encoding>),
}

struct ModelInput {
    encoding: InputEncoding,
    input_ids: Tensor,
    attention_mask: Tensor,
}

fn main() -> Result<()> {
    use tracing_chrome::ChromeLayerBuilder;
    use tracing_subscriber::prelude::*;

    let args = Args::parse();

    if args.model_id.is_some() && args.model_path.is_some() {
        eprintln!("Error: Cannot specify both --model_id and --model_path.");
        std::process::exit(1);
    }

    let _guard = if args.tracing {
        let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
        tracing_subscriber::registry().with(chrome_layer).init();
        Some(guard)
    } else {
        None
    };

    let model_load_time = std::time::Instant::now();
    let (task_type, model_config, tokenizer) = args.build_model_and_tokenizer()?;
    println!(
        "Loaded model and tokenizers in {:?}",
        model_load_time.elapsed()
    );

    let device = get_device(&task_type);

    let tokenize_time = std::time::Instant::now();
    let sentences = args.get_sentences()?;
    let model_input: ModelInput = match sentences.len() {
        1 => {
            let encoding = tokenizer
                .encode(sentences.first().unwrap().as_str(), true)
                .map_err(E::msg)?;

            ModelInput {
                input_ids: Tensor::new(&encoding.get_ids()[..], &device)?.unsqueeze(0)?,
                attention_mask: Tensor::new(&encoding.get_attention_mask()[..], &device)?
                    .unsqueeze(0)?,
                encoding: InputEncoding::Single(encoding),
            }
        }
        _ => {
            let tokenizer_encodings = tokenizer.encode_batch(sentences, true).map_err(E::msg)?;

            let mut encoding_stack: Vec<Tensor> = Vec::default();
            let mut attention_mask_stack: Vec<Tensor> = Vec::default();

            for encoding in &tokenizer_encodings {
                encoding_stack.push(Tensor::new(encoding.get_ids(), &device)?);
                attention_mask_stack.push(Tensor::new(encoding.get_attention_mask(), &device)?);
            }

            ModelInput {
                encoding: InputEncoding::Batch(tokenizer_encodings),
                input_ids: Tensor::stack(&encoding_stack[..], 0)?,
                attention_mask: Tensor::stack(&attention_mask_stack[..], 0)?,
            }
        }
    };

    println!(
        "Tokenized and loaded inputs in {:?}",
        tokenize_time.elapsed()
    );

    match task_type {
        TaskType::Quality(quality_model) => {
            let inference_time = std::time::Instant::now();
            let logits =
                quality_model.forward(&model_input.input_ids, Some(model_input.attention_mask))?;
            println!("Inferenced inputs in {:?}", inference_time.elapsed());

            let scores = candle_nn::ops::softmax(&logits, D::Minus1)?;
            let predicted_labels = scores.argmax(D::Minus1)?;
            // Convert to f32 for Metal compatibility
            let predicted_labels = predicted_labels.to_dtype(candle::DType::F32)?;

            println!(
                "\nPredicted labels: {:?}",
                predicted_labels.to_vec1::<f32>()?
            );
            if let Some(id2label) = model_config.id2label {
                let labels: Vec<_> = predicted_labels
                    .to_vec1::<f32>()?
                    .iter()
                    .map(|&id| id2label.get(&(id as u32)).unwrap().clone())
                    .collect();
                println!("Predicted classes: {:?}", labels);
            }

            println!("\nConfidence scores:");
            println!("{:?}", scores.to_vec2::<f32>()?);
        }
    }
    Ok(())
}

fn create_benchmark<F>(
    num_iters: usize,
    model_input: ModelInput,
) -> impl Fn(F) -> Result<(), candle::Error>
where
    F: Fn(&Tensor, Tensor) -> Result<(), candle::Error>,
{
    move |code: F| -> Result<(), candle::Error> {
        println!("Running {num_iters} iterations...");
        let mut durations = Vec::with_capacity(num_iters);
        for _ in 0..num_iters {
            let attention_mask = model_input.attention_mask.clone();
            let start = std::time::Instant::now();
            code(&model_input.input_ids, attention_mask)?;
            let duration = start.elapsed();
            durations.push(duration.as_nanos());
        }

        let min_time = *durations.iter().min().unwrap();
        let max_time = *durations.iter().max().unwrap();
        let avg_time = durations.iter().sum::<u128>() as f64 / num_iters as f64;

        println!("Min time: {:.3} ms", min_time as f64 / 1_000_000.0);
        println!("Avg time: {:.3} ms", avg_time / 1_000_000.0);
        println!("Max time: {:.3} ms", max_time as f64 / 1_000_000.0);
        Ok(())
    }
}

@BradyBonnette
Copy link

@emschwartz I just pushed up my latest updates, which also contains the latest sync from the main HuggingFace repo (which includes the metal updates).

Feel free to give it a try directly from my fork if you wish. It now does token classification and text classification. I am looking over some of the stuff that I did originally in my main.rs example file, and wondering what I was thinking lol. I think I can shorten some of that up. I still don't have anything in the readme yet since I am making sure everything else still works.

@BradyBonnette
Copy link

@emschwartz I added a preliminary README included with my example. If you want to try anything out, feel free. It's probably good someone other than me tries it out before I make a pull request :)

@emschwartz
Copy link
Author

Looks pretty good to me! My one suggestion about the readme would be to have the first example be one that pulls a model from HuggingFace rather than assuming the user has a locally fine-tuned model. That feels like a more advanced use of this. Might also be worth having a note about using the GPU feature flags that correspond to the platform you're on (because the examples all include cuda).

Thanks again for your work on this!

@BradyBonnette
Copy link

BradyBonnette commented Jan 26, 2025

@emschwartz No problem! My pleasure.

Thanks for the feedback, I can definitely do that. I didn't have any agenda in mind when I put the examples in that order other than "that's how they appeared in my history" lol. I just added it in there to show that you can use a local model if you choose, but I guess in reality even HuggingFace models are "local" at some point, my severely bloated ~/.cache directory is a testimony to that.

I suppose the only thing about the GPU feature flags is that they're flags of candle-examples crate, not the example itself, so I would suppose if the reader wanted to find out more they'd have to look that up. I could make a note with something along the lines of "adjust feature flags to match your system".

I could add an example showing how to run it on the CPU, which is really nothing more than adding --cpu to the command. But it might be good to demonstrate.

Were you able to try out my stuff with the quality-classifier thing you were working with?

@emschwartz
Copy link
Author

Makes sense!

Were you able to try out my stuff with the quality-classifier thing you were working with?

Yup! Works like a charm 😊

@BradyBonnette
Copy link

Glad to hear it. Tomorrow Ill change up some of the documentation a little bit, then Ill start up a pull request. I am almost positive there will be some things that need changed or added, and that's OK. I am just glad to see that it's working for someone other than my development machine here! :)

@emschwartz
Copy link
Author

@BradyBonnette I just ran into a small inconsistency between the results I'm seeing using Candle vs PyTorch. The tokenization is slightly different and seems to be due to tokens being parsed such as "_Light" or "Light".

I saw in the other thread that you mentioned something about SentencePiece and loading the spm.model. Did you figure out how to get that to work? I don't see a mention of that in your example code.

Thanks again for your help!

@BradyBonnette
Copy link

@emschwartz Yeah I do remember there being something with spm.model files that was weird, but for the life of me I cannot remember. Do you have an example I can run? If I remember correctly, SentencePiece tokens when parsed should always have the "special underscore" in front of them.

@BradyBonnette
Copy link

@emschwartz I've got a draft PR up, but haven't submitted it yet. If there's anything else you can find let me know. I did a complete rebase/squash of all my commits in preparation for the PR, so if you were using anything from my fork you might have to reconcile that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants