diff --git a/xlb/operator/boundary_masker/mesh_boundary_masker.py b/xlb/operator/boundary_masker/mesh_boundary_masker.py index c48724fe..b4aca04e 100644 --- a/xlb/operator/boundary_masker/mesh_boundary_masker.py +++ b/xlb/operator/boundary_masker/mesh_boundary_masker.py @@ -61,59 +61,65 @@ def index_to_position(index: wp.vec3i): # Function to precompute useful values per triangle, assuming spacing is (1,1,1) @wp.func - def pre_compute( v: wp.mat33f, # triangle vertices - normal: wp.vec3f ): # triangle normal - c = wp.vec3f(float(normal[0] > 0.0),float(normal[1] > 0.0),float(normal[2] > 0.0)) + def pre_compute( + verts: wp.mat33f, # triangle vertices + normal: wp.vec3f, # triangle normal + ): + corner = wp.vec3f(float(normal[0] > 0.0), float(normal[1] > 0.0), float(normal[2] > 0.0)) - d1 = wp.dot(normal, c-v[0]) - d2 = wp.dot(normal, wp.vec3f(1.,1.,1.) - c - v[0]) + dist1 = wp.dot(normal, corner - verts[0]) + dist2 = wp.dot(normal, wp.vec3f(1.0, 1.0, 1.0) - corner - verts[0]) - edges = wp.transpose(wp.mat33(v[1]-v[0],v[2]-v[1],v[0]-v[2])) - ne0 = wp.mat33f(0.0) - ne1 = wp.mat33f(0.0) - de = wp.mat33f(0.0) + edges = wp.transpose(wp.mat33(verts[1] - verts[0], verts[2] - verts[1], verts[0] - verts[2])) + normal_edge0 = wp.mat33f(0.0) + normal_edge1 = wp.mat33f(0.0) + dist_edge = wp.mat33f(0.0) - for ax0 in range(0,3): - ax2 = ( ax0 + 2 ) % 3 + for axis0 in range(0, 3): + axis2 = (axis0 + 2) % 3 sgn = 1.0 - if ( normal[ax2] < 0.0 ): + if normal[axis2] < 0.0: sgn = -1.0 - for i in range(0,3): - ne0[i][ax0] = -1.0 * sgn * edges[i][ax0] - ne1[i][ax0] = sgn * edges[i][ax0] + for i in range(0, 3): + normal_edge0[i][axis0] = -1.0 * sgn * edges[i][axis0] + normal_edge1[i][axis0] = sgn * edges[i][axis0] - de[i][ax0] = -1. * ( ne0[i][ax0] * v[i][ax0] + ne1[i][ax0] * v[i][ax0] ) \ - + wp.max(0., ne0[i][ax0] ) + wp.max(0., ne1[i][ax0]) + dist_edge[i][axis0] = ( + -1.0 * (normal_edge0[i][axis0] * verts[i][axis0] + normal_edge1[i][axis0] * verts[i][axis0]) + + wp.max(0.0, normal_edge0[i][axis0]) + + wp.max(0.0, normal_edge1[i][axis0]) + ) - return d1, d2, ne0, ne1, de + return dist1, dist2, normal_edge0, normal_edge1, dist_edge # Check whether this triangle intersects the unit cube at position low @wp.func - def triangle_box_intersect( low: wp.vec3f, - normal: wp.vec3f, - d1: wp.float32, - d2: wp.float32, - ne0: wp.mat33f, - ne1: wp.mat33f, - de: wp.mat33f ): - if ( ( wp.dot(normal, low ) + d1 ) * ( wp.dot( normal, low ) + d2 ) <= 0.0 ): + def triangle_box_intersect( + low: wp.vec3f, + normal: wp.vec3f, + dist1: wp.float32, + dist2: wp.float32, + normal_edge0: wp.mat33f, + normal_edge1: wp.mat33f, + dist_edge: wp.mat33f, + ): + if (wp.dot(normal, low) + dist1) * (wp.dot(normal, low) + dist2) <= 0.0: intersect = True # Loop over primary axis for projection - for ax0 in range(0,3): - ax1 = ( ax0 + 1 ) % 3 - for i in range(0,3): - intersect = intersect and ( ne0[i][ax0] * low[ax0] + ne1[i][ax0] * low[ax1] + de[i][ax0] >= 0.0 ) + for ax0 in range(0, 3): + ax1 = (ax0 + 1) % 3 + for i in range(0, 3): + intersect = intersect and (normal_edge0[i][ax0] * low[ax0] + normal_edge1[i][ax0] * low[ax1] + dist_edge[i][ax0] >= 0.0) return intersect else: return False @wp.func - def mesh_voxel_intersect( mesh_id: wp.uint64, low: wp.vec3 ): - - query = wp.mesh_query_aabb(mesh_id, low, low + wp.vec3f(1.,1.,1.)) + def mesh_voxel_intersect(mesh_id: wp.uint64, low: wp.vec3): + query = wp.mesh_query_aabb(mesh_id, low, low + wp.vec3f(1.0, 1.0, 1.0)) for f in query: v0 = wp.mesh_eval_position(mesh_id, f, 1.0, 0.0) @@ -121,12 +127,12 @@ def mesh_voxel_intersect( mesh_id: wp.uint64, low: wp.vec3 ): v2 = wp.mesh_eval_position(mesh_id, f, 0.0, 0.0) normal = wp.mesh_eval_face_normal(mesh_id, f) - v = wp.transpose(wp.mat33f(v0,v1,v2)) + v = wp.transpose(wp.mat33f(v0, v1, v2)) # TODO: run this on triangles in advance - d1, d2, ne0, ne1, de = pre_compute( v= v, normal= normal ) + d1, d2, ne0, ne1, de = pre_compute(verts=v, normal=normal) - if triangle_box_intersect(low=low, normal=normal, d1=d1, d2=d2, ne0=ne0, ne1=ne1, de=de): + if triangle_box_intersect(low=low, normal=normal, dist1=d1, dist2=d2, normal_edge0=ne0, normal_edge1=ne1, dist_edge=de): return True return False @@ -177,9 +183,9 @@ def warp_implementation( ): assert bc.mesh_vertices is not None, f'Please provide the mesh vertices for {bc.__class__.__name__} BC using keyword "mesh_vertices"!' assert bc.indices is None, f"Please use IndicesBoundaryMasker operator if {bc.__class__.__name__} is imposed on known indices of the grid!" - assert ( - bc.mesh_vertices.shape[1] == self.velocity_set.d - ), "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + assert bc.mesh_vertices.shape[1] == self.velocity_set.d, ( + "Mesh points must be reshaped into an array (N, 3) where N indicates number of points!" + ) mesh_vertices = bc.mesh_vertices id_number = bc.id