diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index d1f7dbbb40..612fc32cc3 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -834,6 +834,7 @@ RUN(NAME lambda_01 LABELS cpython llvm llvm_jit) RUN(NAME c_mangling LABELS cpython llvm llvm_jit c) RUN(NAME class_01 LABELS cpython llvm llvm_jit) +RUN(NAME class_02 LABELS cpython llvm llvm_jit) # callback_04 is to test emulation. So just run with cpython RUN(NAME callback_04 IMPORT_PATH .. LABELS cpython) diff --git a/integration_tests/class_02.py b/integration_tests/class_02.py new file mode 100644 index 0000000000..11325b1c06 --- /dev/null +++ b/integration_tests/class_02.py @@ -0,0 +1,44 @@ +from lpython import i32 +class Character: + def __init__(self:"Character", name:str, health:i32, attack_power:i32): + self.name :str = name + self.health :i32 = health + self.attack_power : i32 = attack_power + self.is_immortal : bool = False + + def attack(self:"Character", other:"Character") -> str: + other.health -= self.attack_power + return self.name+" attacks "+ other.name+" for "+str(self.attack_power)+" damage." + + def is_alive(self:"Character")->bool: + if self.is_immortal: + return True + else: + return self.health > 0 + +def main(): + hero : Character = Character("Hero", 10, 20) + monster : Character = Character("Monster", 50, 15) + print(hero.attack(monster)) + print(monster.health) + assert monster.health == 30 + print(monster.is_alive()) + assert monster.is_alive() == True + print("Hero gains temporary immortality") + hero.is_immortal = True + print(monster.attack(hero)) + print(hero.health) + assert hero. health == -5 + print(hero.is_alive()) + assert hero.is_alive() == True + print("Hero's immortality runs out") + hero.is_immortal = False + print(hero.is_alive()) + assert hero.is_alive() == False + print("Restarting") + hero = Character("Hero", 10, 20) + print(hero.is_alive()) + assert hero.is_alive() == True + +main() + \ No newline at end of file diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 49a8255213..2e923e9f8e 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -3087,6 +3087,30 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } + void instantiate_methods(const ASR::Struct_t &x) { + SymbolTable *current_scope_copy = current_scope; + current_scope = x.m_symtab; + for ( auto &item : x.m_symtab->get_scope() ) { + if ( is_a(*item.second) ) { + ASR::Function_t *v = down_cast(item.second); + instantiate_function(*v); + } + } + current_scope = current_scope_copy; + } + + void visit_methods (const ASR::Struct_t &x) { + SymbolTable *current_scope_copy = current_scope; + current_scope = x.m_symtab; + for ( auto &item : x.m_symtab->get_scope() ) { + if ( is_a(*item.second) ) { + ASR::Function_t *v = down_cast(item.second); + visit_Function(*v); + } + } + current_scope = current_scope_copy; + } + void start_module_init_function_prototype(const ASR::Module_t &x) { uint32_t h = get_hash((ASR::asr_t*)&x); llvm::FunctionType *function_type = llvm::FunctionType::get( @@ -3128,6 +3152,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } else if (is_a(*item.second)) { ASR::EnumType_t *et = down_cast(item.second); visit_EnumType(*et); + } else if (is_a(*item.second)) { + ASR::Struct_t *st = down_cast(item.second); + instantiate_methods(*st); } } finish_module_init_function_prototype(x); @@ -4179,6 +4206,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor if (is_a(*item.second)) { ASR::Function_t *s = ASR::down_cast(item.second); visit_Function(*s); + } else if ( is_a(*item.second) ) { + ASR::Struct_t *st = down_cast(item.second); + visit_methods(*st); } } } diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 5a12d1e3f8..4909e0461b 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -1281,8 +1281,20 @@ class CommonVisitor : public AST::BaseVisitor { visit_expr_list(pos_args, n_pos_args, kwargs, n_kwargs, args, st, loc); } + if ( st->n_member_functions > 0 ) { + // Empty struct constructor + // Initializers handled in init proc call + Vecempty_args; + empty_args.reserve(al, 1); + for (size_t i = 0; i < st->n_members; i++) { + empty_args.push_back(al, st->m_initializers[i]); + } + ASR::ttype_t* der_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc, stemp)); + return ASR::make_StructConstructor_t(al, loc, stemp, empty_args.p, + empty_args.size(), der_type, nullptr); + } - if (args.size() > 0 && args.size() > st->n_members) { + if ( args.size() > 0 && args.size() > st->n_members ) { throw SemanticError("StructConstructor has more arguments than the number of struct members", loc); } @@ -1656,6 +1668,10 @@ class CommonVisitor : public AST::BaseVisitor { } else { throw SemanticError("Only Name in Subscript supported for now in annotation", annotation->base.loc); } + } else if ( AST::is_a(*annotation) ) { + //self case in methods + intent = ASRUtils::intent_inout; + return annotation; } return annotation; } @@ -1913,6 +1929,15 @@ class CommonVisitor : public AST::BaseVisitor { current_scope->add_symbol(import_name, import_struct_member); } return ASRUtils::TYPE(ASR::make_Union_t(al, attr_annotation->base.base.loc, import_struct_member)); + } else if ( AST::is_a(annotation) ) { + AST::ConstantStr_t *n = AST::down_cast(&annotation); + ASR::symbol_t *sym = current_scope->parent->parent->resolve_symbol(n->m_value); + if ( sym == nullptr || !ASR::is_a(*sym) ) { + throw SemanticError("Only Struct implemented for constant" + " str annotation", loc); + } + //TODO: Change the returned type from Class to Struct + return ASRUtils::TYPE(ASR::make_Class_t(al,loc,sym)); } throw SemanticError("Only Name, Subscript, and Call supported for now in annotation of annotated assignment.", loc); @@ -2930,22 +2955,13 @@ class CommonVisitor : public AST::BaseVisitor { assign_asr_target = assign_asr_target_copy; } - void handle_init_method(const AST::FunctionDef_t &x, + void get_members_init (const AST::FunctionDef_t &x, Vec& member_names, Vec &member_init){ if(x.n_decorator_list > 0) { throw SemanticError("Decorators for __init__ not implemented", x.base.base.loc); } - if( x.m_args.n_args > 1 ) { - throw SemanticError("Only default constructors implemented ", - x.base.base.loc); - } - // TODO: the obj_name can be anything - std::string obj_name = "self"; - if ( std::string(x.m_args.m_args[0].m_arg) != obj_name) { - throw SemanticError("Only `self` can be used as object name for now", - x.base.base.loc); - } + std::string obj_name = x.m_args.m_args->m_arg; for(size_t i = 0; i < x.n_body; i++) { std::string var_name; if (! AST::is_a(*x.m_body[i]) ){ @@ -2958,7 +2974,7 @@ class CommonVisitor : public AST::BaseVisitor { if(AST::is_a(*a->m_value)) { AST::Name_t* n = AST::down_cast(a->m_value); if(std::string(n->m_id) != obj_name) { - throw SemanticError("Object name doesn't matach", + throw SemanticError("Object name doesn't match", x.m_body[i]->base.loc); } } @@ -2968,7 +2984,6 @@ class CommonVisitor : public AST::BaseVisitor { throw SemanticError("Only Attribute supported as target in " "AnnAssign inside Class", x.m_body[i]->base.loc); } - ASR::expr_t* init_expr = nullptr; ASR::abiType abi = ASR::abiType::Source; bool is_allocatable = false, is_const = false; ASR::ttype_t *type = ast_expr_to_asr_type(ann_assign.m_annotation->base.loc, @@ -2976,38 +2991,13 @@ class CommonVisitor : public AST::BaseVisitor { ASR::storage_typeType storage_type = ASR::storage_typeType::Default; create_add_variable_to_scope(var_name, type, - ann_assign.base.base.loc, abi, storage_type); - - if (ann_assign.m_value == nullptr) { - throw SemanticError("Missing an initialiser for the data member", - x.m_body[i]->base.loc); - } - this->visit_expr(*ann_assign.m_value); - if (tmp && ASR::is_a(*tmp)) { - ASR::expr_t* value = ASRUtils::EXPR(tmp); - ASR::ttype_t* underlying_type = type; - cast_helper(underlying_type, value, value->base.loc); - if (!ASRUtils::check_equal_type(underlying_type, ASRUtils::expr_type(value), true)) { - std::string ltype = ASRUtils::type_to_str_python(underlying_type); - std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(value)); - diag.add(diag::Diagnostic( - "Type mismatch in annotation-assignment, the types must be compatible", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", - {ann_assign.m_target->base.loc, value->base.loc}) - }) - ); - throw SemanticAbort(); - } - init_expr = value; - } + ann_assign.base.base.loc, abi, storage_type); ASR::symbol_t* var_sym = current_scope->resolve_symbol(var_name); ASR::call_arg_t c_arg; c_arg.loc = var_sym->base.loc; - c_arg.m_value = init_expr; + c_arg.m_value = nullptr; member_init.push_back(al, c_arg); } - } void visit_ClassMembers(const AST::ClassDef_t& x, @@ -3037,11 +3027,9 @@ class CommonVisitor : public AST::BaseVisitor { *f = AST::down_cast(x.m_body[i]); std::string f_name = f->m_name; if (f_name == "__init__") { - this->handle_init_method(*f, member_names, member_init); - // This seems hackish, as struct depends on itself - // We need to handle this later. - // Removing this throws a ASR verify error - struct_dependencies.push_back(al, x.m_name); + this->get_members_init(*f, member_names, member_init); + this->visit_stmt(*x.m_body[i]); + member_fn_names.push_back(al, f->m_name); } else { this->visit_stmt(*x.m_body[i]); member_fn_names.push_back(al, f->m_name); @@ -3312,7 +3300,9 @@ class CommonVisitor : public AST::BaseVisitor { if ( AST::is_a(*x.m_body[i]) ) { AST::FunctionDef_t* f = AST::down_cast(x.m_body[i]); - if ( std::string(f->m_name) != std::string("__init__") ) { + if ( std::string(f->m_name) == std::string("__init__") ) { + this->visit_init_body(*f); + } else { this->visit_stmt(*x.m_body[i]); } } @@ -3333,9 +3323,6 @@ class CommonVisitor : public AST::BaseVisitor { "instead use the dataclass decorator ", x.base.base.loc); } - visit_ClassMembers(x, member_names, member_fn_names, - struct_dependencies, member_init, false, class_abi, true); - LCOMPILERS_ASSERT(member_init.size() == member_names.size()); ASR::symbol_t* class_sym = ASR::down_cast( ASR::make_Struct_t(al, x.base.base.loc, current_scope, x.m_name, struct_dependencies.p, struct_dependencies.size(), @@ -3343,22 +3330,26 @@ class CommonVisitor : public AST::BaseVisitor { member_fn_names.size(), class_abi, ASR::accessType::Public, false, false, member_init.p, member_init.size(), nullptr, nullptr)); - ASR::ttype_t* class_type = ASRUtils::TYPE( - ASRUtils::make_StructType_t_util(al, x.base.base.loc, - class_sym)); - std::string self_name = "self"; - if ( current_scope->get_symbol(self_name) ) { - throw SemanticError("`self` cannot be used as a data member " - "for now", x.base.base.loc); - } - create_add_variable_to_scope(self_name, class_type, - x.base.base.loc, class_abi); parent_scope->add_symbol(x.m_name, class_sym); + visit_ClassMembers(x, member_names, member_fn_names, + struct_dependencies, member_init, false, class_abi, true); + ASR::Struct_t* st = ASR::down_cast(class_sym); + st->m_dependencies = struct_dependencies.p; + st->n_dependencies = struct_dependencies.n; + st->m_member_functions = member_fn_names.p; + st->n_member_functions = member_fn_names.n; + st->m_members = member_names.p; + st->n_members = member_names.n; + st->m_initializers = member_init.p; + st->n_initializers = member_init.n; + } current_scope = parent_scope; } } + virtual void visit_init_body (const AST::FunctionDef_t &/*x*/) = 0; + void add_name(const Location &loc) { std::string var_name = "__name__"; std::string var_value = module_name; @@ -4391,6 +4382,10 @@ class SymbolTableVisitor : public CommonVisitor { // Implement visit_Global for Symbol Table visitor. void visit_Global(const AST::Global_t &/*x*/) {} + void visit_init_body (const AST::FunctionDef_t &/*x*/) { + //Implemented in BodyVisitor + } + void visit_FunctionDef(const AST::FunctionDef_t &x) { dependencies.clear(al); SymbolTable *parent_scope = current_scope; @@ -5109,6 +5104,50 @@ class BodyVisitor : public CommonVisitor { tmp = asr; } + void visit_init_body (const AST::FunctionDef_t &x) { + SymbolTable *old_scope = current_scope; + ASR::symbol_t *t = current_scope->get_symbol("__init__"); + if ( t==nullptr ) { + throw SemanticError("__init__ fn not declared", x.base.base.loc); + } + if ( !ASR::is_a(*t) ) { + throw SemanticError("__init__ is not a function", x.base.base.loc); + } + ASR::Function_t *f = ASR::down_cast(t); + //Transform statements into correct format + Vec new_body; + new_body.reserve(al, 1); + for (size_t i=0; i(x.m_body[i]); + if ( ann_assign.m_value != nullptr ) { + Vectarget; + target.reserve(al, 1); + target.push_back(al, ann_assign.m_target); + AST::ast_t* assgn_ast = AST::make_Assign_t(al, ann_assign.base.base.loc, + target.p, 1, ann_assign.m_value, nullptr); + AST::stmt_t* assgn = AST::down_cast(assgn_ast); + new_body.push_back(al, assgn); + } + } + current_scope = f->m_symtab; + Vec body; + body.reserve(al, x.n_body); + Vec rts; + rts.reserve(al, 4); + dependencies.clear(al); + transform_stmts(body, new_body.n, new_body.p); + for (const auto &rt: rt_vec) { rts.push_back(al, rt); } + f->m_body = body.p; + f->n_body = body.size(); + ASR::FunctionType_t* func_type = ASR::down_cast(f->m_function_signature); + func_type->m_restrictions = rts.p; + func_type->n_restrictions = rts.size(); + f->m_dependencies = dependencies.p; + f->n_dependencies = dependencies.size(); + rt_vec.clear(); + current_scope = old_scope; + } + void handle_fn(const AST::FunctionDef_t &x, ASR::Function_t &v) { current_scope = v.m_symtab; Vec body; @@ -5212,6 +5251,29 @@ class BodyVisitor : public CommonVisitor { } ASR::expr_t *init_expr = nullptr; visit_AnnAssignUtil(x, var_name, init_expr); + ASR::symbol_t* sym = current_scope->get_symbol(var_name); + if ( sym && ASR::is_a(*sym) ) { + ASR::Variable_t* var = ASR::down_cast(sym); + if ( ASR::is_a(*(var->m_type)) && + !ASR::down_cast((var->m_type))->m_is_cstruct && + ASR::is_a(*init_expr) ) { + AST::Call_t* call = AST::down_cast(x.m_value); + if ( call->n_keywords>0 ) { + throw SemanticError("Kwargs not implemented yet", x.base.base.loc); + } + Vec args; + args.reserve(al, call->n_args + 1); + ASR::call_arg_t self_arg; + self_arg.loc = x.base.base.loc; + self_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, sym)); + args.push_back(al, self_arg); + visit_expr_list(call->m_args, call->n_args, args); + ASR::symbol_t* der = ASR::down_cast((var->m_type))->m_derived_type; + std::string call_name = "__init__"; + ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc); + tmp = make_call_helper(al, call_sym, current_scope, args, call_name, x.base.base.loc); + } + } } void visit_Delete(const AST::Delete_t &x) { @@ -5485,6 +5547,26 @@ class BodyVisitor : public CommonVisitor { } tmp_vec.push_back(ASR::make_Assignment_t(al, x.base.base.loc, target, tmp_value, overloaded)); + if ( target->type == ASR::exprType::Var && + tmp_value->type == ASR::exprType::StructConstructor ) { + Vec new_args; new_args.reserve(al, 1); + ASR::call_arg_t self_arg; + self_arg.loc = x.base.base.loc; + ASR::symbol_t* st = ASR::down_cast(target)->m_v; + self_arg.m_value = target; + new_args.push_back(al,self_arg); + AST::Call_t* call = AST::down_cast(x.m_value); + if ( call->n_keywords>0 ) { + throw SemanticError("Kwargs not implemented yet", x.base.base.loc); + } + visit_expr_list(call->m_args, call->n_args, new_args); + ASR::symbol_t* der = ASR::down_cast( + ASR::down_cast(st)->m_type)->m_derived_type; + std::string call_name = "__init__"; + ASR::symbol_t* call_sym = get_struct_member(der, call_name, x.base.base.loc); + tmp_vec.push_back(make_call_helper(al, call_sym, + current_scope, new_args, call_name, x.base.base.loc)); + } } // to make sure that we add only those statements in tmp_vec tmp = nullptr; @@ -6113,7 +6195,7 @@ class BodyVisitor : public CommonVisitor { throw SemanticError("'" + attr + "' is not implemented for Complex type", loc); } - } else if( ASR::is_a(*type) ) { + } else if( ASR::is_a(*type)) { ASR::StructType_t* der = ASR::down_cast(type); ASR::symbol_t* der_sym = ASRUtils::symbol_get_past_external(der->m_derived_type); ASR::Struct_t* der_type = ASR::down_cast(der_sym); @@ -6166,6 +6248,59 @@ class BodyVisitor : public CommonVisitor { } tmp = ASR::make_StructInstanceMember_t(al, loc, val, member_sym, member_var_type, nullptr); + } else if( ASR::is_a(*type) ) { //TODO: Remove Class_t from here + ASR::Class_t* der = ASR::down_cast(type); + ASR::symbol_t* der_sym = ASRUtils::symbol_get_past_external(der->m_class_type); + ASR::Struct_t* der_type = ASR::down_cast(der_sym); + bool member_found = false; + std::string member_name = attr_char; + for( size_t i = 0; i < der_type->n_members && !member_found; i++ ) { + member_found = std::string(der_type->m_members[i]) == member_name; + } + if( !member_found ) { + throw SemanticError("No member " + member_name + + " found in " + std::string(der_type->m_name), + loc); + } + ASR::expr_t *val = ASR::down_cast(ASR::make_Var_t(al, loc, t)); + ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name); + LCOMPILERS_ASSERT(ASR::is_a(*member_sym)); + ASR::Variable_t* member_var = ASR::down_cast(member_sym); + ASR::ttype_t* member_var_type = member_var->m_type; + if( ASR::is_a(*member_var->m_type) ) { + ASR::StructType_t* member_var_struct_t = ASR::down_cast(member_var->m_type); + if( !ASR::is_a(*member_var_struct_t->m_derived_type) ) { + ASR::Struct_t* struct_type = ASR::down_cast(member_var_struct_t->m_derived_type); + ASR::symbol_t* struct_type_asr_owner = ASRUtils::get_asr_owner(member_var_struct_t->m_derived_type); + if( struct_type_asr_owner && ASR::is_a(*struct_type_asr_owner) ) { + std::string struct_var_name = ASR::down_cast(struct_type_asr_owner)->m_name; + std::string struct_member_name = struct_type->m_name; + std::string import_name = struct_var_name + "_" + struct_member_name; + ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name); + bool import_from_struct = true; + if( import_struct_member ) { + if( ASR::is_a(*import_struct_member) ) { + ASR::ExternalSymbol_t* ext_sym = ASR::down_cast(import_struct_member); + if( ext_sym->m_external == member_var_struct_t->m_derived_type && + std::string(ext_sym->m_module_name) == struct_var_name ) { + import_from_struct = false; + } + } + } + if( import_from_struct ) { + import_name = current_scope->get_unique_name(import_name, false); + import_struct_member = ASR::down_cast(ASR::make_ExternalSymbol_t(al, + loc, current_scope, s2c(al, import_name), + member_var_struct_t->m_derived_type, s2c(al, struct_var_name), nullptr, 0, + s2c(al, struct_member_name), ASR::accessType::Public)); + current_scope->add_symbol(import_name, import_struct_member); + } + member_var_type = ASRUtils::TYPE(ASRUtils::make_StructType_t_util(al, loc, import_struct_member)); + } + } + } + tmp = ASR::make_StructInstanceMember_t(al, loc, val, member_sym, + member_var_type, nullptr); } else if (ASR::is_a(*type)) { ASR::Enum_t* enum_ = ASR::down_cast(type); ASR::EnumType_t* enum_type = ASR::down_cast(enum_->m_enum_type); @@ -7891,8 +8026,19 @@ we will have to use something else. ASR::Variable_t* var = ASR::down_cast(st); if (ASR::is_a(*var->m_type)) { // call to struct member function + // modifying args to pass the object as self ASR::StructType_t* var_struct = ASR::down_cast(var->m_type); + Vec new_args; new_args.reserve(al, args.n + 1); + ASR::call_arg_t self_arg; + self_arg.loc = args[0].loc; + self_arg.m_value = ASRUtils::EXPR(ASR::make_Var_t(al, loc, st)); + new_args.push_back(al, self_arg); + for (size_t i=0; im_derived_type, call_name, loc); + tmp = make_call_helper(al, st, current_scope, new_args, call_name, loc); + return; } else { // this case when we have variable and attribute st = current_scope->resolve_symbol(mod_name);