From 53bc165204ca4a2ecb63537036a2276cfbb75f33 Mon Sep 17 00:00:00 2001 From: Arthur Brussee Date: Mon, 16 Dec 2024 15:32:22 +0000 Subject: [PATCH] Relax Fn requirements (#2620) --- crates/burn-core/src/module/param/base.rs | 26 ++++++++--------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/crates/burn-core/src/module/param/base.rs b/crates/burn-core/src/module/param/base.rs index 5331f54240..b7e3bf8868 100644 --- a/crates/burn-core/src/module/param/base.rs +++ b/crates/burn-core/src/module/param/base.rs @@ -72,14 +72,14 @@ pub trait Parameter: Clone + core::fmt::Debug + Send { #[allow(clippy::type_complexity)] struct Uninitialized { - init: Box P + Send>, + init: Box P + Send>, device: P::Device, is_require_grad: bool, } impl Uninitialized

{ - fn initialize(&self) -> P { - let init = &self.init; + fn initialize(self) -> P { + let init = self.init; init(&self.device, self.is_require_grad) } } @@ -97,7 +97,7 @@ impl Param { /// Create a new parameter that is not already initialized. pub fn uninitialized(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self where - F: Fn(&T::Device, bool) -> T + Send + 'static, + F: FnOnce(&T::Device, bool) -> T + Send + 'static, { Self { id, @@ -120,12 +120,8 @@ impl Param { .expect("Should have an initialization when no state provided.") .write() .unwrap(); - let state = result.as_ref().expect("Should exist when not initialized"); - let tensor = state.initialize(); - - *result = None; - - tensor + let state = result.take().expect("Should exist when not initialized"); + state.initialize() }) .clone() } @@ -145,7 +141,7 @@ impl Param { } /// Execute the given function on the inner value. - pub fn map T>(self, func: F) -> Self { + pub fn map T>(self, func: F) -> Self { let (id, tensor) = self.consume(); let tensor = func(tensor); @@ -251,12 +247,8 @@ impl Deref for Param { .write() .unwrap(); - let state = result.as_ref().expect("Should exist when not initialized"); - let tensor = state.initialize(); - - *result = None; - - tensor + let state = result.take().expect("Should exist when not initialized"); + state.initialize() }) } }