Skip to content

Commit

Permalink
Relax Fn requirements (#2620)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee authored Dec 16, 2024
1 parent dda336a commit 53bc165
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions crates/burn-core/src/module/param/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ pub trait Parameter: Clone + core::fmt::Debug + Send {

#[allow(clippy::type_complexity)]
struct Uninitialized<P: Parameter> {
init: Box<dyn Fn(&P::Device, bool) -> P + Send>,
init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
device: P::Device,
is_require_grad: bool,
}

impl<P: Parameter> Uninitialized<P> {
fn initialize(&self) -> P {
let init = &self.init;
fn initialize(self) -> P {
let init = self.init;
init(&self.device, self.is_require_grad)
}
}
Expand All @@ -97,7 +97,7 @@ impl<T: Parameter> Param<T> {
/// Create a new parameter that is not already initialized.
pub fn uninitialized<F>(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self
where
F: Fn(&T::Device, bool) -> T + Send + 'static,
F: FnOnce(&T::Device, bool) -> T + Send + 'static,
{
Self {
id,
Expand All @@ -120,12 +120,8 @@ impl<T: Parameter> Param<T> {
.expect("Should have an initialization when no state provided.")
.write()
.unwrap();
let state = result.as_ref().expect("Should exist when not initialized");
let tensor = state.initialize();

*result = None;

tensor
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
.clone()
}
Expand All @@ -145,7 +141,7 @@ impl<T: Parameter> Param<T> {
}

/// Execute the given function on the inner value.
pub fn map<F: Fn(T) -> T>(self, func: F) -> Self {
pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
let (id, tensor) = self.consume();
let tensor = func(tensor);

Expand Down Expand Up @@ -251,12 +247,8 @@ impl<T: Parameter> Deref for Param<T> {
.write()
.unwrap();

let state = result.as_ref().expect("Should exist when not initialized");
let tensor = state.initialize();

*result = None;

tensor
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
}
}

0 comments on commit 53bc165

Please sign in to comment.