Skip to content

Commit

Permalink
[WASM] Avoid triggering clinit from internal calls in class
Browse files Browse the repository at this point in the history
This handles basic inner cycles for clinit due to calls to methods.

Consider following:

```
class Foo {
   // initialized in Foo.clinit
   public static final Foo instance = new Foo();

   // Also calls Foo.clinit since it is public
   public Foo() {}
}
```

We essentially have two entry points in the classes that we need to clinit. This introduces a cycle in clinit (clinit call ctor, ctor calls clinit) and we cannot hoist instance to global because it is no longer trivial.

This could also happen with following pattern:

```
class Foo {
   // initialized in Foo.clinit
   public static final int instance = someHelper();

   // Also calls Foo.clinit since it is public
   public static int someHelper() { ... }
}
```

To overcome this problem, when there is a public method that triggers clinit, we can make it private and add a public one that calls clinit. Then re-writing Internal calls go through private one so they don't trigger clinit. Since this only requires local knowledge, it is compatible with modular compilation.

PiperOrigin-RevId: 590882728
  • Loading branch information
gkdn authored and copybara-github committed Dec 14, 2023
1 parent 893fd3e commit 71f2c81
Show file tree
Hide file tree
Showing 37 changed files with 869 additions and 185 deletions.
2 changes: 0 additions & 2 deletions transpiler/java/com/google/j2cl/transpiler/ast/AstUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,6 @@ private static Method createForwardingMethod(
MethodDescriptor toMethodDescriptor,
String jsDocDescription,
boolean isStaticDispatch) {
checkArgument(!fromMethodDescriptor.getEnclosingTypeDescriptor().isInterface());

List<Variable> parameters =
createParameterVariables(fromMethodDescriptor.getParameterTypeDescriptors());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ public final void applyTo(CompilationUnit compilationUnit) {
if (type.isNative()) {
continue;
}
synthesizeClinitCallsInMethods(type);
synthesizeSuperClinitCalls(type);
// Apply the additional normalizations defined in subclasses.
applyTo(type);
Expand Down Expand Up @@ -87,7 +86,7 @@ private static String getUniqueIdentifier(MemberDescriptor memberDescriptor) {
}

/** Add clinit calls to methods and (real js) constructors. */
private void synthesizeClinitCallsInMethods(Type type) {
public void synthesizeClinitCallsInMethods(Type type) {
type.accept(
new AbstractRewriter() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public void applyTo(Type type) {
if (type.isJsEnum()) {
return;
}
synthesizeClinitCallsInMethods(type);
synthesizeSettersAndGetters(type);
synthesizeClinitMethod(type);
synthesizeStaticFieldDeclaration(type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
*/
package com.google.j2cl.transpiler.passes;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;

import com.google.j2cl.common.SourcePosition;
import com.google.j2cl.transpiler.ast.AbstractRewriter;
import com.google.j2cl.transpiler.ast.AstUtils;
import com.google.j2cl.transpiler.ast.BinaryExpression;
import com.google.j2cl.transpiler.ast.BooleanLiteral;
import com.google.j2cl.transpiler.ast.DeclaredTypeDescriptor;
Expand All @@ -27,12 +29,17 @@
import com.google.j2cl.transpiler.ast.FieldAccess;
import com.google.j2cl.transpiler.ast.FieldDescriptor;
import com.google.j2cl.transpiler.ast.IfStatement;
import com.google.j2cl.transpiler.ast.Member;
import com.google.j2cl.transpiler.ast.Method;
import com.google.j2cl.transpiler.ast.MethodCall;
import com.google.j2cl.transpiler.ast.MethodDescriptor;
import com.google.j2cl.transpiler.ast.MultiExpression;
import com.google.j2cl.transpiler.ast.PrimitiveTypes;
import com.google.j2cl.transpiler.ast.ReturnStatement;
import com.google.j2cl.transpiler.ast.Statement;
import com.google.j2cl.transpiler.ast.Type;
import com.google.j2cl.transpiler.ast.Visibility;
import java.util.HashMap;
import java.util.List;

/**
Expand All @@ -46,13 +53,35 @@ public class ImplementStaticInitializationViaConditionChecks
@Override
public void applyTo(Type type) {
synthesizeClinitCallsOnFieldAccess(type);
synthesizeClinitCallsInMethods(type);
synthesizeClinitMethod(type);
}

/** Add clinit calls to field accesses. */
private void synthesizeClinitCallsOnFieldAccess(Type type) {
HashMap<MethodDescriptor, MethodDescriptor> neededPrivateMethodsByPublic = new HashMap<>();
type.accept(
new AbstractRewriter() {
@Override
public Expression rewriteMethodCall(MethodCall methodCall) {
// To avoid calling clinit when calling the methods on the same class, the non-private
// method is converted to private method call. Then later on (below) we will add these
// new private methods.

MethodDescriptor target = methodCall.getTarget();
if (target.isMemberOf(type.getDeclaration()) && triggersClinit(target, type)) {
checkState(target.isStatic());

// No need to call clinit when accessing the method from members in the enclosing
// type.
MethodDescriptor privateDescriptor = createPrivateDescriptor(target);
methodCall = MethodCall.Builder.from(methodCall).setTarget(privateDescriptor).build();
neededPrivateMethodsByPublic.put(
target.getDeclarationDescriptor(), privateDescriptor.getDeclarationDescriptor());
}
return methodCall;
}

@Override
public Expression rewriteFieldAccess(FieldAccess fieldAccess) {
FieldDescriptor target = fieldAccess.getTarget();
Expand All @@ -75,6 +104,35 @@ public Expression rewriteFieldAccess(FieldAccess fieldAccess) {
return fieldAccess;
}
});

// Insert private methods needed and make the public ones bridge to them to trigger clinit.
List<Member> members = type.getMembers();
for (int i = 0; i < members.size(); i++) {
if (!members.get(i).isMethod()) {
continue;
}
Method method = (Method) members.get(i);
MethodDescriptor privateDescriptor =
neededPrivateMethodsByPublic.remove(method.getDescriptor());
if (privateDescriptor == null) {
continue;
}
Method newPublicMethod =
AstUtils.createForwardingMethod(
method.getSourcePosition(),
null,
method.getDescriptor(),
privateDescriptor,
"Bridge to private");
members.set(i, newPublicMethod);
members.add(++i, Method.Builder.from(method).setMethodDescriptor(privateDescriptor).build());
}
checkState(neededPrivateMethodsByPublic.isEmpty(), neededPrivateMethodsByPublic);
}

private static MethodDescriptor createPrivateDescriptor(MethodDescriptor descriptor) {
return descriptor.transform(
m -> m.setVisibility(Visibility.PRIVATE).setName(descriptor.getName() + "_$private"));
}

/** Implements the static initialization method ($clinit). */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4969,10 +4969,20 @@
(func [email protected]
(result (ref null $bridgemethods.AccidentalOverrideBridge))
;;@ bridgemethods/BridgeMethod.java:96:6
(local $$instance (ref null $bridgemethods.AccidentalOverrideBridge))
(block
;;@ bridgemethods/BridgeMethod.java:96:6
(call $$clinit__void_<once>[email protected] )
;;@ bridgemethods/BridgeMethod.java:96:6
(return (call [email protected] ))
)
)

;;; AccidentalOverrideBridge AccidentalOverrideBridge.$create_$private()
(func [email protected]
(result (ref null $bridgemethods.AccidentalOverrideBridge))
;;@ bridgemethods/BridgeMethod.java:96:6
(local $$instance (ref null $bridgemethods.AccidentalOverrideBridge))
(block
;;@ bridgemethods/BridgeMethod.java:96:6
(local.set $$instance (struct.new $bridgemethods.AccidentalOverrideBridge (ref.as_non_null (global.get $bridgemethods.AccidentalOverrideBridge.vtable)) (ref.as_non_null (global.get $bridgemethods.AccidentalOverrideBridge.itable)) (i32.const 0)))
;;@ bridgemethods/BridgeMethod.java:96:6
Expand Down Expand Up @@ -5002,7 +5012,7 @@
(local.set $this (ref.cast (ref $bridgemethods.AccidentalOverrideBridge) (local.get $this.untyped)))
(block
;;@ bridgemethods/BridgeMethod.java:100:4
(local.set $g (call $$create__@bridgemethods.AccidentalOverrideBridge ))
(local.set $g (call $$create_$private__@bridgemethods.AccidentalOverrideBridge ))
;;@ bridgemethods/BridgeMethod.java:101:4
(drop (call_ref $function.m_get__java_lang_String__java_lang_String (ref.as_non_null (local.get $g))(call $function.no.side.effects.$getString_||__java_lang_String (ref.func $$getString_||[email protected]) )(struct.get $bridgemethods.Getter.vtable $m_get__java_lang_String__java_lang_String (ref.cast (ref $bridgemethods.Getter.vtable) (struct.get $itable $slot0 (struct.get $java.lang.Object $itable (local.get $g)))))))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,20 @@
(func [email protected]
(result (ref null $cast.CastGenerics))
;;@ cast/CastGenerics.java:20:13
(local $$instance (ref null $cast.CastGenerics))
(block
;;@ cast/CastGenerics.java:20:13
(call $$clinit__void_<once>[email protected] )
;;@ cast/CastGenerics.java:20:13
(return (call [email protected] ))
)
)

;;; CastGenerics<T, E> CastGenerics.$create_$private()
(func [email protected]
(result (ref null $cast.CastGenerics))
;;@ cast/CastGenerics.java:20:13
(local $$instance (ref null $cast.CastGenerics))
(block
;;@ cast/CastGenerics.java:20:13
(local.set $$instance (struct.new $cast.CastGenerics (ref.as_non_null (global.get $cast.CastGenerics.vtable)) (ref.as_non_null (global.get $itable.empty)) (i32.const 0) (ref.null $java.lang.Object)))
;;@ cast/CastGenerics.java:20:13
Expand Down Expand Up @@ -862,10 +872,10 @@
;;@ cast/CastGenerics.java:50:70
(call $$clinit__void_<once>[email protected] )
;;@ cast/CastGenerics.java:51:4
(local.set $str (ref.cast (ref null $java.lang.String) (struct.get $cast.CastGenerics [email protected] (call $$create__@cast.CastGenerics ))))
(local.set $str (ref.cast (ref null $java.lang.String) (struct.get $cast.CastGenerics [email protected] (call $$create_$private__@cast.CastGenerics ))))
;;@ cast/CastGenerics.java:52:4
(local.set $str (ref.cast (ref null $java.lang.String) (block (result (ref null $java.lang.Object))
(local.set $$qualifier (call $$create__@cast.CastGenerics ))
(local.set $$qualifier (call $$create_$private__@cast.CastGenerics ))
(call_ref $function.m_method__java_lang_Object_$pp_cast (ref.as_non_null (local.get $$qualifier))(struct.get $cast.CastGenerics.vtable $m_method__java_lang_Object_$pp_cast (struct.get $cast.CastGenerics $vtable(local.get $$qualifier))))
)))
;;@ cast/CastGenerics.java:54:4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,23 @@
)
)

;;; void CastOnArrayInit.fun(Foo<E>... args)
;;; void CastOnArrayInit.fun(Foo<E>... arg0)
(func $m_fun__arrayOf_castonarrayinit_CastOnArrayInit_Foo__void@castonarrayinit.CastOnArrayInit
(param $args (ref null $javaemul.internal.WasmArray.OfObject))
(param $arg0 (ref null $javaemul.internal.WasmArray.OfObject))
;;@ castonarrayinit/CastOnArrayInit.java:21:25
(block
;;@ castonarrayinit/CastOnArrayInit.java:21:45
;;@ castonarrayinit/CastOnArrayInit.java:21:25
(call $$clinit__void_<once>[email protected] )
;;@ castonarrayinit/CastOnArrayInit.java:21:25
(call $m_fun_$private__arrayOf_castonarrayinit_CastOnArrayInit_Foo__void@castonarrayinit.CastOnArrayInit (local.get $arg0))
)
)

;;; void CastOnArrayInit.fun_$private(Foo<E>... args)
(func $m_fun_$private__arrayOf_castonarrayinit_CastOnArrayInit_Foo__void@castonarrayinit.CastOnArrayInit
(param $args (ref null $javaemul.internal.WasmArray.OfObject))
;;@ castonarrayinit/CastOnArrayInit.java:21:25
(block
)
)

Expand All @@ -126,7 +136,7 @@
;;@ castonarrayinit/CastOnArrayInit.java:25:4
(local.set $f2 (call [email protected] ))
;;@ castonarrayinit/CastOnArrayInit.java:26:4
(call $m_fun__arrayOf_castonarrayinit_CastOnArrayInit_Foo__void@castonarrayinit.CastOnArrayInit (call $m_newWithLiteral__arrayOf_java_lang_Object__javaemul_internal_WasmArray_OfObject@javaemul.internal.WasmArray.OfObject (array.new_fixed $java.lang.Object.array (local.get $f1)(local.get $f2))))
(call $m_fun_$private__arrayOf_castonarrayinit_CastOnArrayInit_Foo__void@castonarrayinit.CastOnArrayInit (call $m_newWithLiteral__arrayOf_java_lang_Object__javaemul_internal_WasmArray_OfObject@javaemul.internal.WasmArray.OfObject (array.new_fixed $java.lang.Object.array (local.get $f1)(local.get $f2))))
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,20 @@
(func [email protected]
(result (ref null $cloneable.Cloneables.WithoutCloneable))
;;@ cloneable/Cloneables.java:19:22
(local $$instance (ref null $cloneable.Cloneables.WithoutCloneable))
(block
;;@ cloneable/Cloneables.java:19:22
(call $$clinit__void_<once>[email protected] )
;;@ cloneable/Cloneables.java:19:22
(return (call [email protected] ))
)
)

;;; WithoutCloneable WithoutCloneable.$create_$private()
(func [email protected]
(result (ref null $cloneable.Cloneables.WithoutCloneable))
;;@ cloneable/Cloneables.java:19:22
(local $$instance (ref null $cloneable.Cloneables.WithoutCloneable))
(block
;;@ cloneable/Cloneables.java:19:22
(local.set $$instance (struct.new $cloneable.Cloneables.WithoutCloneable (ref.as_non_null (global.get $cloneable.Cloneables.WithoutCloneable.vtable)) (ref.as_non_null (global.get $itable.empty)) (i32.const 0)))
;;@ cloneable/Cloneables.java:19:22
Expand Down Expand Up @@ -375,7 +385,7 @@
(local.set $this (ref.cast (ref $cloneable.Cloneables.WithoutCloneable) (local.get $this.untyped)))
(block
;;@ cloneable/Cloneables.java:23:6
(return (call $$create__@cloneable.Cloneables.WithoutCloneable ))
(return (call $$create_$private__@cloneable.Cloneables.WithoutCloneable ))
)
)
(elem declare func [email protected])
Expand Down Expand Up @@ -443,10 +453,20 @@
(func [email protected]
(result (ref null $cloneable.Cloneables.WithCloneable))
;;@ cloneable/Cloneables.java:27:22
(local $$instance (ref null $cloneable.Cloneables.WithCloneable))
(block
;;@ cloneable/Cloneables.java:27:22
(call $$clinit__void_<once>[email protected] )
;;@ cloneable/Cloneables.java:27:22
(return (call [email protected] ))
)
)

;;; WithCloneable WithCloneable.$create_$private()
(func [email protected]
(result (ref null $cloneable.Cloneables.WithCloneable))
;;@ cloneable/Cloneables.java:27:22
(local $$instance (ref null $cloneable.Cloneables.WithCloneable))
(block
;;@ cloneable/Cloneables.java:27:22
(local.set $$instance (struct.new $cloneable.Cloneables.WithCloneable (ref.as_non_null (global.get $cloneable.Cloneables.WithCloneable.vtable)) (ref.as_non_null (global.get $cloneable.Cloneables.WithCloneable.itable)) (i32.const 0)))
;;@ cloneable/Cloneables.java:27:22
Expand Down Expand Up @@ -476,7 +496,7 @@
(local.set $this (ref.cast (ref $cloneable.Cloneables.WithCloneable) (local.get $this.untyped)))
(block
;;@ cloneable/Cloneables.java:31:6
(return (call $$create__@cloneable.Cloneables.WithCloneable ))
(return (call $$create_$private__@cloneable.Cloneables.WithCloneable ))
)
)
(elem declare func [email protected])
Expand Down Expand Up @@ -544,10 +564,20 @@
(func [email protected]
(result (ref null $cloneable.Cloneables.WithoutCloneableChild))
;;@ cloneable/Cloneables.java:35:28
(local $$instance (ref null $cloneable.Cloneables.WithoutCloneableChild))
(block
;;@ cloneable/Cloneables.java:35:28
(call $$clinit__void_<once>[email protected] )
;;@ cloneable/Cloneables.java:35:28
(return (call [email protected] ))
)
)

;;; WithoutCloneableChild WithoutCloneableChild.$create_$private()
(func [email protected]
(result (ref null $cloneable.Cloneables.WithoutCloneableChild))
;;@ cloneable/Cloneables.java:35:28
(local $$instance (ref null $cloneable.Cloneables.WithoutCloneableChild))
(block
;;@ cloneable/Cloneables.java:35:28
(local.set $$instance (struct.new $cloneable.Cloneables.WithoutCloneableChild (ref.as_non_null (global.get $cloneable.Cloneables.WithoutCloneableChild.vtable)) (ref.as_non_null (global.get $itable.empty)) (i32.const 0)))
;;@ cloneable/Cloneables.java:35:28
Expand Down Expand Up @@ -577,7 +607,7 @@
(local.set $this (ref.cast (ref $cloneable.Cloneables.WithoutCloneableChild) (local.get $this.untyped)))
(block
;;@ cloneable/Cloneables.java:38:6
(return (call $$create__@cloneable.Cloneables.WithoutCloneableChild ))
(return (call $$create_$private__@cloneable.Cloneables.WithoutCloneableChild ))
)
)
(elem declare func $m_clone__java_lang_Object@cloneable.Cloneables.WithoutCloneableChild)
Expand Down Expand Up @@ -645,10 +675,20 @@
(func [email protected]
(result (ref null $cloneable.Cloneables.WithCloneableChild))
;;@ cloneable/Cloneables.java:42:28
(local $$instance (ref null $cloneable.Cloneables.WithCloneableChild))
(block
;;@ cloneable/Cloneables.java:42:28
(call $$clinit__void_<once>[email protected] )
;;@ cloneable/Cloneables.java:42:28
(return (call [email protected] ))
)
)

;;; WithCloneableChild WithCloneableChild.$create_$private()
(func [email protected]
(result (ref null $cloneable.Cloneables.WithCloneableChild))
;;@ cloneable/Cloneables.java:42:28
(local $$instance (ref null $cloneable.Cloneables.WithCloneableChild))
(block
;;@ cloneable/Cloneables.java:42:28
(local.set $$instance (struct.new $cloneable.Cloneables.WithCloneableChild (ref.as_non_null (global.get $cloneable.Cloneables.WithCloneableChild.vtable)) (ref.as_non_null (global.get $cloneable.Cloneables.WithCloneableChild.itable)) (i32.const 0)))
;;@ cloneable/Cloneables.java:42:28
Expand Down Expand Up @@ -678,7 +718,7 @@
(local.set $this (ref.cast (ref $cloneable.Cloneables.WithCloneableChild) (local.get $this.untyped)))
(block
;;@ cloneable/Cloneables.java:45:6
(return (call $$create__@cloneable.Cloneables.WithCloneableChild ))
(return (call $$create_$private__@cloneable.Cloneables.WithCloneableChild ))
)
)
(elem declare func $m_clone__java_lang_Object@cloneable.Cloneables.WithCloneableChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,10 +865,20 @@
(func [email protected]
(result (ref null $collisions.T))
;;@ collisions/goog.java:74:6
(local $$instance (ref null $collisions.T))
(block
;;@ collisions/goog.java:74:6
(call $$clinit__void_<once>[email protected] )
;;@ collisions/goog.java:74:6
(return (call [email protected] ))
)
)

;;; T<T> T.$create_$private()
(func [email protected]
(result (ref null $collisions.T))
;;@ collisions/goog.java:74:6
(local $$instance (ref null $collisions.T))
(block
;;@ collisions/goog.java:74:6
(local.set $$instance (struct.new $collisions.T (ref.as_non_null (global.get $collisions.T.vtable)) (ref.as_non_null (global.get $itable.empty)) (i32.const 0)))
;;@ collisions/goog.java:74:6
Expand Down Expand Up @@ -916,7 +926,7 @@
(block
;;@ collisions/goog.java:80:4
(local.set $t (block (result (ref null $java.lang.Number))
(local.set $$qualifier (call $$create__@collisions.T ))
(local.set $$qualifier (call $$create_$private__@collisions.T ))
(call_ref $function.m_m__java_lang_Number_$pp_collisions (ref.as_non_null (local.get $$qualifier))(struct.get $collisions.T.vtable $m_m__java_lang_Number_$pp_collisions (struct.get $collisions.T $vtable(local.get $$qualifier))))
))
;;@ collisions/goog.java:81:4
Expand Down
Loading

0 comments on commit 71f2c81

Please sign in to comment.