diff --git a/loki/program_unit.py b/loki/program_unit.py
index b7b51a072..daf501b99 100644
--- a/loki/program_unit.py
+++ b/loki/program_unit.py
@@ -327,7 +327,8 @@ def enrich(self, definitions, recurse=False):
         """
         definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions))
 
-        for imprt in self.imports:
+        # Enrich type info from all known imports (including parent scopes)
+        for imprt in self.all_imports:
             if not (module := definitions_map.get(imprt.module)):
                 # Skip modules that are not available in the definitions list
                 continue
diff --git a/loki/tests/test_modules.py b/loki/tests/test_modules.py
index a8397f75a..668e95322 100644
--- a/loki/tests/test_modules.py
+++ b/loki/tests/test_modules.py
@@ -7,15 +7,16 @@
 
 import pytest
 
-from loki import (
-    Module, Subroutine, VariableDeclaration, TypeDef, fexprgen,
-    BasicType, Assignment, FindNodes, FindInlineCalls, FindTypedSymbols,
-    Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic,
-    Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal
-)
+from loki import Module, Subroutine, fexprgen, fgen
 from loki.build import jit_compile, clean_test
+from loki.expression import symbols as sym
 from loki.frontend import available_frontends, OMNI
+from loki.ir import (
+    nodes as ir, FindNodes, FindInlineCalls, FindTypedSymbols,
+    FindVariables, SubstituteExpressions, Transformer
+)
 from loki.sourcefile import Sourcefile
+from loki.types import BasicType, DerivedType, SymbolAttributes
 
 
 @pytest.mark.parametrize('frontend', available_frontends())
@@ -40,8 +41,8 @@ def test_module_from_source(frontend, tmp_path):
 end module a_module
 """.strip()
     module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
-    assert len([o for o in module.spec.body if isinstance(o, VariableDeclaration)]) == 2
-    assert len([o for o in module.spec.body if isinstance(o, TypeDef)]) == 1
+    assert len([o for o in module.spec.body if isinstance(o, ir.VariableDeclaration)]) == 2
+    assert len([o for o in module.spec.body if isinstance(o, ir.TypeDef)]) == 1
     assert 'derived_type' in module.typedef_map
     assert len(module.routines) == 1
     assert module.routines[0].name == 'my_routine'
@@ -100,7 +101,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path):
     assert fexprgen(a.shape) == exptected_array_shape
 
     # Check the LHS of the assignment has correct meta-data
-    stmt = FindNodes(Assignment).visit(routine.body)[0]
+    stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
     pt_ext_arr = stmt.lhs
     assert pt_ext_arr.type.dtype == BasicType.REAL
     assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
@@ -177,14 +178,14 @@ def test_module_external_typedefs_type(frontend, tmp_path):
 
     # Verify correct attachment of type information
     assert 'ext_type' in module.symbol_attrs
-    assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, TypeDef)
-    assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, TypeDef)
-    assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
-    assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, TypeDef)
+    assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, ir.TypeDef)
+    assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, ir.TypeDef)
+    assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
+    assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, ir.TypeDef)
     assert 'other_type' in module.symbol_attrs
     assert 'other_type' not in module['other_routine'].symbol_attrs
-    assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, TypeDef)
-    assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
+    assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, ir.TypeDef)
+    assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
 
     # OMNI resolves explicit shape parameters in the frontend parser
     exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'
@@ -206,7 +207,7 @@ def test_module_external_typedefs_type(frontend, tmp_path):
     assert fexprgen(pt_ext_a.shape) == exptected_array_shape
 
     # Check the LHS of the assignment has correct meta-data
-    stmt = FindNodes(Assignment).visit(routine.body)[0]
+    stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
     pt_ext_arr = stmt.lhs
     assert pt_ext_arr.type.dtype == BasicType.REAL
     assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
@@ -412,9 +413,9 @@ def test_module_variables_add_remove(frontend, tmp_path):
     x = module.variable_map['x']  # That's the symbol for variable 'x'
     real_type = SymbolAttributes('real', kind=module.variable_map['jprb'])
     int_type = SymbolAttributes('integer')
-    a = Variable(name='a', type=real_type, scope=module)
-    b = Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
-    c = Variable(name='c', type=int_type, scope=module)
+    a = sym.Variable(name='a', type=real_type, scope=module)
+    b = sym.Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
+    c = sym.Variable(name='c', type=int_type, scope=module)
 
     # Add new variables and check that they are all in the module spec
     module.variables += (a, b, c)
@@ -554,22 +555,22 @@ def test_module_deep_clone(frontend, tmp_path):
     new_module = module.clone()
 
     n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
-    n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0]
+    n_decl = FindNodes(ir.VariableDeclaration).visit(new_module.spec)[0]
 
     # Remove the declaration of `n` and replace it with `3`
     new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
-    new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec)
+    new_module.spec = SubstituteExpressions({n: sym.Literal(3)}).visit(new_module.spec)
 
     # Check the new module has been changed
-    assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1
-    new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body)
+    assert len(FindNodes(ir.VariableDeclaration).visit(new_module.spec)) == 1
+    new_type_decls = FindNodes(ir.VariableDeclaration).visit(new_module['my_type'].body)
     assert len(new_type_decls) == 2
     assert new_type_decls[0].symbols[0] == 'vector(3)'
     assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'
 
     # Check the old one has not changed
-    assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2
-    type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body)
+    assert len(FindNodes(ir.VariableDeclaration).visit(module.spec)) == 2
+    type_decls = FindNodes(ir.VariableDeclaration).visit(module['my_type'].body)
     assert len(type_decls) == 2
     assert type_decls[0].symbols[0] == 'vector(n)'
     assert type_decls[1].symbols[0] == 'matrix(n, n)'
@@ -831,7 +832,7 @@ def test_module_rename_imports_with_definitions(frontend, tmp_path):
         assert mod3.symbol_attrs[s].compare(mod2.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name'))
 
     # Verify Import IR node
-    for imprt in FindNodes(Import).visit(mod3.spec):
+    for imprt in FindNodes(ir.Import).visit(mod3.spec):
         if imprt.module == 'test_rename_mod':
             assert imprt.rename_list
             assert not imprt.symbols
@@ -915,7 +916,7 @@ def test_module_rename_imports_no_definitions(frontend, tmp_path):
         assert mod3.symbol_attrs[s].use_name == use_name
 
     # Verify Import IR node
-    for imprt in FindNodes(Import).visit(mod3.spec):
+    for imprt in FindNodes(ir.Import).visit(mod3.spec):
         if imprt.module == 'test_rename_mod':
             assert imprt.rename_list
             assert not imprt.symbols
@@ -969,7 +970,7 @@ def test_module_use_module_nature(frontend, tmp_path):
 
     # Check properties on the Import IR node in the external module
     assert ext_mod.imported_symbols == ('int16',)
-    imprt = FindNodes(Import).visit(ext_mod.spec)[0]
+    imprt = FindNodes(ir.Import).visit(ext_mod.spec)[0]
     assert imprt.nature.lower() == 'intrinsic'
     assert imprt.module.lower() == 'iso_c_binding'
     assert ext_mod.imported_symbol_map['int16'].type.imported is True
@@ -988,8 +989,8 @@ def test_module_use_module_nature(frontend, tmp_path):
     assert set(my_kinds.imported_symbols) == {'int8', 'int16'}
     assert set(kinds.imported_symbols) == {'int8', 'int16'}
 
-    my_import_map = {s.name: imprt for imprt in FindNodes(Import).visit(my_kinds.spec) for s in imprt.symbols}
-    import_map = {s.name: imprt for imprt in FindNodes(Import).visit(kinds.spec) for s in imprt.symbols}
+    my_import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(my_kinds.spec) for s in imprt.symbols}
+    import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(kinds.spec) for s in imprt.symbols}
 
     assert my_import_map['int8'] is my_import_map['int16']
     assert import_map['int8'] is import_map['int16']
@@ -1194,13 +1195,13 @@ def test_module_contains_auto_insert(frontend, tmp_path):
     assert routine1.contains is None
 
     routine1 = routine1.clone(contains=routine2)
-    assert isinstance(routine1.contains, Section)
-    assert isinstance(routine1.contains.body[0], Intrinsic)
+    assert isinstance(routine1.contains, ir.Section)
+    assert isinstance(routine1.contains.body[0], ir.Intrinsic)
     assert routine1.contains.body[0].text == 'CONTAINS'
 
     module = module.clone(contains=routine1)
-    assert isinstance(module.contains, Section)
-    assert isinstance(module.contains.body[0], Intrinsic)
+    assert isinstance(module.contains, ir.Section)
+    assert isinstance(module.contains.body[0], ir.Intrinsic)
     assert module.contains.body[0].text == 'CONTAINS'
 
 
@@ -1243,14 +1244,14 @@ def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_
     b = driver.symbol_map['b']
 
     if complete_tree:
-        assert isinstance(a, Scalar)
+        assert isinstance(a, sym.Scalar)
         assert a.type.dtype is BasicType.INTEGER
-        assert isinstance(b, Scalar)
+        assert isinstance(b, sym.Scalar)
         assert b.type.dtype is BasicType.INTEGER
     else:
-        assert isinstance(a, DeferredTypeSymbol)
+        assert isinstance(a, sym.DeferredTypeSymbol)
         assert a.type.dtype is BasicType.DEFERRED
-        assert isinstance(b, DeferredTypeSymbol)
+        assert isinstance(b, sym.DeferredTypeSymbol)
         assert b.type.dtype is BasicType.DEFERRED
 
     assert a.type.imported
@@ -1371,3 +1372,55 @@ def test_module_enrichment_within_file(frontend, tmp_path):
         assert calls[0].arguments[0].type.parameter
         assert calls[0].arguments[0].type.initial == 16
         assert calls[0].arguments[0].type.module is source['foo']
+
+
+@pytest.mark.parametrize('frontend', available_frontends())
+def test_module_enrichment_typdefs(frontend, tmp_path):
+    """ Test that module-level enrihcment is propagated correctly """
+
+    fcode_state_mod = """
+module state_type_mod
+  implicit none
+
+  type state_type
+    real, pointer, dimension(:,:) :: a
+  end type state_type
+
+end module state_type_mod
+"""
+
+    fcode_driver_mod = """
+module driver_mod
+  use state_type_mod, only: state_type
+  implicit none
+
+contains
+  subroutine driver_routine(state)
+    type(state_type), intent(inout) :: state
+
+    state%a = 1
+
+  end subroutine driver_routine
+end module driver_mod
+"""
+    state_mod = Sourcefile.from_source(fcode_state_mod, frontend=frontend, xmods=[tmp_path])['state_type_mod']
+    driver_mod = Sourcefile.from_source(fcode_driver_mod, frontend=frontend, xmods=[tmp_path])['driver_mod']
+    driver = driver_mod['driver_routine']
+
+    state = driver.variable_map['state']
+    assert isinstance(state.type.dtype, DerivedType)
+    assert state.type.dtype.typedef == BasicType.DEFERRED
+
+    # Enrich typedef on the outer module Import
+    driver_mod.enrich([state_mod], recurse=True)
+
+    state = driver.variable_map['state']
+
+    # Ensure type info has been propagated to inner subroutine
+    assert isinstance(state.type.dtype, DerivedType)
+    assert isinstance(state.type.dtype.typedef, ir.TypeDef)
+
+    assigns = FindNodes(ir.Assignment).visit(driver.body)
+    assert len(assigns) == 1
+    assert assigns[0].lhs.type.dtype == BasicType.REAL
+    assert assigns[0].lhs.type.shape == (':', ':')