Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gccrs: add support for lang_item eq and PartialEq trait #3347

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions gcc/rust/backend/rust-compile-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,26 @@ CompileExpr::visit (HIR::ComparisonExpr &expr)
auto rhs = CompileExpr::Compile (expr.get_rhs (), ctx);
auto location = expr.get_locus ();

// this might be an operator overload situation lets check
TyTy::FnType *fntype;
bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
expr.get_mappings ().get_hirid (), &fntype);
if (is_op_overload)
{
auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
auto segment = HIR::PathIdentSegment (seg_name);
auto lang_item_type
= LangItem::ComparisonToLangItem (expr.get_expr_type ());

rhs = address_expression (rhs, EXPR_LOCATION (rhs));

translated = resolve_operator_overload (
lang_item_type, expr, lhs, rhs, expr.get_lhs (),
tl::optional<std::reference_wrapper<HIR::Expr>> (expr.get_rhs ()),
segment);
return;
}

translated = Backend::comparison_expression (op, lhs, rhs, location);
}

Expand Down Expand Up @@ -1478,7 +1498,8 @@ CompileExpr::get_receiver_from_dyn (const TyTy::DynamicObjectType *dyn,
tree
CompileExpr::resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs, tree rhs,
HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr)
HIR::Expr &lhs_expr, tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
HIR::PathIdentSegment specified_segment)
{
TyTy::FnType *fntype;
bool is_op_overload = ctx->get_tyctx ()->lookup_operator_overload (
Expand All @@ -1499,7 +1520,10 @@ CompileExpr::resolve_operator_overload (
}

// lookup compiled functions since it may have already been compiled
HIR::PathIdentSegment segment_name (LangItem::ToString (lang_item_type));
HIR::PathIdentSegment segment_name
= specified_segment.is_error ()
? HIR::PathIdentSegment (LangItem::ToString (lang_item_type))
: specified_segment;
tree fn_expr = resolve_method_address (fntype, receiver, expr.get_locus ());

// lookup the autoderef mappings
Expand Down
4 changes: 3 additions & 1 deletion gcc/rust/backend/rust-compile-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ class CompileExpr : private HIRCompileBase, protected HIR::HIRExpressionVisitor
tree resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, tree lhs,
tree rhs, HIR::Expr &lhs_expr,
tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr);
tl::optional<std::reference_wrapper<HIR::Expr>> rhs_expr,
HIR::PathIdentSegment specified_segment
= HIR::PathIdentSegment::create_error ());

tree compile_bool_literal (const HIR::LiteralExpr &expr,
const TyTy::BaseType *tyty);
Expand Down
6 changes: 6 additions & 0 deletions gcc/rust/hir/tree/rust-hir-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,12 @@ OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr)
locus (expr.get_locus ())
{}

OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr)
: node_mappings (expr.get_mappings ()),
lvalue_mappings (expr.get_expr ().get_mappings ()),
locus (expr.get_locus ())
{}

AnonConst::AnonConst (NodeId id, std::unique_ptr<Expr> expr)
: id (id), expr (std::move (expr))
{
Expand Down
2 changes: 2 additions & 0 deletions gcc/rust/hir/tree/rust-hir-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,8 @@ class OperatorExprMeta

OperatorExprMeta (HIR::ArrayIndexExpr &expr);

OperatorExprMeta (HIR::ComparisonExpr &expr);

const Analysis::NodeMapping &get_mappings () const { return node_mappings; }

const Analysis::NodeMapping &get_lvalue_mappings () const
Expand Down
27 changes: 22 additions & 5 deletions gcc/rust/typecheck/rust-hir-type-check-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,21 @@ TypeCheckExpr::visit (HIR::ComparisonExpr &expr)
auto lhs = TypeCheckExpr::Resolve (expr.get_lhs ());
auto rhs = TypeCheckExpr::Resolve (expr.get_rhs ());

auto borrowed_rhs
= new TyTy::ReferenceType (mappings.get_next_hir_id (),
TyTy::TyVar (rhs->get_ref ()), Mutability::Imm);
context->insert_implicit_type (borrowed_rhs->get_ref (), borrowed_rhs);

auto seg_name = LangItem::ComparisonToSegment (expr.get_expr_type ());
auto segment = HIR::PathIdentSegment (seg_name);
auto lang_item_type = LangItem::ComparisonToLangItem (expr.get_expr_type ());

bool operator_overloaded
= resolve_operator_overload (lang_item_type, expr, lhs, borrowed_rhs,
segment);
if (operator_overloaded)
return;

unify_site (expr.get_mappings ().get_hirid (),
TyTy::TyWithLocation (lhs, expr.get_lhs ().get_locus ()),
TyTy::TyWithLocation (rhs, expr.get_rhs ().get_locus ()),
Expand Down Expand Up @@ -1640,10 +1655,10 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr)
}

bool
TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs,
TyTy::BaseType *rhs)
TypeCheckExpr::resolve_operator_overload (
LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs, TyTy::BaseType *rhs,
HIR::PathIdentSegment specified_segment)
{
// look up lang item for arithmetic type
std::string associated_item_name = LangItem::ToString (lang_item_type);
Expand All @@ -1661,7 +1676,9 @@ TypeCheckExpr::resolve_operator_overload (LangItem::Kind lang_item_type,
current_context = context->peek_context ();
}

auto segment = HIR::PathIdentSegment (associated_item_name);
auto segment = specified_segment.is_error ()
? HIR::PathIdentSegment (associated_item_name)
: specified_segment;
auto candidates = MethodResolver::Probe (lhs, segment);

// remove any recursive candidates
Expand Down
4 changes: 3 additions & 1 deletion gcc/rust/typecheck/rust-hir-type-check-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor
protected:
bool resolve_operator_overload (LangItem::Kind lang_item_type,
HIR::OperatorExprMeta expr,
TyTy::BaseType *lhs, TyTy::BaseType *rhs);
TyTy::BaseType *lhs, TyTy::BaseType *rhs,
HIR::PathIdentSegment specified_segment
= HIR::PathIdentSegment::create_error ());

bool resolve_fn_trait_call (HIR::CallExpr &expr,
TyTy::BaseType *function_tyty,
Expand Down
44 changes: 44 additions & 0 deletions gcc/rust/util/rust-lang-item.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ const BiMap<std::string, LangItem::Kind> Rust::LangItem::lang_items = {{

{"into_iter", Kind::INTOITER_INTOITER},
{"next", Kind::ITERATOR_NEXT},

{"eq", Kind::EQ},
{"partial_ord", Kind::PARTIAL_ORD},
}};

tl::optional<LangItem::Kind>
Expand Down Expand Up @@ -145,6 +148,47 @@ LangItem::OperatorToLangItem (ArithmeticOrLogicalOperator op)
rust_unreachable ();
}

LangItem::Kind
LangItem::ComparisonToLangItem (ComparisonOperator op)
{
switch (op)
{
case ComparisonOperator::NOT_EQUAL:
case ComparisonOperator::EQUAL:
return LangItem::Kind::EQ;

case ComparisonOperator::GREATER_THAN:
case ComparisonOperator::LESS_THAN:
case ComparisonOperator::GREATER_OR_EQUAL:
case ComparisonOperator::LESS_OR_EQUAL:
return LangItem::Kind::PARTIAL_ORD;
}

rust_unreachable ();
}

std::string
LangItem::ComparisonToSegment (ComparisonOperator op)
{
switch (op)
{
case ComparisonOperator::NOT_EQUAL:
return "ne";
case ComparisonOperator::EQUAL:
return "eq";
case ComparisonOperator::GREATER_THAN:
return "gt";
case ComparisonOperator::LESS_THAN:
return "lt";
case ComparisonOperator::GREATER_OR_EQUAL:
return "ge";
case ComparisonOperator::LESS_OR_EQUAL:
return "le";
}

rust_unreachable ();
}

LangItem::Kind
LangItem::CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op)
{
Expand Down
5 changes: 5 additions & 0 deletions gcc/rust/util/rust-lang-item.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class LangItem

NEGATION,
NOT,
EQ,
PARTIAL_ORD,

ADD_ASSIGN,
SUB_ASSIGN,
Expand Down Expand Up @@ -136,6 +138,9 @@ class LangItem
static Kind
CompoundAssignmentOperatorToLangItem (ArithmeticOrLogicalOperator op);
static Kind NegationOperatorToLangItem (NegationOperator op);
static Kind ComparisonToLangItem (ComparisonOperator op);

static std::string ComparisonToSegment (ComparisonOperator op);
};

} // namespace Rust
Expand Down
8 changes: 4 additions & 4 deletions gcc/rust/util/rust-operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ enum class ComparisonOperator
{
EQUAL, // std::cmp::PartialEq::eq
NOT_EQUAL, // std::cmp::PartialEq::ne
GREATER_THAN, // std::cmp::PartialEq::gt
LESS_THAN, // std::cmp::PartialEq::lt
GREATER_OR_EQUAL, // std::cmp::PartialEq::ge
LESS_OR_EQUAL // std::cmp::PartialEq::le
GREATER_THAN, // std::cmp::PartialOrd::gt
LESS_THAN, // std::cmp::PartialOrd::lt
GREATER_OR_EQUAL, // std::cmp::PartialOrd::ge
LESS_OR_EQUAL // std::cmp::PartialOrd::le
};

enum class LazyBooleanOperator
Expand Down
78 changes: 78 additions & 0 deletions gcc/testsuite/rust/compile/cmp1.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// { dg-options "-w" }
// taken from https://github.com/rust-lang/rust/blob/e1884a8e3c3e813aada8254edfa120e85bf5ffca/library/core/src/cmp.rs#L98

#[lang = "sized"]
pub trait Sized {}

#[lang = "eq"]
#[stable(feature = "rust1", since = "1.0.0")]
#[doc(alias = "==")]
#[doc(alias = "!=")]
pub trait PartialEq<Rhs: ?Sized = Self> {
/// This method tests for `self` and `other` values to be equal, and is used
/// by `==`.
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
fn eq(&self, other: &Rhs) -> bool;

/// This method tests for `!=`.
#[inline]
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
fn ne(&self, other: &Rhs) -> bool {
!self.eq(other)
}
}

enum BookFormat {
Paperback,
Hardback,
Ebook,
}

impl PartialEq<BookFormat> for BookFormat {
fn eq(&self, other: &BookFormat) -> bool {
self == other
}
}

pub struct Book {
isbn: i32,
format: BookFormat,
}

// Implement <Book> == <BookFormat> comparisons
impl PartialEq<BookFormat> for Book {
fn eq(&self, other: &BookFormat) -> bool {
self.format == *other
}
}

// Implement <BookFormat> == <Book> comparisons
impl PartialEq<Book> for BookFormat {
fn eq(&self, other: &Book) -> bool {
*self == other.format
}
}

// Implement <Book> == <Book> comparisons
impl PartialEq<Book> for Book {
fn eq(&self, other: &Book) -> bool {
self.isbn == other.isbn
}
}

pub fn main() {
let b1 = Book {
isbn: 1,
format: BookFormat::Paperback,
};
let b2 = Book {
isbn: 2,
format: BookFormat::Paperback,
};

let _c1: bool = b1 == BookFormat::Paperback;
let _c2: bool = BookFormat::Paperback == b2;
let _c3: bool = b1 != b2;
}
1 change: 1 addition & 0 deletions gcc/testsuite/rust/compile/nr2/exclude
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,5 @@ additional-trait-bounds2.rs
auto_traits2.rs
auto_traits3.rs
issue-3140.rs
cmp1.rs
# please don't delete the trailing newline
Loading