Skip to content

Commit

Permalink
Expose a reference to the mapping from the ground truth data label to…
Browse files Browse the repository at this point in the history
… the assigned binary label used in the logistic regression model
  • Loading branch information
DGPardo committed Dec 24, 2023
1 parent 0d8b62a commit cf15ccd
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions algorithms/linfa-logistic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,12 @@ impl<F: Float, C: PartialOrd + Clone> FittedLogisticRegression<F, C> {
&self.params
}

/// Get the model positive and negative classes mapped to their
/// corresponding problem input labels.
pub fn labels(&self) -> &BinaryClassLabels<F, C> {
&self.labels
}

/// Given a feature matrix, predict the probabilities that a sample
/// should be classified as the larger of the two classes learned when the
/// model was fitted.
Expand Down Expand Up @@ -745,9 +751,9 @@ impl<C: PartialOrd + Clone + Default, F: Float, D: Data<Elem = F>>
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
struct ClassLabel<F, C: PartialOrd> {
class: C,
label: F,
pub struct ClassLabel<F, C: PartialOrd> {
pub class: C,
pub label: F,
}

#[derive(Debug, Clone, PartialEq)]
Expand All @@ -756,9 +762,9 @@ struct ClassLabel<F, C: PartialOrd> {
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
struct BinaryClassLabels<F, C: PartialOrd> {
pos: ClassLabel<F, C>,
neg: ClassLabel<F, C>,
pub struct BinaryClassLabels<F, C: PartialOrd> {
pub pos: ClassLabel<F, C>,
pub neg: ClassLabel<F, C>,
}

/// Internal representation of a logistic regression problem.
Expand Down Expand Up @@ -1008,6 +1014,8 @@ mod test {
&res.predict(dataset.records()),
dataset.targets().as_single_targets()
);
assert_eq!(res.labels().pos.class, "dog");
assert_eq!(res.labels().neg.class, "cat");
}

#[test]
Expand Down

0 comments on commit cf15ccd

Please sign in to comment.